]> git.ipfire.org Git - location/libloc.git/blobdiff - src/python/database.c
Implement listing networks in Python
[location/libloc.git] / src / python / database.c
index 3c9e02420dca8f2d25f3bdad593f44124686ae72..954b5aceae975760e596c56c0ed945a7686119fa 100644 (file)
@@ -50,8 +50,10 @@ static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs)
 
        // Open the file for reading
        FILE* f = fopen(self->path, "r");
-       if (!f)
+       if (!f) {
+               PyErr_SetFromErrno(PyExc_IOError);
                return -1;
+       }
 
        // Load the database
        int r = loc_database_new(loc_ctx, &self->db, f);
@@ -166,7 +168,7 @@ static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
 
        struct loc_database_enumerator* enumerator;
 
-       int r = loc_database_enumerator_new(&enumerator, self->db);
+       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES);
        if (r) {
                PyErr_SetFromErrno(PyExc_SystemError);
                return NULL;
@@ -181,6 +183,47 @@ static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
        return obj;
 }
 
+static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
+       char* kwlist[] = { "country_code", "asn", NULL };
+       const char* country_code = NULL;
+       unsigned int asn = 0;
+
+    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|si", kwlist, &country_code, &asn))
+               return NULL;
+
+       struct loc_database_enumerator* enumerator;
+       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS);
+       if (r) {
+               PyErr_SetFromErrno(PyExc_SystemError);
+               return NULL;
+       }
+
+       // Set country code we are searching for
+       if (country_code) {
+               r = loc_database_enumerator_set_country_code(enumerator, country_code);
+
+               if (r) {
+                       PyErr_SetFromErrno(PyExc_SystemError);
+                       return NULL;
+               }
+       }
+
+       // Set the ASN we are searching for
+       if (asn) {
+               r = loc_database_enumerator_set_asn(enumerator, asn);
+
+               if (r) {
+                       PyErr_SetFromErrno(PyExc_SystemError);
+                       return NULL;
+               }
+       }
+
+       PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
+       loc_database_enumerator_unref(enumerator);
+
+       return obj;
+}
+
 static struct PyMethodDef Database_methods[] = {
        {
                "get_as",
@@ -200,6 +243,12 @@ static struct PyMethodDef Database_methods[] = {
                METH_VARARGS,
                NULL,
        },
+       {
+               "search_networks",
+               (PyCFunction)Database_search_networks,
+               METH_VARARGS|METH_KEYWORDS,
+               NULL,
+       },
        { NULL },
 };
 
@@ -237,16 +286,16 @@ static struct PyGetSetDef Database_getsetters[] = {
 
 PyTypeObject DatabaseType = {
        PyVarObject_HEAD_INIT(NULL, 0)
-       tp_name:                "location.Database",
-       tp_basicsize:           sizeof(DatabaseObject),
-       tp_flags:               Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
-       tp_new:                 Database_new,
-       tp_dealloc:             (destructor)Database_dealloc,
-       tp_init:                (initproc)Database_init,
-       tp_doc:                 "Database object",
-       tp_methods:             Database_methods,
-       tp_getset:              Database_getsetters,
-       tp_repr:                (reprfunc)Database_repr,
+       .tp_name =               "location.Database",
+       .tp_basicsize =          sizeof(DatabaseObject),
+       .tp_flags =              Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
+       .tp_new =                Database_new,
+       .tp_dealloc =            (destructor)Database_dealloc,
+       .tp_init =               (initproc)Database_init,
+       .tp_doc =                "Database object",
+       .tp_methods =            Database_methods,
+       .tp_getset =             Database_getsetters,
+       .tp_repr =               (reprfunc)Database_repr,
 };
 
 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
@@ -262,6 +311,16 @@ static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
 }
 
 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
+       // Enumerate all networks
+       struct loc_network* network = loc_database_enumerator_next_network(self->enumerator);
+       if (network) {
+               PyObject* obj = new_network(&NetworkType, network);
+               loc_network_unref(network);
+
+               return obj;
+       }
+
+       // Enumerate all ASes
        struct loc_as* as = loc_database_enumerator_next_as(self->enumerator);
        if (as) {
                PyObject* obj = new_as(&ASType, as);
@@ -277,12 +336,12 @@ static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
 
 PyTypeObject DatabaseEnumeratorType = {
        PyVarObject_HEAD_INIT(NULL, 0)
-       tp_name:                "location.DatabaseEnumerator",
-       tp_basicsize:           sizeof(DatabaseEnumeratorObject),
-       tp_flags:               Py_TPFLAGS_DEFAULT,
-       tp_alloc:               PyType_GenericAlloc,
-       tp_new:                 DatabaseEnumerator_new,
-       tp_dealloc:             (destructor)DatabaseEnumerator_dealloc,
-       tp_iter:                PyObject_SelfIter,
-       tp_iternext:            (iternextfunc)DatabaseEnumerator_next,
+       .tp_name =               "location.DatabaseEnumerator",
+       .tp_basicsize =          sizeof(DatabaseEnumeratorObject),
+       .tp_flags =              Py_TPFLAGS_DEFAULT,
+       .tp_alloc =              PyType_GenericAlloc,
+       .tp_new =                DatabaseEnumerator_new,
+       .tp_dealloc =            (destructor)DatabaseEnumerator_dealloc,
+       .tp_iter =               PyObject_SelfIter,
+       .tp_iternext =           (iternextfunc)DatabaseEnumerator_next,
 };