]> git.ipfire.org Git - location/libloc.git/commitdiff
Implement filtering for multiple countries in the enumerator
authorMichael Tremer <michael.tremer@ipfire.org>
Mon, 16 Nov 2020 15:13:28 +0000 (15:13 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Mon, 16 Nov 2020 15:13:28 +0000 (15:13 +0000)
This will allow us to speed up the export of the database
if only a few countries should be returned.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/country-list.c [new file with mode: 0644]
src/country.c
src/database.c
src/libloc.sym
src/loc/country-list.h [new file with mode: 0644]
src/loc/database.h
src/python/database.c

index f0d8c4ce5bb2f0e0ba1bea47d704b91b828ada3b..f4ca3c802881298fcbe86f845f6c607b473b7e9a 100644 (file)
@@ -93,6 +93,7 @@ pkginclude_HEADERS = \
        src/loc/as.h \
        src/loc/compat.h \
        src/loc/country.h \
+       src/loc/country-list.h \
        src/loc/database.h \
        src/loc/format.h \
        src/loc/network.h \
@@ -109,6 +110,7 @@ src_libloc_la_SOURCES = \
        src/libloc.c \
        src/as.c \
        src/country.c \
+       src/country-list.c \
        src/database.c \
        src/network.c \
        src/network-list.c \
diff --git a/src/country-list.c b/src/country-list.c
new file mode 100644 (file)
index 0000000..ae0d71a
--- /dev/null
@@ -0,0 +1,138 @@
+/*
+       libloc - A library to determine the location of someone on the Internet
+
+       Copyright (C) 2020 IPFire Development Team <info@ipfire.org>
+
+       This library is free software; you can redistribute it and/or
+       modify it under the terms of the GNU Lesser General Public
+       License as published by the Free Software Foundation; either
+       version 2.1 of the License, or (at your option) any later version.
+
+       This library is distributed in the hope that it will be useful,
+       but WITHOUT ANY WARRANTY; without even the implied warranty of
+       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+       Lesser General Public License for more details.
+*/
+
+#include <errno.h>
+#include <stdlib.h>
+
+#include <loc/country.h>
+#include <loc/country-list.h>
+#include <loc/private.h>
+
+struct loc_country_list {
+       struct loc_ctx* ctx;
+       int refcount;
+
+       struct loc_country* list[1024];
+       size_t size;
+       size_t max_size;
+};
+
+LOC_EXPORT int loc_country_list_new(struct loc_ctx* ctx,
+               struct loc_country_list** list) {
+       struct loc_country_list* l = calloc(1, sizeof(*l));
+       if (!l)
+               return -ENOMEM;
+
+       l->ctx = loc_ref(ctx);
+       l->refcount = 1;
+
+       // Do not allow this list to grow larger than this
+       l->max_size = 1024;
+
+       DEBUG(l->ctx, "Country list allocated at %p\n", l);
+       *list = l;
+
+       return 0;
+}
+
+LOC_EXPORT struct loc_country_list* loc_country_list_ref(struct loc_country_list* list) {
+       list->refcount++;
+
+       return list;
+}
+
+static void loc_country_list_free(struct loc_country_list* list) {
+       DEBUG(list->ctx, "Releasing country list at %p\n", list);
+
+       loc_country_list_clear(list);
+
+       loc_unref(list->ctx);
+       free(list);
+}
+
+LOC_EXPORT struct loc_country_list* loc_country_list_unref(struct loc_country_list* list) {
+       if (!list)
+               return NULL;
+
+       if (--list->refcount > 0)
+               return list;
+
+       loc_country_list_free(list);
+       return NULL;
+}
+
+LOC_EXPORT size_t loc_country_list_size(struct loc_country_list* list) {
+       return list->size;
+}
+
+LOC_EXPORT int loc_country_list_empty(struct loc_country_list* list) {
+       return list->size == 0;
+}
+
+LOC_EXPORT void loc_country_list_clear(struct loc_country_list* list) {
+       for (unsigned int i = 0; i < list->size; i++)
+               loc_country_unref(list->list[i]);
+}
+
+LOC_EXPORT struct loc_country* loc_country_list_get(struct loc_country_list* list, size_t index) {
+       // Check index
+       if (index >= list->size)
+               return NULL;
+
+       return loc_country_ref(list->list[index]);
+}
+
+LOC_EXPORT int loc_country_list_append(
+               struct loc_country_list* list, struct loc_country* country) {
+       if (loc_country_list_contains(list, country))
+               return 0;
+
+       // Check if we have space left
+       if (list->size == list->max_size) {
+               ERROR(list->ctx, "%p: Could not append country to the list. List full\n", list);
+               return -ENOMEM;
+       }
+
+       DEBUG(list->ctx, "%p: Appending country %p to list\n", list, country);
+
+       list->list[list->size++] = loc_country_ref(country);
+
+       return 0;
+}
+
+LOC_EXPORT int loc_country_list_contains(
+               struct loc_country_list* list, struct loc_country* country) {
+       for (unsigned int i = 0; i < list->size; i++) {
+               if (loc_country_cmp(country, list->list[i]) == 0)
+                       return 1;
+       }
+
+       return 0;
+}
+
+LOC_EXPORT int loc_country_list_contains_code(
+               struct loc_country_list* list, const char* code) {
+       struct loc_country* country;
+
+       int r = loc_country_new(list->ctx, &country, code);
+       if (r)
+               return -1;
+
+       r = loc_country_list_contains(list, country);
+       loc_country_unref(country);
+
+       return r;
+}
index 2ba93e6ef7d5c225d032b8b307ad76942659b0e0..7aac0dba3ae50acffffb52284cdcfeb7fc73fe92 100644 (file)
@@ -34,6 +34,9 @@ struct loc_country {
 };
 
 LOC_EXPORT int loc_country_new(struct loc_ctx* ctx, struct loc_country** country, const char* country_code) {
+       if (!loc_country_code_is_valid(country_code))
+               return -EINVAL;
+
        struct loc_country* c = calloc(1, sizeof(*c));
        if (!c)
                return -ENOMEM;
index b9d870f44a7dd8ffc84b310f46624d1a040804e0..29823b2b512c8249a82f0c816bdde631beb2d829 100644 (file)
@@ -40,6 +40,7 @@
 #include <loc/as.h>
 #include <loc/compat.h>
 #include <loc/country.h>
+#include <loc/country-list.h>
 #include <loc/database.h>
 #include <loc/format.h>
 #include <loc/network.h>
@@ -99,7 +100,7 @@ struct loc_database_enumerator {
 
        // Search string
        char* string;
-       char country_code[3];
+       struct loc_country_list* countries;
        uint32_t asn;
        enum loc_network_flags flags;
        int family;
@@ -1017,33 +1018,20 @@ LOC_EXPORT int loc_database_enumerator_set_string(struct loc_database_enumerator
        return 0;
 }
 
-LOC_EXPORT int loc_database_enumerator_set_country_code(struct loc_database_enumerator* enumerator, const char* country_code) {
-       // Set empty country code
-       if (!country_code || !*country_code) {
-               *enumerator->country_code = '\0';
-               return 0;
-       }
+LOC_EXPORT struct loc_country_list* loc_database_enumerator_get_countries(
+               struct loc_database_enumerator* enumerator) {
+       if (!enumerator->countries)
+               return NULL;
 
-       // Treat A1, A2, A3 as special country codes,
-       // but perform search for flags instead
-       if (strcmp(country_code, "A1") == 0) {
-               return loc_database_enumerator_set_flag(enumerator,
-                       LOC_NETWORK_FLAG_ANONYMOUS_PROXY);
-       } else if (strcmp(country_code, "A2") == 0) {
-               return loc_database_enumerator_set_flag(enumerator,
-                       LOC_NETWORK_FLAG_SATELLITE_PROVIDER);
-       } else if (strcmp(country_code, "A3") == 0) {
-               return loc_database_enumerator_set_flag(enumerator,
-                       LOC_NETWORK_FLAG_ANYCAST);
-       }
+       return loc_country_list_ref(enumerator->countries);
+}
 
-       // Country codes must be two characters
-       if (!loc_country_code_is_valid(country_code))
-               return -EINVAL;
+LOC_EXPORT int loc_database_enumerator_set_countries(
+               struct loc_database_enumerator* enumerator, struct loc_country_list* countries) {
+       if (enumerator->countries)
+               loc_country_list_unref(enumerator->countries);
 
-       for (unsigned int i = 0; i < 3; i++) {
-               enumerator->country_code[i] = country_code[i];
-       }
+       enumerator->countries = loc_country_list_ref(countries);
 
        return 0;
 }
@@ -1129,6 +1117,12 @@ static int loc_database_enumerator_stack_push_node(
        return 0;
 }
 
+static int loc_network_match_countries(struct loc_network* network, struct loc_country_list* countries) {
+       const char* country_code = loc_network_get_country_code(network);
+
+       return loc_country_list_contains_code(countries, country_code);
+}
+
 static int loc_database_enumerator_filter_network(
                struct loc_database_enumerator* enumerator, struct loc_network* network) {
        // Skip if the family does not match
@@ -1136,8 +1130,7 @@ static int loc_database_enumerator_filter_network(
                return 1;
 
        // Skip if the country code does not match
-       if (*enumerator->country_code &&
-                       !loc_network_match_country_code(network, enumerator->country_code))
+       if (enumerator->countries && !loc_network_match_countries(network, enumerator->countries))
                return 1;
 
        // Skip if the ASN does not match
index 453a1beaf3c25eeb0735195a19a935b80fef9229..40e9f88105a90f92022d4e82545792bb42ed9cf8 100644 (file)
@@ -49,6 +49,18 @@ global:
        loc_country_set_name;
        loc_country_unref;
 
+       # Country List
+       loc_country_list_append;
+       loc_country_list_clear;
+       loc_country_list_contains;
+       loc_country_list_contains_code;
+       loc_country_list_empty;
+       loc_country_list_get;
+       loc_country_list_new;
+       loc_country_list_ref;
+       loc_country_list_size;
+       loc_country_list_unref;
+
        # Database
        loc_database_add_as;
        loc_database_count_as;
@@ -66,13 +78,14 @@ global:
        loc_database_verify;
 
        # Database Enumerator
+       loc_database_enumerator_get_countries;
        loc_database_enumerator_new;
        loc_database_enumerator_next_as;
        loc_database_enumerator_next_country;
        loc_database_enumerator_next_network;
        loc_database_enumerator_ref;
        loc_database_enumerator_set_asn;
-       loc_database_enumerator_set_country_code;
+       loc_database_enumerator_set_countries;
        loc_database_enumerator_set_family;
        loc_database_enumerator_set_flag;
        loc_database_enumerator_set_string;
diff --git a/src/loc/country-list.h b/src/loc/country-list.h
new file mode 100644 (file)
index 0000000..a7f818a
--- /dev/null
@@ -0,0 +1,43 @@
+/*
+       libloc - A library to determine the location of someone on the Internet
+
+       Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
+
+       This library is free software; you can redistribute it and/or
+       modify it under the terms of the GNU Lesser General Public
+       License as published by the Free Software Foundation; either
+       version 2.1 of the License, or (at your option) any later version.
+
+       This library is distributed in the hope that it will be useful,
+       but WITHOUT ANY WARRANTY; without even the implied warranty of
+       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+       Lesser General Public License for more details.
+*/
+
+#ifndef LIBLOC_COUNTRY_LIST_H
+#define LIBLOC_COUNTRY_LIST_H
+
+#include <stdlib.h>
+
+#include <loc/libloc.h>
+#include <loc/country.h>
+
+struct loc_country_list;
+
+int loc_country_list_new(struct loc_ctx* ctx, struct loc_country_list** list);
+struct loc_country_list* loc_country_list_ref(struct loc_country_list* list);
+struct loc_country_list* loc_country_list_unref(struct loc_country_list* list);
+
+size_t loc_country_list_size(struct loc_country_list* list);
+int loc_country_list_empty(struct loc_country_list* list);
+void loc_country_list_clear(struct loc_country_list* list);
+
+struct loc_country* loc_country_list_get(struct loc_country_list* list, size_t index);
+int loc_country_list_append(struct loc_country_list* list, struct loc_country* country);
+
+int loc_country_list_contains(
+       struct loc_country_list* list, struct loc_country* country);
+int loc_country_list_contains_code(
+       struct loc_country_list* list, const char* code);
+
+#endif
index 14eb5ea273d80453e959e36c4f5cb6fc1e0a4742..246e5c550cfb7a418e53d44eb47bb4ed406f665c 100644 (file)
@@ -25,6 +25,7 @@
 #include <loc/network.h>
 #include <loc/as.h>
 #include <loc/country.h>
+#include <loc/country-list.h>
 
 struct loc_database;
 int loc_database_new(struct loc_ctx* ctx, struct loc_database** database, FILE* f);
@@ -66,7 +67,9 @@ struct loc_database_enumerator* loc_database_enumerator_ref(struct loc_database_
 struct loc_database_enumerator* loc_database_enumerator_unref(struct loc_database_enumerator* enumerator);
 
 int loc_database_enumerator_set_string(struct loc_database_enumerator* enumerator, const char* string);
-int loc_database_enumerator_set_country_code(struct loc_database_enumerator* enumerator, const char* country_code);
+struct loc_country_list* loc_database_enumerator_get_countries(struct loc_database_enumerator* enumerator);
+int loc_database_enumerator_set_countries(
+       struct loc_database_enumerator* enumerator, struct loc_country_list* countries);
 int loc_database_enumerator_set_asn(struct loc_database_enumerator* enumerator, unsigned int asn);
 int loc_database_enumerator_set_flag(struct loc_database_enumerator* enumerator, enum loc_network_flags flag);
 int loc_database_enumerator_set_family(struct loc_database_enumerator* enumerator, int family);
index d169547dfa9782a00526c5e55738cfe8d4ef0dc5..e6f6f37e2bc437a55c5b3c84ef5672a0faa7353f 100644 (file)
@@ -258,14 +258,15 @@ static PyObject* Database_networks_flattened(DatabaseObject *self) {
 }
 
 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
-       char* kwlist[] = { "country_code", "asn", "flags", "family", "flatten", NULL };
-       const char* country_code = NULL;
+       char* kwlist[] = { "country_codes", "asn", "flags", "family", "flatten", NULL };
+       PyObject* country_codes = NULL;
        unsigned int asn = 0;
        int flags = 0;
        int family = 0;
        int flatten = 0;
 
-       if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siiip", kwlist, &country_code, &asn, &flags, &family, &flatten))
+       if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!iiip", kwlist,
+                       &PyList_Type, &country_codes, &asn, &flags, &family, &flatten))
                return NULL;
 
        struct loc_database_enumerator* enumerator;
@@ -277,13 +278,55 @@ static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args,
        }
 
        // Set country code we are searching for
-       if (country_code) {
-               r = loc_database_enumerator_set_country_code(enumerator, country_code);
-
+       if (country_codes) {
+               struct loc_country_list* countries;
+               r = loc_country_list_new(loc_ctx, &countries);
                if (r) {
-                       PyErr_SetFromErrno(PyExc_SystemError);
+                       PyErr_SetString(PyExc_SystemError, "Could not create country list");
                        return NULL;
                }
+
+               for (unsigned int i = 0; i < PyList_Size(country_codes); i++) {
+                       PyObject* item = PyList_GetItem(country_codes, i);
+
+                       if (!PyUnicode_Check(item)) {
+                               PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
+                               loc_country_list_unref(countries);
+                               return NULL;
+                       }
+
+                       const char* country_code = PyUnicode_AsUTF8(item);
+
+                       struct loc_country* country;
+                       r = loc_country_new(loc_ctx, &country, country_code);
+                       if (r) {
+                               if (r == -EINVAL) {
+                                       PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
+                               } else {
+                                       PyErr_SetString(PyExc_SystemError, "Could not create country");
+                               }
+
+                               loc_country_list_unref(countries);
+                               return NULL;
+                       }
+
+                       // Append it to the list
+                       r = loc_country_list_append(countries, country);
+                       if (r) {
+                               PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
+
+                               loc_country_list_unref(countries);
+                               loc_country_unref(country);
+                               return NULL;
+                       }
+
+                       loc_country_unref(country);
+               }
+
+               loc_database_enumerator_set_countries(enumerator, countries);
+
+               Py_DECREF(country_codes);
+               loc_country_list_unref(countries);
        }
 
        // Set the ASN we are searching for