]> git.ipfire.org Git - people/ms/libloc.git/blame - src/python/database.c
python: Permit passing family to database enumerator
[people/ms/libloc.git] / src / python / database.c
CommitLineData
9cdf6c53
MT
1/*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15*/
16
17#include <Python.h>
18
9fc7f001 19#include <loc/libloc.h>
84a2f0c2
MT
20#include <loc/as.h>
21#include <loc/as-list.h>
9fc7f001
MT
22#include <loc/database.h>
23
1da9cd39 24#include "locationmodule.h"
86ca7ef7 25#include "as.h"
7c922e9c 26#include "country.h"
9cdf6c53 27#include "database.h"
31edab76 28#include "network.h"
9cdf6c53
MT
29
30static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
9cdf6c53 31 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
9cdf6c53
MT
32
33 return (PyObject*)self;
34}
35
36static void Database_dealloc(DatabaseObject* self) {
37 if (self->db)
38 loc_database_unref(self->db);
39
6fd96715
MT
40 if (self->path)
41 free(self->path);
42
9cdf6c53
MT
43 Py_TYPE(self)->tp_free((PyObject* )self);
44}
45
46static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
47 const char* path = NULL;
48
49 if (!PyArg_ParseTuple(args, "s", &path))
50 return -1;
51
6fd96715
MT
52 self->path = strdup(path);
53
9cdf6c53 54 // Open the file for reading
6fd96715 55 FILE* f = fopen(self->path, "r");
774eea26
MT
56 if (!f) {
57 PyErr_SetFromErrno(PyExc_IOError);
9cdf6c53 58 return -1;
774eea26 59 }
9cdf6c53
MT
60
61 // Load the database
38e07ee0 62 int r = loc_database_new(loc_ctx, &self->db, f);
9cdf6c53
MT
63 fclose(f);
64
65 // Return on any errors
66 if (r)
67 return -1;
68
69 return 0;
70}
71
6fd96715
MT
72static PyObject* Database_repr(DatabaseObject* self) {
73 return PyUnicode_FromFormat("<Database %s>", self->path);
74}
75
726f9984
MT
76static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
77 PyObject* public_key = NULL;
78 FILE* f = NULL;
79
80 // Parse arguments
81 if (!PyArg_ParseTuple(args, "O", &public_key))
82 return NULL;
83
84 // Convert into FILE*
85 int fd = PyObject_AsFileDescriptor(public_key);
86 if (fd < 0)
87 return NULL;
88
89 // Re-open file descriptor
90 f = fdopen(fd, "r");
91 if (!f) {
92 PyErr_SetFromErrno(PyExc_IOError);
93 return NULL;
94 }
95
96 int r = loc_database_verify(self->db, f);
b1720435
MT
97
98 if (r == 0)
99 Py_RETURN_TRUE;
100
101 Py_RETURN_FALSE;
102}
103
d99b0256
MT
104static PyObject* Database_get_description(DatabaseObject* self) {
105 const char* description = loc_database_get_description(self->db);
106
107 return PyUnicode_FromString(description);
108}
109
110static PyObject* Database_get_vendor(DatabaseObject* self) {
111 const char* vendor = loc_database_get_vendor(self->db);
112
113 return PyUnicode_FromString(vendor);
114}
115
4bf49d00
MT
116static PyObject* Database_get_license(DatabaseObject* self) {
117 const char* license = loc_database_get_license(self->db);
118
119 return PyUnicode_FromString(license);
120}
121
53524b2d
MT
122static PyObject* Database_get_created_at(DatabaseObject* self) {
123 time_t created_at = loc_database_created_at(self->db);
124
125 return PyLong_FromLong(created_at);
126}
127
86ca7ef7
MT
128static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
129 struct loc_as* as = NULL;
130 uint32_t number = 0;
131
132 if (!PyArg_ParseTuple(args, "i", &number))
133 return NULL;
134
135 // Try to retrieve the AS
136 int r = loc_database_get_as(self->db, &as, number);
86ca7ef7 137
4a0a0f7e
MT
138 // We got an AS
139 if (r == 0) {
86ca7ef7
MT
140 PyObject* obj = new_as(&ASType, as);
141 loc_as_unref(as);
142
143 return obj;
86ca7ef7
MT
144
145 // Nothing found
4a0a0f7e
MT
146 } else if (r == 1) {
147 Py_RETURN_NONE;
148 }
149
150 // Unexpected error
151 return NULL;
86ca7ef7
MT
152}
153
7c922e9c
MT
154static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
155 const char* country_code = NULL;
156
157 if (!PyArg_ParseTuple(args, "s", &country_code))
158 return NULL;
159
160 struct loc_country* country;
161 int r = loc_database_get_country(self->db, &country, country_code);
162 if (r) {
163 Py_RETURN_NONE;
164 }
165
166 PyObject* obj = new_country(&CountryType, country);
167 loc_country_unref(country);
168
169 return obj;
170}
171
31edab76
MT
172static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
173 struct loc_network* network = NULL;
174 const char* address = NULL;
175
176 if (!PyArg_ParseTuple(args, "s", &address))
177 return NULL;
178
179 // Try to retrieve a matching network
180 int r = loc_database_lookup_from_string(self->db, address, &network);
181
182 // We got a network
183 if (r == 0) {
184 PyObject* obj = new_network(&NetworkType, network);
185 loc_network_unref(network);
186
187 return obj;
188
189 // Nothing found
190 } else if (r == 1) {
191 Py_RETURN_NONE;
927e82f2
MT
192
193 // Invalid input
194 } else if (r == -EINVAL) {
195 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
196 return NULL;
31edab76
MT
197 }
198
199 // Unexpected error
200 return NULL;
201}
202
afb426df
MT
203static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
204 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
205 if (self) {
206 self->enumerator = loc_database_enumerator_ref(enumerator);
207 }
208
209 return (PyObject*)self;
210}
211
d34b669c
MT
212static PyObject* Database_iterate_all(DatabaseObject* self,
213 enum loc_database_enumerator_mode what, int family, int flags) {
a68a46f5
MT
214 struct loc_database_enumerator* enumerator;
215
681ff05c 216 int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
a68a46f5
MT
217 if (r) {
218 PyErr_SetFromErrno(PyExc_SystemError);
219 return NULL;
220 }
221
d34b669c
MT
222 // Set family
223 if (family)
224 loc_database_enumerator_set_family(enumerator, family);
225
a68a46f5
MT
226 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
227 loc_database_enumerator_unref(enumerator);
228
229 return obj;
230}
231
232static PyObject* Database_ases(DatabaseObject* self) {
d34b669c 233 return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, AF_UNSPEC, 0);
a68a46f5
MT
234}
235
afb426df
MT
236static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
237 const char* string = NULL;
238
239 if (!PyArg_ParseTuple(args, "s", &string))
240 return NULL;
241
242 struct loc_database_enumerator* enumerator;
243
681ff05c 244 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
afb426df
MT
245 if (r) {
246 PyErr_SetFromErrno(PyExc_SystemError);
247 return NULL;
248 }
249
250 // Search string we are searching for
251 loc_database_enumerator_set_string(enumerator, string);
252
253 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
254 loc_database_enumerator_unref(enumerator);
255
256 return obj;
257}
258
a68a46f5 259static PyObject* Database_networks(DatabaseObject* self) {
d34b669c 260 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC, 0);
681ff05c
MT
261}
262
263static PyObject* Database_networks_flattened(DatabaseObject *self) {
d34b669c
MT
264 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC,
265 LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
a68a46f5
MT
266}
267
ccc7ab4e 268static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
84a2f0c2 269 char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL };
e646a8f3 270 PyObject* country_codes = NULL;
84a2f0c2 271 PyObject* asn_list = NULL;
bbdb2e0a 272 int flags = 0;
44e5ef71 273 int family = 0;
c242f732 274 int flatten = 0;
ccc7ab4e 275
84a2f0c2
MT
276 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", kwlist,
277 &PyList_Type, &country_codes, &PyList_Type, &asn_list, &flags, &family, &flatten))
ccc7ab4e
MT
278 return NULL;
279
280 struct loc_database_enumerator* enumerator;
c242f732
MT
281 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
282 (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
ccc7ab4e
MT
283 if (r) {
284 PyErr_SetFromErrno(PyExc_SystemError);
285 return NULL;
286 }
287
288 // Set country code we are searching for
e646a8f3
MT
289 if (country_codes) {
290 struct loc_country_list* countries;
291 r = loc_country_list_new(loc_ctx, &countries);
ccc7ab4e 292 if (r) {
e646a8f3 293 PyErr_SetString(PyExc_SystemError, "Could not create country list");
ccc7ab4e
MT
294 return NULL;
295 }
e646a8f3 296
fa9eb94e 297 for (int i = 0; i < PyList_Size(country_codes); i++) {
e646a8f3
MT
298 PyObject* item = PyList_GetItem(country_codes, i);
299
300 if (!PyUnicode_Check(item)) {
301 PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
302 loc_country_list_unref(countries);
303 return NULL;
304 }
305
306 const char* country_code = PyUnicode_AsUTF8(item);
307
308 struct loc_country* country;
309 r = loc_country_new(loc_ctx, &country, country_code);
310 if (r) {
311 if (r == -EINVAL) {
312 PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
313 } else {
314 PyErr_SetString(PyExc_SystemError, "Could not create country");
315 }
316
317 loc_country_list_unref(countries);
318 return NULL;
319 }
320
321 // Append it to the list
322 r = loc_country_list_append(countries, country);
323 if (r) {
324 PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
325
326 loc_country_list_unref(countries);
327 loc_country_unref(country);
328 return NULL;
329 }
330
331 loc_country_unref(country);
332 }
333
c98ebf8a
MT
334 r = loc_database_enumerator_set_countries(enumerator, countries);
335 if (r) {
336 PyErr_SetFromErrno(PyExc_SystemError);
337
a1a00053 338 loc_country_list_unref(countries);
c98ebf8a
MT
339 return NULL;
340 }
e646a8f3 341
e646a8f3 342 loc_country_list_unref(countries);
ccc7ab4e
MT
343 }
344
345 // Set the ASN we are searching for
84a2f0c2
MT
346 if (asn_list) {
347 struct loc_as_list* asns;
348 r = loc_as_list_new(loc_ctx, &asns);
349 if (r) {
350 PyErr_SetString(PyExc_SystemError, "Could not create AS list");
351 return NULL;
352 }
353
fa9eb94e 354 for (int i = 0; i < PyList_Size(asn_list); i++) {
84a2f0c2
MT
355 PyObject* item = PyList_GetItem(asn_list, i);
356
357 if (!PyLong_Check(item)) {
358 PyErr_SetString(PyExc_TypeError, "ASNs must be numbers");
359
360 loc_as_list_unref(asns);
361 return NULL;
362 }
363
364 unsigned long number = PyLong_AsLong(item);
ccc7ab4e 365
84a2f0c2
MT
366 struct loc_as* as;
367 r = loc_as_new(loc_ctx, &as, number);
368 if (r) {
369 PyErr_SetString(PyExc_SystemError, "Could not create AS");
370
371 loc_as_list_unref(asns);
372 loc_as_unref(as);
373 return NULL;
374 }
375
376 r = loc_as_list_append(asns, as);
377 if (r) {
378 PyErr_SetString(PyExc_SystemError, "Could not append AS to the list");
379
380 loc_as_list_unref(asns);
381 loc_as_unref(as);
382 return NULL;
383 }
384
385 loc_as_unref(as);
386 }
387
388 r = loc_database_enumerator_set_asns(enumerator, asns);
ccc7ab4e
MT
389 if (r) {
390 PyErr_SetFromErrno(PyExc_SystemError);
84a2f0c2
MT
391
392 loc_as_list_unref(asns);
ccc7ab4e
MT
393 return NULL;
394 }
84a2f0c2
MT
395
396 loc_as_list_unref(asns);
ccc7ab4e
MT
397 }
398
bbdb2e0a
MT
399 // Set the flags we are searching for
400 if (flags) {
401 r = loc_database_enumerator_set_flag(enumerator, flags);
402
403 if (r) {
404 PyErr_SetFromErrno(PyExc_SystemError);
405 return NULL;
406 }
407 }
408
44e5ef71
MT
409 // Set the family we are searching for
410 if (family) {
411 r = loc_database_enumerator_set_family(enumerator, family);
412
413 if (r) {
414 PyErr_SetFromErrno(PyExc_SystemError);
415 return NULL;
416 }
417 }
418
ccc7ab4e
MT
419 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
420 loc_database_enumerator_unref(enumerator);
421
422 return obj;
423}
424
fa9a3663 425static PyObject* Database_countries(DatabaseObject* self) {
d34b669c 426 return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, AF_UNSPEC, 0);
fa9a3663
MT
427}
428
9cdf6c53 429static struct PyMethodDef Database_methods[] = {
86ca7ef7
MT
430 {
431 "get_as",
432 (PyCFunction)Database_get_as,
433 METH_VARARGS,
434 NULL,
435 },
7c922e9c
MT
436 {
437 "get_country",
438 (PyCFunction)Database_get_country,
439 METH_VARARGS,
440 NULL,
441 },
31edab76
MT
442 {
443 "lookup",
444 (PyCFunction)Database_lookup,
445 METH_VARARGS,
446 NULL,
447 },
afb426df
MT
448 {
449 "search_as",
450 (PyCFunction)Database_search_as,
451 METH_VARARGS,
452 NULL,
453 },
ccc7ab4e
MT
454 {
455 "search_networks",
456 (PyCFunction)Database_search_networks,
457 METH_VARARGS|METH_KEYWORDS,
458 NULL,
459 },
b1720435
MT
460 {
461 "verify",
462 (PyCFunction)Database_verify,
726f9984 463 METH_VARARGS,
b1720435
MT
464 NULL,
465 },
9cdf6c53
MT
466 { NULL },
467};
468
d99b0256 469static struct PyGetSetDef Database_getsetters[] = {
a68a46f5
MT
470 {
471 "ases",
472 (getter)Database_ases,
473 NULL,
474 NULL,
475 NULL,
476 },
fa9a3663
MT
477 {
478 "countries",
479 (getter)Database_countries,
480 NULL,
481 NULL,
482 NULL,
483 },
53524b2d
MT
484 {
485 "created_at",
486 (getter)Database_get_created_at,
487 NULL,
488 NULL,
489 NULL,
490 },
d99b0256
MT
491 {
492 "description",
493 (getter)Database_get_description,
494 NULL,
495 NULL,
496 NULL,
497 },
4bf49d00
MT
498 {
499 "license",
500 (getter)Database_get_license,
501 NULL,
502 NULL,
503 NULL,
504 },
a68a46f5
MT
505 {
506 "networks",
507 (getter)Database_networks,
508 NULL,
509 NULL,
510 NULL,
511 },
681ff05c
MT
512 {
513 "networks_flattened",
514 (getter)Database_networks_flattened,
515 NULL,
516 NULL,
517 NULL,
518 },
d99b0256
MT
519 {
520 "vendor",
521 (getter)Database_get_vendor,
522 NULL,
523 NULL,
53524b2d 524 NULL,
d99b0256
MT
525 },
526 { NULL },
527};
528
9cdf6c53
MT
529PyTypeObject DatabaseType = {
530 PyVarObject_HEAD_INIT(NULL, 0)
d42e1dcd
MT
531 .tp_name = "location.Database",
532 .tp_basicsize = sizeof(DatabaseObject),
533 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
534 .tp_new = Database_new,
535 .tp_dealloc = (destructor)Database_dealloc,
536 .tp_init = (initproc)Database_init,
537 .tp_doc = "Database object",
538 .tp_methods = Database_methods,
539 .tp_getset = Database_getsetters,
540 .tp_repr = (reprfunc)Database_repr,
9cdf6c53 541};
afb426df
MT
542
543static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
544 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
545
546 return (PyObject*)self;
547}
548
549static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
550 loc_database_enumerator_unref(self->enumerator);
551
552 Py_TYPE(self)->tp_free((PyObject* )self);
553}
554
555static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
15f79e2d
MT
556 struct loc_network* network = NULL;
557
ccc7ab4e 558 // Enumerate all networks
15f79e2d
MT
559 int r = loc_database_enumerator_next_network(self->enumerator, &network);
560 if (r) {
20bb6d0c 561 PyErr_SetFromErrno(PyExc_ValueError);
15f79e2d
MT
562 return NULL;
563 }
564
565 // A network was found
ccc7ab4e
MT
566 if (network) {
567 PyObject* obj = new_network(&NetworkType, network);
568 loc_network_unref(network);
569
570 return obj;
571 }
572
573 // Enumerate all ASes
15f79e2d
MT
574 struct loc_as* as = NULL;
575
576 r = loc_database_enumerator_next_as(self->enumerator, &as);
577 if (r) {
20bb6d0c 578 PyErr_SetFromErrno(PyExc_ValueError);
15f79e2d
MT
579 return NULL;
580 }
581
afb426df
MT
582 if (as) {
583 PyObject* obj = new_as(&ASType, as);
584 loc_as_unref(as);
585
586 return obj;
587 }
588
fa9a3663
MT
589 // Enumerate all countries
590 struct loc_country* country = NULL;
591
592 r = loc_database_enumerator_next_country(self->enumerator, &country);
593 if (r) {
594 PyErr_SetFromErrno(PyExc_ValueError);
595 return NULL;
596 }
597
598 if (country) {
599 PyObject* obj = new_country(&CountryType, country);
600 loc_country_unref(country);
601
602 return obj;
603 }
604
afb426df
MT
605 // Nothing found, that means the end
606 PyErr_SetNone(PyExc_StopIteration);
607 return NULL;
608}
609
610PyTypeObject DatabaseEnumeratorType = {
611 PyVarObject_HEAD_INIT(NULL, 0)
d42e1dcd
MT
612 .tp_name = "location.DatabaseEnumerator",
613 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
614 .tp_flags = Py_TPFLAGS_DEFAULT,
615 .tp_alloc = PyType_GenericAlloc,
616 .tp_new = DatabaseEnumerator_new,
617 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
618 .tp_iter = PyObject_SelfIter,
619 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
afb426df 620};