]> git.ipfire.org Git - people/ms/libloc.git/blobdiff - src/database.c
Implement filtering for multiple countries in the enumerator
[people/ms/libloc.git] / src / database.c
index b9d870f44a7dd8ffc84b310f46624d1a040804e0..29823b2b512c8249a82f0c816bdde631beb2d829 100644 (file)
@@ -40,6 +40,7 @@
 #include <loc/as.h>
 #include <loc/compat.h>
 #include <loc/country.h>
+#include <loc/country-list.h>
 #include <loc/database.h>
 #include <loc/format.h>
 #include <loc/network.h>
@@ -99,7 +100,7 @@ struct loc_database_enumerator {
 
        // Search string
        char* string;
-       char country_code[3];
+       struct loc_country_list* countries;
        uint32_t asn;
        enum loc_network_flags flags;
        int family;
@@ -1017,33 +1018,20 @@ LOC_EXPORT int loc_database_enumerator_set_string(struct loc_database_enumerator
        return 0;
 }
 
-LOC_EXPORT int loc_database_enumerator_set_country_code(struct loc_database_enumerator* enumerator, const char* country_code) {
-       // Set empty country code
-       if (!country_code || !*country_code) {
-               *enumerator->country_code = '\0';
-               return 0;
-       }
+LOC_EXPORT struct loc_country_list* loc_database_enumerator_get_countries(
+               struct loc_database_enumerator* enumerator) {
+       if (!enumerator->countries)
+               return NULL;
 
-       // Treat A1, A2, A3 as special country codes,
-       // but perform search for flags instead
-       if (strcmp(country_code, "A1") == 0) {
-               return loc_database_enumerator_set_flag(enumerator,
-                       LOC_NETWORK_FLAG_ANONYMOUS_PROXY);
-       } else if (strcmp(country_code, "A2") == 0) {
-               return loc_database_enumerator_set_flag(enumerator,
-                       LOC_NETWORK_FLAG_SATELLITE_PROVIDER);
-       } else if (strcmp(country_code, "A3") == 0) {
-               return loc_database_enumerator_set_flag(enumerator,
-                       LOC_NETWORK_FLAG_ANYCAST);
-       }
+       return loc_country_list_ref(enumerator->countries);
+}
 
-       // Country codes must be two characters
-       if (!loc_country_code_is_valid(country_code))
-               return -EINVAL;
+LOC_EXPORT int loc_database_enumerator_set_countries(
+               struct loc_database_enumerator* enumerator, struct loc_country_list* countries) {
+       if (enumerator->countries)
+               loc_country_list_unref(enumerator->countries);
 
-       for (unsigned int i = 0; i < 3; i++) {
-               enumerator->country_code[i] = country_code[i];
-       }
+       enumerator->countries = loc_country_list_ref(countries);
 
        return 0;
 }
@@ -1129,6 +1117,12 @@ static int loc_database_enumerator_stack_push_node(
        return 0;
 }
 
+static int loc_network_match_countries(struct loc_network* network, struct loc_country_list* countries) {
+       const char* country_code = loc_network_get_country_code(network);
+
+       return loc_country_list_contains_code(countries, country_code);
+}
+
 static int loc_database_enumerator_filter_network(
                struct loc_database_enumerator* enumerator, struct loc_network* network) {
        // Skip if the family does not match
@@ -1136,8 +1130,7 @@ static int loc_database_enumerator_filter_network(
                return 1;
 
        // Skip if the country code does not match
-       if (*enumerator->country_code &&
-                       !loc_network_match_country_code(network, enumerator->country_code))
+       if (enumerator->countries && !loc_network_match_countries(network, enumerator->countries))
                return 1;
 
        // Skip if the ASN does not match