]> git.ipfire.org Git - location/libloc.git/blobdiff - src/python/database.c
location-query: Allow filtering networks by family
[location/libloc.git] / src / python / database.c
index ea476c807da591e572fed635f4def3d69b400685..2f0a3b0fb1b3c5cba75e68018d405cb96760f4de 100644 (file)
@@ -71,6 +71,34 @@ static PyObject* Database_repr(DatabaseObject* self) {
        return PyUnicode_FromFormat("<Database %s>", self->path);
 }
 
+static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
+       PyObject* public_key = NULL;
+       FILE* f = NULL;
+
+       // Parse arguments
+       if (!PyArg_ParseTuple(args, "O", &public_key))
+               return NULL;
+
+       // Convert into FILE*
+       int fd = PyObject_AsFileDescriptor(public_key);
+       if (fd < 0)
+               return NULL;
+
+       // Re-open file descriptor
+       f = fdopen(fd, "r");
+       if (!f) {
+               PyErr_SetFromErrno(PyExc_IOError);
+               return NULL;
+       }
+
+       int r = loc_database_verify(self->db, f);
+
+       if (r == 0)
+               Py_RETURN_TRUE;
+
+       Py_RETURN_FALSE;
+}
+
 static PyObject* Database_get_description(DatabaseObject* self) {
        const char* description = loc_database_get_description(self->db);
 
@@ -203,11 +231,13 @@ static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
 }
 
 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
-       char* kwlist[] = { "country_code", "asn", NULL };
+       char* kwlist[] = { "country_code", "asn", "flags", "family", NULL };
        const char* country_code = NULL;
        unsigned int asn = 0;
+       int flags = 0;
+       int family = 0;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|si", kwlist, &country_code, &asn))
+       if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siii", kwlist, &country_code, &asn, &flags, &family))
                return NULL;
 
        struct loc_database_enumerator* enumerator;
@@ -237,6 +267,26 @@ static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args,
                }
        }
 
+       // Set the flags we are searching for
+       if (flags) {
+               r = loc_database_enumerator_set_flag(enumerator, flags);
+
+               if (r) {
+                       PyErr_SetFromErrno(PyExc_SystemError);
+                       return NULL;
+               }
+       }
+
+       // Set the family we are searching for
+       if (family) {
+               r = loc_database_enumerator_set_family(enumerator, family);
+
+               if (r) {
+                       PyErr_SetFromErrno(PyExc_SystemError);
+                       return NULL;
+               }
+       }
+
        PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
        loc_database_enumerator_unref(enumerator);
 
@@ -274,6 +324,12 @@ static struct PyMethodDef Database_methods[] = {
                METH_VARARGS|METH_KEYWORDS,
                NULL,
        },
+       {
+               "verify",
+               (PyCFunction)Database_verify,
+               METH_VARARGS,
+               NULL,
+       },
        { NULL },
 };
 
@@ -341,6 +397,7 @@ static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
        // Enumerate all networks
        int r = loc_database_enumerator_next_network(self->enumerator, &network);
        if (r) {
+               PyErr_SetFromErrno(PyExc_ValueError);
                return NULL;
        }
 
@@ -357,6 +414,7 @@ static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
 
        r = loc_database_enumerator_next_as(self->enumerator, &as);
        if (r) {
+               PyErr_SetFromErrno(PyExc_ValueError);
                return NULL;
        }