From e646a8f35ec7eff009414b3fd107c9af5cf39a86 Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Mon, 16 Nov 2020 15:13:28 +0000 Subject: [PATCH] Implement filtering for multiple countries in the enumerator This will allow us to speed up the export of the database if only a few countries should be returned. Signed-off-by: Michael Tremer --- Makefile.am | 2 + src/country-list.c | 138 +++++++++++++++++++++++++++++++++++++++++ src/country.c | 3 + src/database.c | 47 ++++++-------- src/libloc.sym | 15 ++++- src/loc/country-list.h | 43 +++++++++++++ src/loc/database.h | 5 +- src/python/database.c | 57 ++++++++++++++--- 8 files changed, 274 insertions(+), 36 deletions(-) create mode 100644 src/country-list.c create mode 100644 src/loc/country-list.h diff --git a/Makefile.am b/Makefile.am index f0d8c4c..f4ca3c8 100644 --- a/Makefile.am +++ b/Makefile.am @@ -93,6 +93,7 @@ pkginclude_HEADERS = \ src/loc/as.h \ src/loc/compat.h \ src/loc/country.h \ + src/loc/country-list.h \ src/loc/database.h \ src/loc/format.h \ src/loc/network.h \ @@ -109,6 +110,7 @@ src_libloc_la_SOURCES = \ src/libloc.c \ src/as.c \ src/country.c \ + src/country-list.c \ src/database.c \ src/network.c \ src/network-list.c \ diff --git a/src/country-list.c b/src/country-list.c new file mode 100644 index 0000000..ae0d71a --- /dev/null +++ b/src/country-list.c @@ -0,0 +1,138 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2020 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include + +#include +#include +#include + +struct loc_country_list { + struct loc_ctx* ctx; + int refcount; + + struct loc_country* list[1024]; + size_t size; + size_t max_size; +}; + +LOC_EXPORT int loc_country_list_new(struct loc_ctx* ctx, + struct loc_country_list** list) { + struct loc_country_list* l = calloc(1, sizeof(*l)); + if (!l) + return -ENOMEM; + + l->ctx = loc_ref(ctx); + l->refcount = 1; + + // Do not allow this list to grow larger than this + l->max_size = 1024; + + DEBUG(l->ctx, "Country list allocated at %p\n", l); + *list = l; + + return 0; +} + +LOC_EXPORT struct loc_country_list* loc_country_list_ref(struct loc_country_list* list) { + list->refcount++; + + return list; +} + +static void loc_country_list_free(struct loc_country_list* list) { + DEBUG(list->ctx, "Releasing country list at %p\n", list); + + loc_country_list_clear(list); + + loc_unref(list->ctx); + free(list); +} + +LOC_EXPORT struct loc_country_list* loc_country_list_unref(struct loc_country_list* list) { + if (!list) + return NULL; + + if (--list->refcount > 0) + return list; + + loc_country_list_free(list); + return NULL; +} + +LOC_EXPORT size_t loc_country_list_size(struct loc_country_list* list) { + return list->size; +} + +LOC_EXPORT int loc_country_list_empty(struct loc_country_list* list) { + return list->size == 0; +} + +LOC_EXPORT void loc_country_list_clear(struct loc_country_list* list) { + for (unsigned int i = 0; i < list->size; i++) + loc_country_unref(list->list[i]); +} + +LOC_EXPORT struct loc_country* loc_country_list_get(struct loc_country_list* list, size_t index) { + // Check index + if (index >= list->size) + return NULL; + + return loc_country_ref(list->list[index]); +} + +LOC_EXPORT int loc_country_list_append( + struct loc_country_list* list, struct loc_country* country) { + if (loc_country_list_contains(list, country)) + return 0; + + // Check if we have space left + if (list->size == list->max_size) { + ERROR(list->ctx, "%p: Could not append country to the list. List full\n", list); + return -ENOMEM; + } + + DEBUG(list->ctx, "%p: Appending country %p to list\n", list, country); + + list->list[list->size++] = loc_country_ref(country); + + return 0; +} + +LOC_EXPORT int loc_country_list_contains( + struct loc_country_list* list, struct loc_country* country) { + for (unsigned int i = 0; i < list->size; i++) { + if (loc_country_cmp(country, list->list[i]) == 0) + return 1; + } + + return 0; +} + +LOC_EXPORT int loc_country_list_contains_code( + struct loc_country_list* list, const char* code) { + struct loc_country* country; + + int r = loc_country_new(list->ctx, &country, code); + if (r) + return -1; + + r = loc_country_list_contains(list, country); + loc_country_unref(country); + + return r; +} diff --git a/src/country.c b/src/country.c index 2ba93e6..7aac0db 100644 --- a/src/country.c +++ b/src/country.c @@ -34,6 +34,9 @@ struct loc_country { }; LOC_EXPORT int loc_country_new(struct loc_ctx* ctx, struct loc_country** country, const char* country_code) { + if (!loc_country_code_is_valid(country_code)) + return -EINVAL; + struct loc_country* c = calloc(1, sizeof(*c)); if (!c) return -ENOMEM; diff --git a/src/database.c b/src/database.c index b9d870f..29823b2 100644 --- a/src/database.c +++ b/src/database.c @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -99,7 +100,7 @@ struct loc_database_enumerator { // Search string char* string; - char country_code[3]; + struct loc_country_list* countries; uint32_t asn; enum loc_network_flags flags; int family; @@ -1017,33 +1018,20 @@ LOC_EXPORT int loc_database_enumerator_set_string(struct loc_database_enumerator return 0; } -LOC_EXPORT int loc_database_enumerator_set_country_code(struct loc_database_enumerator* enumerator, const char* country_code) { - // Set empty country code - if (!country_code || !*country_code) { - *enumerator->country_code = '\0'; - return 0; - } +LOC_EXPORT struct loc_country_list* loc_database_enumerator_get_countries( + struct loc_database_enumerator* enumerator) { + if (!enumerator->countries) + return NULL; - // Treat A1, A2, A3 as special country codes, - // but perform search for flags instead - if (strcmp(country_code, "A1") == 0) { - return loc_database_enumerator_set_flag(enumerator, - LOC_NETWORK_FLAG_ANONYMOUS_PROXY); - } else if (strcmp(country_code, "A2") == 0) { - return loc_database_enumerator_set_flag(enumerator, - LOC_NETWORK_FLAG_SATELLITE_PROVIDER); - } else if (strcmp(country_code, "A3") == 0) { - return loc_database_enumerator_set_flag(enumerator, - LOC_NETWORK_FLAG_ANYCAST); - } + return loc_country_list_ref(enumerator->countries); +} - // Country codes must be two characters - if (!loc_country_code_is_valid(country_code)) - return -EINVAL; +LOC_EXPORT int loc_database_enumerator_set_countries( + struct loc_database_enumerator* enumerator, struct loc_country_list* countries) { + if (enumerator->countries) + loc_country_list_unref(enumerator->countries); - for (unsigned int i = 0; i < 3; i++) { - enumerator->country_code[i] = country_code[i]; - } + enumerator->countries = loc_country_list_ref(countries); return 0; } @@ -1129,6 +1117,12 @@ static int loc_database_enumerator_stack_push_node( return 0; } +static int loc_network_match_countries(struct loc_network* network, struct loc_country_list* countries) { + const char* country_code = loc_network_get_country_code(network); + + return loc_country_list_contains_code(countries, country_code); +} + static int loc_database_enumerator_filter_network( struct loc_database_enumerator* enumerator, struct loc_network* network) { // Skip if the family does not match @@ -1136,8 +1130,7 @@ static int loc_database_enumerator_filter_network( return 1; // Skip if the country code does not match - if (*enumerator->country_code && - !loc_network_match_country_code(network, enumerator->country_code)) + if (enumerator->countries && !loc_network_match_countries(network, enumerator->countries)) return 1; // Skip if the ASN does not match diff --git a/src/libloc.sym b/src/libloc.sym index 453a1be..40e9f88 100644 --- a/src/libloc.sym +++ b/src/libloc.sym @@ -49,6 +49,18 @@ global: loc_country_set_name; loc_country_unref; + # Country List + loc_country_list_append; + loc_country_list_clear; + loc_country_list_contains; + loc_country_list_contains_code; + loc_country_list_empty; + loc_country_list_get; + loc_country_list_new; + loc_country_list_ref; + loc_country_list_size; + loc_country_list_unref; + # Database loc_database_add_as; loc_database_count_as; @@ -66,13 +78,14 @@ global: loc_database_verify; # Database Enumerator + loc_database_enumerator_get_countries; loc_database_enumerator_new; loc_database_enumerator_next_as; loc_database_enumerator_next_country; loc_database_enumerator_next_network; loc_database_enumerator_ref; loc_database_enumerator_set_asn; - loc_database_enumerator_set_country_code; + loc_database_enumerator_set_countries; loc_database_enumerator_set_family; loc_database_enumerator_set_flag; loc_database_enumerator_set_string; diff --git a/src/loc/country-list.h b/src/loc/country-list.h new file mode 100644 index 0000000..a7f818a --- /dev/null +++ b/src/loc/country-list.h @@ -0,0 +1,43 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2017 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LIBLOC_COUNTRY_LIST_H +#define LIBLOC_COUNTRY_LIST_H + +#include + +#include +#include + +struct loc_country_list; + +int loc_country_list_new(struct loc_ctx* ctx, struct loc_country_list** list); +struct loc_country_list* loc_country_list_ref(struct loc_country_list* list); +struct loc_country_list* loc_country_list_unref(struct loc_country_list* list); + +size_t loc_country_list_size(struct loc_country_list* list); +int loc_country_list_empty(struct loc_country_list* list); +void loc_country_list_clear(struct loc_country_list* list); + +struct loc_country* loc_country_list_get(struct loc_country_list* list, size_t index); +int loc_country_list_append(struct loc_country_list* list, struct loc_country* country); + +int loc_country_list_contains( + struct loc_country_list* list, struct loc_country* country); +int loc_country_list_contains_code( + struct loc_country_list* list, const char* code); + +#endif diff --git a/src/loc/database.h b/src/loc/database.h index 14eb5ea..246e5c5 100644 --- a/src/loc/database.h +++ b/src/loc/database.h @@ -25,6 +25,7 @@ #include #include #include +#include struct loc_database; int loc_database_new(struct loc_ctx* ctx, struct loc_database** database, FILE* f); @@ -66,7 +67,9 @@ struct loc_database_enumerator* loc_database_enumerator_ref(struct loc_database_ struct loc_database_enumerator* loc_database_enumerator_unref(struct loc_database_enumerator* enumerator); int loc_database_enumerator_set_string(struct loc_database_enumerator* enumerator, const char* string); -int loc_database_enumerator_set_country_code(struct loc_database_enumerator* enumerator, const char* country_code); +struct loc_country_list* loc_database_enumerator_get_countries(struct loc_database_enumerator* enumerator); +int loc_database_enumerator_set_countries( + struct loc_database_enumerator* enumerator, struct loc_country_list* countries); int loc_database_enumerator_set_asn(struct loc_database_enumerator* enumerator, unsigned int asn); int loc_database_enumerator_set_flag(struct loc_database_enumerator* enumerator, enum loc_network_flags flag); int loc_database_enumerator_set_family(struct loc_database_enumerator* enumerator, int family); diff --git a/src/python/database.c b/src/python/database.c index d169547..e6f6f37 100644 --- a/src/python/database.c +++ b/src/python/database.c @@ -258,14 +258,15 @@ static PyObject* Database_networks_flattened(DatabaseObject *self) { } static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) { - char* kwlist[] = { "country_code", "asn", "flags", "family", "flatten", NULL }; - const char* country_code = NULL; + char* kwlist[] = { "country_codes", "asn", "flags", "family", "flatten", NULL }; + PyObject* country_codes = NULL; unsigned int asn = 0; int flags = 0; int family = 0; int flatten = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siiip", kwlist, &country_code, &asn, &flags, &family, &flatten)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!iiip", kwlist, + &PyList_Type, &country_codes, &asn, &flags, &family, &flatten)) return NULL; struct loc_database_enumerator* enumerator; @@ -277,13 +278,55 @@ static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, } // Set country code we are searching for - if (country_code) { - r = loc_database_enumerator_set_country_code(enumerator, country_code); - + if (country_codes) { + struct loc_country_list* countries; + r = loc_country_list_new(loc_ctx, &countries); if (r) { - PyErr_SetFromErrno(PyExc_SystemError); + PyErr_SetString(PyExc_SystemError, "Could not create country list"); return NULL; } + + for (unsigned int i = 0; i < PyList_Size(country_codes); i++) { + PyObject* item = PyList_GetItem(country_codes, i); + + if (!PyUnicode_Check(item)) { + PyErr_SetString(PyExc_TypeError, "Country codes must be strings"); + loc_country_list_unref(countries); + return NULL; + } + + const char* country_code = PyUnicode_AsUTF8(item); + + struct loc_country* country; + r = loc_country_new(loc_ctx, &country, country_code); + if (r) { + if (r == -EINVAL) { + PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code); + } else { + PyErr_SetString(PyExc_SystemError, "Could not create country"); + } + + loc_country_list_unref(countries); + return NULL; + } + + // Append it to the list + r = loc_country_list_append(countries, country); + if (r) { + PyErr_SetString(PyExc_SystemError, "Could not append country to the list"); + + loc_country_list_unref(countries); + loc_country_unref(country); + return NULL; + } + + loc_country_unref(country); + } + + loc_database_enumerator_set_countries(enumerator, countries); + + Py_DECREF(country_codes); + loc_country_list_unref(countries); } // Set the ASN we are searching for -- 2.39.2