]> git.ipfire.org Git - people/ms/libloc.git/blobdiff - src/python/database.c
python: Do not use any GNU-style initialisers for structs
[people/ms/libloc.git] / src / python / database.c
index 53b454c4f6dde3dd9f1fd28055fbdf64e70793d7..01be0894a1a4fdcf38dfdf5320a6ea291fbcc35f 100644 (file)
 #include "locationmodule.h"
 #include "as.h"
 #include "database.h"
+#include "network.h"
 
 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
        DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
-       if (self) {
-               self->ctx = loc_ref(loc_ctx);
-       }
 
        return (PyObject*)self;
 }
@@ -36,8 +34,8 @@ static void Database_dealloc(DatabaseObject* self) {
        if (self->db)
                loc_database_unref(self->db);
 
-       if (self->ctx)
-               loc_unref(self->ctx);
+       if (self->path)
+               free(self->path);
 
        Py_TYPE(self)->tp_free((PyObject* )self);
 }
@@ -48,13 +46,17 @@ static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs)
        if (!PyArg_ParseTuple(args, "s", &path))
                return -1;
 
+       self->path = strdup(path);
+
        // Open the file for reading
-       FILE* f = fopen(path, "r");
-       if (!f)
+       FILE* f = fopen(self->path, "r");
+       if (!f) {
+               PyErr_SetFromErrno(PyExc_IOError);
                return -1;
+       }
 
        // Load the database
-       int r = loc_database_new(self->ctx, &self->db, f);
+       int r = loc_database_new(loc_ctx, &self->db, f);
        fclose(f);
 
        // Return on any errors
@@ -64,6 +66,10 @@ static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs)
        return 0;
 }
 
+static PyObject* Database_repr(DatabaseObject* self) {
+       return PyUnicode_FromFormat("<Database %s>", self->path);
+}
+
 static PyObject* Database_get_description(DatabaseObject* self) {
        const char* description = loc_database_get_description(self->db);
 
@@ -76,6 +82,12 @@ static PyObject* Database_get_vendor(DatabaseObject* self) {
        return PyUnicode_FromString(vendor);
 }
 
+static PyObject* Database_get_license(DatabaseObject* self) {
+       const char* license = loc_database_get_license(self->db);
+
+       return PyUnicode_FromString(license);
+}
+
 static PyObject* Database_get_created_at(DatabaseObject* self) {
        time_t created_at = loc_database_created_at(self->db);
 
@@ -91,19 +103,84 @@ static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
 
        // Try to retrieve the AS
        int r = loc_database_get_as(self->db, &as, number);
-       if (r)
-               return NULL;
 
-       // Create an AS object
-       if (as) {
+       // We got an AS
+       if (r == 0) {
                PyObject* obj = new_as(&ASType, as);
                loc_as_unref(as);
 
                return obj;
+
+       // Nothing found
+       } else if (r == 1) {
+               Py_RETURN_NONE;
        }
 
+       // Unexpected error
+       return NULL;
+}
+
+static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
+       struct loc_network* network = NULL;
+       const char* address = NULL;
+
+       if (!PyArg_ParseTuple(args, "s", &address))
+               return NULL;
+
+       // Try to retrieve a matching network
+       int r = loc_database_lookup_from_string(self->db, address, &network);
+
+       // We got a network
+       if (r == 0) {
+               PyObject* obj = new_network(&NetworkType, network);
+               loc_network_unref(network);
+
+               return obj;
+
        // Nothing found
-       Py_RETURN_NONE;
+       } else if (r == 1) {
+               Py_RETURN_NONE;
+
+       // Invalid input
+       } else if (r == -EINVAL) {
+               PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
+               return NULL;
+       }
+
+       // Unexpected error
+       return NULL;
+}
+
+static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
+       DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
+       if (self) {
+               self->enumerator = loc_database_enumerator_ref(enumerator);
+       }
+
+       return (PyObject*)self;
+}
+
+static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
+       const char* string = NULL;
+
+       if (!PyArg_ParseTuple(args, "s", &string))
+               return NULL;
+
+       struct loc_database_enumerator* enumerator;
+
+       int r = loc_database_enumerator_new(&enumerator, self->db);
+       if (r) {
+               PyErr_SetFromErrno(PyExc_SystemError);
+               return NULL;
+       }
+
+       // Search string we are searching for
+       loc_database_enumerator_set_string(enumerator, string);
+
+       PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
+       loc_database_enumerator_unref(enumerator);
+
+       return obj;
 }
 
 static struct PyMethodDef Database_methods[] = {
@@ -113,6 +190,18 @@ static struct PyMethodDef Database_methods[] = {
                METH_VARARGS,
                NULL,
        },
+       {
+               "lookup",
+               (PyCFunction)Database_lookup,
+               METH_VARARGS,
+               NULL,
+       },
+       {
+               "search_as",
+               (PyCFunction)Database_search_as,
+               METH_VARARGS,
+               NULL,
+       },
        { NULL },
 };
 
@@ -131,6 +220,13 @@ static struct PyGetSetDef Database_getsetters[] = {
                NULL,
                NULL,
        },
+       {
+               "license",
+               (getter)Database_get_license,
+               NULL,
+               NULL,
+               NULL,
+       },
        {
                "vendor",
                (getter)Database_get_vendor,
@@ -143,13 +239,52 @@ static struct PyGetSetDef Database_getsetters[] = {
 
 PyTypeObject DatabaseType = {
        PyVarObject_HEAD_INIT(NULL, 0)
-       tp_name:                "location.Database",
-       tp_basicsize:           sizeof(DatabaseObject),
-       tp_flags:               Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
-       tp_new:                 Database_new,
-       tp_dealloc:             (destructor)Database_dealloc,
-       tp_init:                (initproc)Database_init,
-       tp_doc:                 "Database object",
-       tp_methods:             Database_methods,
-       tp_getset:              Database_getsetters,
+       .tp_name =               "location.Database",
+       .tp_basicsize =          sizeof(DatabaseObject),
+       .tp_flags =              Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
+       .tp_new =                Database_new,
+       .tp_dealloc =            (destructor)Database_dealloc,
+       .tp_init =               (initproc)Database_init,
+       .tp_doc =                "Database object",
+       .tp_methods =            Database_methods,
+       .tp_getset =             Database_getsetters,
+       .tp_repr =               (reprfunc)Database_repr,
+};
+
+static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
+       DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
+
+       return (PyObject*)self;
+}
+
+static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
+       loc_database_enumerator_unref(self->enumerator);
+
+       Py_TYPE(self)->tp_free((PyObject* )self);
+}
+
+static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
+       struct loc_as* as = loc_database_enumerator_next_as(self->enumerator);
+       if (as) {
+               PyObject* obj = new_as(&ASType, as);
+               loc_as_unref(as);
+
+               return obj;
+       }
+
+       // Nothing found, that means the end
+       PyErr_SetNone(PyExc_StopIteration);
+       return NULL;
+}
+
+PyTypeObject DatabaseEnumeratorType = {
+       PyVarObject_HEAD_INIT(NULL, 0)
+       .tp_name =               "location.DatabaseEnumerator",
+       .tp_basicsize =          sizeof(DatabaseEnumeratorObject),
+       .tp_flags =              Py_TPFLAGS_DEFAULT,
+       .tp_alloc =              PyType_GenericAlloc,
+       .tp_new =                DatabaseEnumerator_new,
+       .tp_dealloc =            (destructor)DatabaseEnumerator_dealloc,
+       .tp_iter =               PyObject_SelfIter,
+       .tp_iternext =           (iternextfunc)DatabaseEnumerator_next,
 };