]> git.ipfire.org Git - location/libloc.git/blobdiff - src/python/database.c
importer: Drop EDROP as it has been merged into DROP
[location/libloc.git] / src / python / database.c
index f6247cfbc87cdbffd06f44b24ad414b72f9c8b49..d6ee4d02d0ed1d35fcf1ccc4244d25763e222cd3 100644 (file)
 
 #include <Python.h>
 
-#include <loc/libloc.h>
-#include <loc/database.h>
+#include <libloc/libloc.h>
+#include <libloc/as.h>
+#include <libloc/as-list.h>
+#include <libloc/database.h>
 
 #include "locationmodule.h"
 #include "as.h"
@@ -43,36 +45,63 @@ static void Database_dealloc(DatabaseObject* self) {
 
 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
        const char* path = NULL;
+       FILE* f = NULL;
 
+       // Parse arguments
        if (!PyArg_ParseTuple(args, "s", &path))
                return -1;
 
+       // Copy path
        self->path = strdup(path);
+       if (!self->path)
+               goto ERROR;
 
        // Open the file for reading
-       FILE* f = fopen(self->path, "r");
-       if (!f) {
-               PyErr_SetFromErrno(PyExc_IOError);
-               return -1;
-       }
+       f = fopen(self->path, "r");
+       if (!f)
+               goto ERROR;
 
        // Load the database
        int r = loc_database_new(loc_ctx, &self->db, f);
-       fclose(f);
-
-       // Return on any errors
        if (r)
-               return -1;
+               goto ERROR;
 
+       fclose(f);
        return 0;
+
+ERROR:
+       if (f)
+               fclose(f);
+
+       PyErr_SetFromErrno(PyExc_OSError);
+       return -1;
 }
 
 static PyObject* Database_repr(DatabaseObject* self) {
        return PyUnicode_FromFormat("<Database %s>", self->path);
 }
 
-static PyObject* Database_verify(DatabaseObject* self) {
-       int r = loc_database_verify(self->db);
+static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
+       PyObject* public_key = NULL;
+       FILE* f = NULL;
+
+       // Parse arguments
+       if (!PyArg_ParseTuple(args, "O", &public_key))
+               return NULL;
+
+       // Convert into FILE*
+       int fd = PyObject_AsFileDescriptor(public_key);
+       if (fd < 0)
+               return NULL;
+
+       // Re-open file descriptor
+       f = fdopen(fd, "r");
+       if (!f) {
+               PyErr_SetFromErrno(PyExc_IOError);
+               return NULL;
+       }
+
+       int r = loc_database_verify(self->db, f);
 
        if (r == 0)
                Py_RETURN_TRUE;
@@ -82,18 +111,24 @@ static PyObject* Database_verify(DatabaseObject* self) {
 
 static PyObject* Database_get_description(DatabaseObject* self) {
        const char* description = loc_database_get_description(self->db);
+       if (!description)
+               Py_RETURN_NONE;
 
        return PyUnicode_FromString(description);
 }
 
 static PyObject* Database_get_vendor(DatabaseObject* self) {
        const char* vendor = loc_database_get_vendor(self->db);
+       if (!vendor)
+               Py_RETURN_NONE;
 
        return PyUnicode_FromString(vendor);
 }
 
 static PyObject* Database_get_license(DatabaseObject* self) {
        const char* license = loc_database_get_license(self->db);
+       if (!license)
+               Py_RETURN_NONE;
 
        return PyUnicode_FromString(license);
 }
@@ -131,17 +166,32 @@ static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
 }
 
 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
+       struct loc_country* country = NULL;
        const char* country_code = NULL;
 
        if (!PyArg_ParseTuple(args, "s", &country_code))
                return NULL;
 
-       struct loc_country* country;
+       // Fetch the country
        int r = loc_database_get_country(self->db, &country, country_code);
        if (r) {
-               Py_RETURN_NONE;
+               switch (errno) {
+                       case EINVAL:
+                               PyErr_SetString(PyExc_ValueError, "Invalid country code");
+                               break;
+
+                       default:
+                               PyErr_SetFromErrno(PyExc_OSError);
+                               break;
+               }
+
+               return NULL;
        }
 
+       // No result
+       if (!country)
+               Py_RETURN_NONE;
+
        PyObject* obj = new_country(&CountryType, country);
        loc_country_unref(country);
 
@@ -164,18 +214,21 @@ static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
                loc_network_unref(network);
 
                return obj;
+       }
 
        // Nothing found
-       } else if (r == 1) {
+       if (!errno)
                Py_RETURN_NONE;
 
-       // Invalid input
-       } else if (r == -EINVAL) {
-               PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
-               return NULL;
+       // Handle any errors
+       switch (errno) {
+               case EINVAL:
+                       PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
+
+               default:
+                       PyErr_SetFromErrno(PyExc_OSError);
        }
 
-       // Unexpected error
        return NULL;
 }
 
@@ -188,6 +241,30 @@ static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database
        return (PyObject*)self;
 }
 
+static PyObject* Database_iterate_all(DatabaseObject* self,
+               enum loc_database_enumerator_mode what, int family, int flags) {
+       struct loc_database_enumerator* enumerator;
+
+       int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
+       if (r) {
+               PyErr_SetFromErrno(PyExc_SystemError);
+               return NULL;
+       }
+
+       // Set family
+       if (family)
+               loc_database_enumerator_set_family(enumerator, family);
+
+       PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
+       loc_database_enumerator_unref(enumerator);
+
+       return obj;
+}
+
+static PyObject* Database_ases(DatabaseObject* self) {
+       return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, AF_UNSPEC, 0);
+}
+
 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
        const char* string = NULL;
 
@@ -196,7 +273,7 @@ static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
 
        struct loc_database_enumerator* enumerator;
 
-       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES);
+       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
        if (r) {
                PyErr_SetFromErrno(PyExc_SystemError);
                return NULL;
@@ -211,40 +288,144 @@ static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
        return obj;
 }
 
+static PyObject* Database_networks(DatabaseObject* self) {
+       return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC, 0);
+}
+
+static PyObject* Database_networks_flattened(DatabaseObject *self) {
+       return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC,
+               LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
+}
+
 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
-       char* kwlist[] = { "country_code", "asn", "flags", NULL };
-       const char* country_code = NULL;
-       unsigned int asn = 0;
+       char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL };
+       PyObject* country_codes = NULL;
+       PyObject* asn_list = NULL;
        int flags = 0;
+       int family = 0;
+       int flatten = 0;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|sii", kwlist, &country_code, &asn, &flags))
+       if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", kwlist,
+                       &PyList_Type, &country_codes, &PyList_Type, &asn_list, &flags, &family, &flatten))
                return NULL;
 
        struct loc_database_enumerator* enumerator;
-       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS);
+       int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
+               (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
        if (r) {
                PyErr_SetFromErrno(PyExc_SystemError);
                return NULL;
        }
 
        // Set country code we are searching for
-       if (country_code) {
-               r = loc_database_enumerator_set_country_code(enumerator, country_code);
+       if (country_codes) {
+               struct loc_country_list* countries;
+               r = loc_country_list_new(loc_ctx, &countries);
+               if (r) {
+                       PyErr_SetString(PyExc_SystemError, "Could not create country list");
+                       return NULL;
+               }
+
+               for (int i = 0; i < PyList_Size(country_codes); i++) {
+                       PyObject* item = PyList_GetItem(country_codes, i);
+
+                       if (!PyUnicode_Check(item)) {
+                               PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
+                               loc_country_list_unref(countries);
+                               return NULL;
+                       }
+
+                       const char* country_code = PyUnicode_AsUTF8(item);
+
+                       struct loc_country* country;
+                       r = loc_country_new(loc_ctx, &country, country_code);
+                       if (r) {
+                               if (r == -EINVAL) {
+                                       PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
+                               } else {
+                                       PyErr_SetString(PyExc_SystemError, "Could not create country");
+                               }
+
+                               loc_country_list_unref(countries);
+                               return NULL;
+                       }
+
+                       // Append it to the list
+                       r = loc_country_list_append(countries, country);
+                       if (r) {
+                               PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
+
+                               loc_country_list_unref(countries);
+                               loc_country_unref(country);
+                               return NULL;
+                       }
+
+                       loc_country_unref(country);
+               }
 
+               r = loc_database_enumerator_set_countries(enumerator, countries);
                if (r) {
                        PyErr_SetFromErrno(PyExc_SystemError);
+
+                       loc_country_list_unref(countries);
                        return NULL;
                }
+
+               loc_country_list_unref(countries);
        }
 
        // Set the ASN we are searching for
-       if (asn) {
-               r = loc_database_enumerator_set_asn(enumerator, asn);
+       if (asn_list) {
+               struct loc_as_list* asns;
+               r = loc_as_list_new(loc_ctx, &asns);
+               if (r) {
+                       PyErr_SetFromErrno(PyExc_OSError);
+                       return NULL;
+               }
+
+               for (int i = 0; i < PyList_Size(asn_list); i++) {
+                       PyObject* item = PyList_GetItem(asn_list, i);
+
+                       if (!PyLong_Check(item)) {
+                               PyErr_SetString(PyExc_TypeError, "ASNs must be numbers");
+
+                               loc_as_list_unref(asns);
+                               return NULL;
+                       }
+
+                       unsigned long number = PyLong_AsLong(item);
+
+                       struct loc_as* as;
+                       r = loc_as_new(loc_ctx, &as, number);
+                       if (r) {
+                               PyErr_SetFromErrno(PyExc_OSError);
+
+                               loc_as_list_unref(asns);
+                               loc_as_unref(as);
+                               return NULL;
+                       }
+
+                       r = loc_as_list_append(asns, as);
+                       if (r) {
+                               PyErr_SetFromErrno(PyExc_OSError);
 
+                               loc_as_list_unref(asns);
+                               loc_as_unref(as);
+                               return NULL;
+                       }
+
+                       loc_as_unref(as);
+               }
+
+               r = loc_database_enumerator_set_asns(enumerator, asns);
                if (r) {
-                       PyErr_SetFromErrno(PyExc_SystemError);
+                       PyErr_SetFromErrno(PyExc_OSError);
+
+                       loc_as_list_unref(asns);
                        return NULL;
                }
+
+               loc_as_list_unref(asns);
        }
 
        // Set the flags we are searching for
@@ -252,7 +433,17 @@ static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args,
                r = loc_database_enumerator_set_flag(enumerator, flags);
 
                if (r) {
-                       PyErr_SetFromErrno(PyExc_SystemError);
+                       PyErr_SetFromErrno(PyExc_OSError);
+                       return NULL;
+               }
+       }
+
+       // Set the family we are searching for
+       if (family) {
+               r = loc_database_enumerator_set_family(enumerator, family);
+
+               if (r) {
+                       PyErr_SetFromErrno(PyExc_OSError);
                        return NULL;
                }
        }
@@ -263,6 +454,21 @@ static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args,
        return obj;
 }
 
+static PyObject* Database_countries(DatabaseObject* self) {
+       return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, AF_UNSPEC, 0);
+}
+
+static PyObject* Database_list_bogons(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
+       char* kwlist[] = { "family", NULL };
+       int family = AF_UNSPEC;
+
+       // Parse arguments
+       if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &family))
+               return NULL;
+
+       return Database_iterate_all(self, LOC_DB_ENUMERATE_BOGONS, family, 0);
+}
+
 static struct PyMethodDef Database_methods[] = {
        {
                "get_as",
@@ -276,6 +482,12 @@ static struct PyMethodDef Database_methods[] = {
                METH_VARARGS,
                NULL,
        },
+       {
+               "list_bogons",
+               (PyCFunction)Database_list_bogons,
+               METH_VARARGS|METH_KEYWORDS,
+               NULL,
+       },
        {
                "lookup",
                (PyCFunction)Database_lookup,
@@ -297,13 +509,27 @@ static struct PyMethodDef Database_methods[] = {
        {
                "verify",
                (PyCFunction)Database_verify,
-               METH_NOARGS,
+               METH_VARARGS,
                NULL,
        },
        { NULL },
 };
 
 static struct PyGetSetDef Database_getsetters[] = {
+       {
+               "ases",
+               (getter)Database_ases,
+               NULL,
+               NULL,
+               NULL,
+       },
+       {
+               "countries",
+               (getter)Database_countries,
+               NULL,
+               NULL,
+               NULL,
+       },
        {
                "created_at",
                (getter)Database_get_created_at,
@@ -325,6 +551,20 @@ static struct PyGetSetDef Database_getsetters[] = {
                NULL,
                NULL,
        },
+       {
+               "networks",
+               (getter)Database_networks,
+               NULL,
+               NULL,
+               NULL,
+       },
+       {
+               "networks_flattened",
+               (getter)Database_networks_flattened,
+               NULL,
+               NULL,
+               NULL,
+       },
        {
                "vendor",
                (getter)Database_get_vendor,
@@ -395,6 +635,22 @@ static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
                return obj;
        }
 
+       // Enumerate all countries
+       struct loc_country* country = NULL;
+
+       r = loc_database_enumerator_next_country(self->enumerator, &country);
+       if (r) {
+               PyErr_SetFromErrno(PyExc_ValueError);
+               return NULL;
+       }
+
+       if (country) {
+               PyObject* obj = new_country(&CountryType, country);
+               loc_country_unref(country);
+
+               return obj;
+       }
+
        // Nothing found, that means the end
        PyErr_SetNone(PyExc_StopIteration);
        return NULL;