From ccc7ab4e50a5c454f8a3e13ba1ee0726ae197f4d Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Thu, 3 Oct 2019 18:02:07 +0000 Subject: [PATCH] Implement listing networks in Python Signed-off-by: Michael Tremer --- src/database.c | 13 +++++++- src/loc/database.h | 8 ++++- src/python/database.c | 59 +++++++++++++++++++++++++++++++++++- src/python/location-query.in | 13 ++++++++ src/test-as.c | 2 +- 5 files changed, 91 insertions(+), 4 deletions(-) diff --git a/src/database.c b/src/database.c index 75fc50a..7f5a58a 100644 --- a/src/database.c +++ b/src/database.c @@ -73,6 +73,7 @@ struct loc_node_stack { struct loc_database_enumerator { struct loc_ctx* ctx; struct loc_database* db; + enum loc_database_enumerator_mode mode; int refcount; // Search string @@ -570,7 +571,8 @@ LOC_EXPORT int loc_database_lookup_from_string(struct loc_database* db, // Enumerator -LOC_EXPORT int loc_database_enumerator_new(struct loc_database_enumerator** enumerator, struct loc_database* db) { +LOC_EXPORT int loc_database_enumerator_new(struct loc_database_enumerator** enumerator, + struct loc_database* db, enum loc_database_enumerator_mode mode) { struct loc_database_enumerator* e = calloc(1, sizeof(*e)); if (!e) return -ENOMEM; @@ -578,6 +580,7 @@ LOC_EXPORT int loc_database_enumerator_new(struct loc_database_enumerator** enum // Reference context e->ctx = loc_ref(db->ctx); e->db = loc_database_ref(db); + e->mode = mode; e->refcount = 1; // Initialise graph search @@ -660,6 +663,10 @@ LOC_EXPORT int loc_database_enumerator_set_asn( } LOC_EXPORT struct loc_as* loc_database_enumerator_next_as(struct loc_database_enumerator* enumerator) { + // Do not do anything if not in AS mode + if (enumerator->mode != LOC_DB_ENUMERATE_ASES) + return NULL; + struct loc_database* db = enumerator->db; struct loc_as* as; @@ -800,6 +807,10 @@ static int loc_database_enumerator_network_depth_first_search( LOC_EXPORT struct loc_network* loc_database_enumerator_next_network( struct loc_database_enumerator* enumerator) { + // Do not do anything if not in network mode + if (enumerator->mode != LOC_DB_ENUMERATE_NETWORKS) + return NULL; + struct loc_network* network = NULL; int r = loc_database_enumerator_network_depth_first_search(enumerator, &network); diff --git a/src/loc/database.h b/src/loc/database.h index 7d0b20c..97650bd 100644 --- a/src/loc/database.h +++ b/src/loc/database.h @@ -43,8 +43,14 @@ int loc_database_lookup(struct loc_database* db, int loc_database_lookup_from_string(struct loc_database* db, const char* string, struct loc_network** network); +enum loc_database_enumerator_mode { + LOC_DB_ENUMERATE_NETWORKS = 1, + LOC_DB_ENUMERATE_ASES = 2, +}; + struct loc_database_enumerator; -int loc_database_enumerator_new(struct loc_database_enumerator** enumerator, struct loc_database* db); +int loc_database_enumerator_new(struct loc_database_enumerator** enumerator, + struct loc_database* db, enum loc_database_enumerator_mode mode); struct loc_database_enumerator* loc_database_enumerator_ref(struct loc_database_enumerator* enumerator); struct loc_database_enumerator* loc_database_enumerator_unref(struct loc_database_enumerator* enumerator); diff --git a/src/python/database.c b/src/python/database.c index 01be089..954b5ac 100644 --- a/src/python/database.c +++ b/src/python/database.c @@ -168,7 +168,7 @@ static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) { struct loc_database_enumerator* enumerator; - int r = loc_database_enumerator_new(&enumerator, self->db); + int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES); if (r) { PyErr_SetFromErrno(PyExc_SystemError); return NULL; @@ -183,6 +183,47 @@ static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) { return obj; } +static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) { + char* kwlist[] = { "country_code", "asn", NULL }; + const char* country_code = NULL; + unsigned int asn = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|si", kwlist, &country_code, &asn)) + return NULL; + + struct loc_database_enumerator* enumerator; + int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS); + if (r) { + PyErr_SetFromErrno(PyExc_SystemError); + return NULL; + } + + // Set country code we are searching for + if (country_code) { + r = loc_database_enumerator_set_country_code(enumerator, country_code); + + if (r) { + PyErr_SetFromErrno(PyExc_SystemError); + return NULL; + } + } + + // Set the ASN we are searching for + if (asn) { + r = loc_database_enumerator_set_asn(enumerator, asn); + + if (r) { + PyErr_SetFromErrno(PyExc_SystemError); + return NULL; + } + } + + PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator); + loc_database_enumerator_unref(enumerator); + + return obj; +} + static struct PyMethodDef Database_methods[] = { { "get_as", @@ -202,6 +243,12 @@ static struct PyMethodDef Database_methods[] = { METH_VARARGS, NULL, }, + { + "search_networks", + (PyCFunction)Database_search_networks, + METH_VARARGS|METH_KEYWORDS, + NULL, + }, { NULL }, }; @@ -264,6 +311,16 @@ static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) { } static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) { + // Enumerate all networks + struct loc_network* network = loc_database_enumerator_next_network(self->enumerator); + if (network) { + PyObject* obj = new_network(&NetworkType, network); + loc_network_unref(network); + + return obj; + } + + // Enumerate all ASes struct loc_as* as = loc_database_enumerator_next_as(self->enumerator); if (as) { PyObject* obj = new_as(&ASType, as); diff --git a/src/python/location-query.in b/src/python/location-query.in index 7a7492c..733f068 100644 --- a/src/python/location-query.in +++ b/src/python/location-query.in @@ -73,6 +73,13 @@ class CLI(object): search_as.add_argument("query", nargs=1) search_as.set_defaults(func=self.handle_search_as) + # List all networks in a country + search_as = subparsers.add_parser("list-networks-by-cc", + help=_("Lists all networks in a country"), + ) + search_as.add_argument("country_code", nargs=1) + search_as.set_defaults(func=self.handle_list_networks_by_cc) + return parser.parse_args() def run(self): @@ -170,6 +177,12 @@ class CLI(object): for a in db.search_as(query): print(a) + def handle_list_networks_by_cc(self, db, ns): + for country_code in ns.country_code: + # Print all matching networks + for n in db.search_networks(country_code=country_code): + print(n) + def main(): # Run the command line interface c = CLI() diff --git a/src/test-as.c b/src/test-as.c index 8010a90..d545d9c 100644 --- a/src/test-as.c +++ b/src/test-as.c @@ -98,7 +98,7 @@ int main(int argc, char** argv) { // Enumerator struct loc_database_enumerator* enumerator; - err = loc_database_enumerator_new(&enumerator, db); + err = loc_database_enumerator_new(&enumerator, db, LOC_DB_ENUMERATE_ASES); if (err) { fprintf(stderr, "Could not create a database enumerator\n"); exit(EXIT_FAILURE); -- 2.39.2