]> git.ipfire.org Git - location/libloc.git/blob - src/python/database.c
203bb3964eaed4a089a81fbf21647823f6e6d869
[location/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <libloc/libloc.h>
20 #include <libloc/as.h>
21 #include <libloc/as-list.h>
22 #include <libloc/database.h>
23
24 #include "locationmodule.h"
25 #include "as.h"
26 #include "country.h"
27 #include "database.h"
28 #include "network.h"
29
30 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
31 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
32
33 return (PyObject*)self;
34 }
35
36 static void Database_dealloc(DatabaseObject* self) {
37 if (self->db)
38 loc_database_unref(self->db);
39
40 if (self->path)
41 free(self->path);
42
43 Py_TYPE(self)->tp_free((PyObject* )self);
44 }
45
46 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
47 const char* path = NULL;
48 FILE* f = NULL;
49
50 // Parse arguments
51 if (!PyArg_ParseTuple(args, "s", &path))
52 return -1;
53
54 // Copy path
55 self->path = strdup(path);
56 if (!self->path)
57 goto ERROR;
58
59 // Open the file for reading
60 f = fopen(self->path, "r");
61 if (!f)
62 goto ERROR;
63
64 // Load the database
65 int r = loc_database_new(loc_ctx, &self->db, f);
66 if (r)
67 goto ERROR;
68
69 fclose(f);
70 return 0;
71
72 ERROR:
73 if (f)
74 fclose(f);
75
76 PyErr_SetFromErrno(PyExc_OSError);
77 return -1;
78 }
79
80 static PyObject* Database_repr(DatabaseObject* self) {
81 return PyUnicode_FromFormat("<Database %s>", self->path);
82 }
83
84 static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
85 PyObject* public_key = NULL;
86 FILE* f = NULL;
87
88 // Parse arguments
89 if (!PyArg_ParseTuple(args, "O", &public_key))
90 return NULL;
91
92 // Convert into FILE*
93 int fd = PyObject_AsFileDescriptor(public_key);
94 if (fd < 0)
95 return NULL;
96
97 // Re-open file descriptor
98 f = fdopen(fd, "r");
99 if (!f) {
100 PyErr_SetFromErrno(PyExc_IOError);
101 return NULL;
102 }
103
104 int r = loc_database_verify(self->db, f);
105
106 if (r == 0)
107 Py_RETURN_TRUE;
108
109 Py_RETURN_FALSE;
110 }
111
112 static PyObject* Database_get_description(DatabaseObject* self) {
113 const char* description = loc_database_get_description(self->db);
114 if (!description)
115 Py_RETURN_NONE;
116
117 return PyUnicode_FromString(description);
118 }
119
120 static PyObject* Database_get_vendor(DatabaseObject* self) {
121 const char* vendor = loc_database_get_vendor(self->db);
122 if (!vendor)
123 Py_RETURN_NONE;
124
125 return PyUnicode_FromString(vendor);
126 }
127
128 static PyObject* Database_get_license(DatabaseObject* self) {
129 const char* license = loc_database_get_license(self->db);
130 if (!license)
131 Py_RETURN_NONE;
132
133 return PyUnicode_FromString(license);
134 }
135
136 static PyObject* Database_get_created_at(DatabaseObject* self) {
137 time_t created_at = loc_database_created_at(self->db);
138
139 return PyLong_FromLong(created_at);
140 }
141
142 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
143 struct loc_as* as = NULL;
144 uint32_t number = 0;
145
146 if (!PyArg_ParseTuple(args, "i", &number))
147 return NULL;
148
149 // Try to retrieve the AS
150 int r = loc_database_get_as(self->db, &as, number);
151
152 // We got an AS
153 if (r == 0) {
154 PyObject* obj = new_as(&ASType, as);
155 loc_as_unref(as);
156
157 return obj;
158
159 // Nothing found
160 } else if (r == 1) {
161 Py_RETURN_NONE;
162 }
163
164 // Unexpected error
165 return NULL;
166 }
167
168 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
169 struct loc_country* country = NULL;
170 const char* country_code = NULL;
171
172 if (!PyArg_ParseTuple(args, "s", &country_code))
173 return NULL;
174
175 // Fetch the country
176 int r = loc_database_get_country(self->db, &country, country_code);
177 if (r) {
178 switch (errno) {
179 case EINVAL:
180 PyErr_SetString(PyExc_ValueError, "Invalid country code");
181 break;
182
183 default:
184 PyErr_SetFromErrno(PyExc_OSError);
185 break;
186 }
187
188 return NULL;
189 }
190
191 // No result
192 if (!country)
193 Py_RETURN_NONE;
194
195 PyObject* obj = new_country(&CountryType, country);
196 loc_country_unref(country);
197
198 return obj;
199 }
200
201 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
202 struct loc_network* network = NULL;
203 const char* address = NULL;
204 int r;
205
206 if (!PyArg_ParseTuple(args, "s", &address))
207 return NULL;
208
209 // Try to retrieve a matching network
210 r = loc_database_lookup_from_string(self->db, address, &network);
211 if (r) {
212 // Handle any errors
213 switch (errno) {
214 case EINVAL:
215 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
216 break;
217
218 default:
219 PyErr_SetFromErrno(PyExc_OSError);
220 break;
221 }
222
223 return NULL;
224 }
225
226 // Nothing found
227 if (!network)
228 Py_RETURN_NONE;
229
230 // We got a network
231 PyObject* obj = new_network(&NetworkType, network);
232 loc_network_unref(network);
233
234 return obj;
235 }
236
237 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
238 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
239 if (self) {
240 self->enumerator = loc_database_enumerator_ref(enumerator);
241 }
242
243 return (PyObject*)self;
244 }
245
246 static PyObject* Database_iterate_all(DatabaseObject* self,
247 enum loc_database_enumerator_mode what, int family, int flags) {
248 struct loc_database_enumerator* enumerator;
249
250 int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
251 if (r) {
252 PyErr_SetFromErrno(PyExc_SystemError);
253 return NULL;
254 }
255
256 // Set family
257 if (family)
258 loc_database_enumerator_set_family(enumerator, family);
259
260 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
261 loc_database_enumerator_unref(enumerator);
262
263 return obj;
264 }
265
266 static PyObject* Database_ases(DatabaseObject* self) {
267 return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, AF_UNSPEC, 0);
268 }
269
270 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
271 const char* string = NULL;
272
273 if (!PyArg_ParseTuple(args, "s", &string))
274 return NULL;
275
276 struct loc_database_enumerator* enumerator;
277
278 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
279 if (r) {
280 PyErr_SetFromErrno(PyExc_SystemError);
281 return NULL;
282 }
283
284 // Search string we are searching for
285 loc_database_enumerator_set_string(enumerator, string);
286
287 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
288 loc_database_enumerator_unref(enumerator);
289
290 return obj;
291 }
292
293 static PyObject* Database_networks(DatabaseObject* self) {
294 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC, 0);
295 }
296
297 static PyObject* Database_networks_flattened(DatabaseObject *self) {
298 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC,
299 LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
300 }
301
302 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
303 const char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL };
304 PyObject* country_codes = NULL;
305 PyObject* asn_list = NULL;
306 int flags = 0;
307 int family = 0;
308 int flatten = 0;
309
310 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", (char**)kwlist,
311 &PyList_Type, &country_codes, &PyList_Type, &asn_list, &flags, &family, &flatten))
312 return NULL;
313
314 struct loc_database_enumerator* enumerator;
315 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
316 (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
317 if (r) {
318 PyErr_SetFromErrno(PyExc_SystemError);
319 return NULL;
320 }
321
322 // Set country code we are searching for
323 if (country_codes) {
324 struct loc_country_list* countries;
325 r = loc_country_list_new(loc_ctx, &countries);
326 if (r) {
327 PyErr_SetString(PyExc_SystemError, "Could not create country list");
328 return NULL;
329 }
330
331 for (int i = 0; i < PyList_Size(country_codes); i++) {
332 PyObject* item = PyList_GetItem(country_codes, i);
333
334 if (!PyUnicode_Check(item)) {
335 PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
336 loc_country_list_unref(countries);
337 return NULL;
338 }
339
340 const char* country_code = PyUnicode_AsUTF8(item);
341
342 struct loc_country* country;
343 r = loc_country_new(loc_ctx, &country, country_code);
344 if (r) {
345 if (r == -EINVAL) {
346 PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
347 } else {
348 PyErr_SetString(PyExc_SystemError, "Could not create country");
349 }
350
351 loc_country_list_unref(countries);
352 return NULL;
353 }
354
355 // Append it to the list
356 r = loc_country_list_append(countries, country);
357 if (r) {
358 PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
359
360 loc_country_list_unref(countries);
361 loc_country_unref(country);
362 return NULL;
363 }
364
365 loc_country_unref(country);
366 }
367
368 r = loc_database_enumerator_set_countries(enumerator, countries);
369 if (r) {
370 PyErr_SetFromErrno(PyExc_SystemError);
371
372 loc_country_list_unref(countries);
373 return NULL;
374 }
375
376 loc_country_list_unref(countries);
377 }
378
379 // Set the ASN we are searching for
380 if (asn_list) {
381 struct loc_as_list* asns;
382 r = loc_as_list_new(loc_ctx, &asns);
383 if (r) {
384 PyErr_SetFromErrno(PyExc_OSError);
385 return NULL;
386 }
387
388 for (int i = 0; i < PyList_Size(asn_list); i++) {
389 PyObject* item = PyList_GetItem(asn_list, i);
390
391 if (!PyLong_Check(item)) {
392 PyErr_SetString(PyExc_TypeError, "ASNs must be numbers");
393
394 loc_as_list_unref(asns);
395 return NULL;
396 }
397
398 unsigned long number = PyLong_AsLong(item);
399
400 struct loc_as* as;
401 r = loc_as_new(loc_ctx, &as, number);
402 if (r) {
403 PyErr_SetFromErrno(PyExc_OSError);
404
405 loc_as_list_unref(asns);
406 loc_as_unref(as);
407 return NULL;
408 }
409
410 r = loc_as_list_append(asns, as);
411 if (r) {
412 PyErr_SetFromErrno(PyExc_OSError);
413
414 loc_as_list_unref(asns);
415 loc_as_unref(as);
416 return NULL;
417 }
418
419 loc_as_unref(as);
420 }
421
422 r = loc_database_enumerator_set_asns(enumerator, asns);
423 if (r) {
424 PyErr_SetFromErrno(PyExc_OSError);
425
426 loc_as_list_unref(asns);
427 return NULL;
428 }
429
430 loc_as_list_unref(asns);
431 }
432
433 // Set the flags we are searching for
434 if (flags) {
435 r = loc_database_enumerator_set_flag(enumerator, flags);
436
437 if (r) {
438 PyErr_SetFromErrno(PyExc_OSError);
439 return NULL;
440 }
441 }
442
443 // Set the family we are searching for
444 if (family) {
445 r = loc_database_enumerator_set_family(enumerator, family);
446
447 if (r) {
448 PyErr_SetFromErrno(PyExc_OSError);
449 return NULL;
450 }
451 }
452
453 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
454 loc_database_enumerator_unref(enumerator);
455
456 return obj;
457 }
458
459 static PyObject* Database_countries(DatabaseObject* self) {
460 return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, AF_UNSPEC, 0);
461 }
462
463 static PyObject* Database_list_bogons(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
464 const char* kwlist[] = { "family", NULL };
465 int family = AF_UNSPEC;
466
467 // Parse arguments
468 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", (char**)kwlist, &family))
469 return NULL;
470
471 return Database_iterate_all(self, LOC_DB_ENUMERATE_BOGONS, family, 0);
472 }
473
474 static struct PyMethodDef Database_methods[] = {
475 {
476 "get_as",
477 (PyCFunction)Database_get_as,
478 METH_VARARGS,
479 NULL,
480 },
481 {
482 "get_country",
483 (PyCFunction)Database_get_country,
484 METH_VARARGS,
485 NULL,
486 },
487 {
488 "list_bogons",
489 (PyCFunction)Database_list_bogons,
490 METH_VARARGS|METH_KEYWORDS,
491 NULL,
492 },
493 {
494 "lookup",
495 (PyCFunction)Database_lookup,
496 METH_VARARGS,
497 NULL,
498 },
499 {
500 "search_as",
501 (PyCFunction)Database_search_as,
502 METH_VARARGS,
503 NULL,
504 },
505 {
506 "search_networks",
507 (PyCFunction)Database_search_networks,
508 METH_VARARGS|METH_KEYWORDS,
509 NULL,
510 },
511 {
512 "verify",
513 (PyCFunction)Database_verify,
514 METH_VARARGS,
515 NULL,
516 },
517 { NULL },
518 };
519
520 static struct PyGetSetDef Database_getsetters[] = {
521 {
522 "ases",
523 (getter)Database_ases,
524 NULL,
525 NULL,
526 NULL,
527 },
528 {
529 "countries",
530 (getter)Database_countries,
531 NULL,
532 NULL,
533 NULL,
534 },
535 {
536 "created_at",
537 (getter)Database_get_created_at,
538 NULL,
539 NULL,
540 NULL,
541 },
542 {
543 "description",
544 (getter)Database_get_description,
545 NULL,
546 NULL,
547 NULL,
548 },
549 {
550 "license",
551 (getter)Database_get_license,
552 NULL,
553 NULL,
554 NULL,
555 },
556 {
557 "networks",
558 (getter)Database_networks,
559 NULL,
560 NULL,
561 NULL,
562 },
563 {
564 "networks_flattened",
565 (getter)Database_networks_flattened,
566 NULL,
567 NULL,
568 NULL,
569 },
570 {
571 "vendor",
572 (getter)Database_get_vendor,
573 NULL,
574 NULL,
575 NULL,
576 },
577 { NULL },
578 };
579
580 PyTypeObject DatabaseType = {
581 PyVarObject_HEAD_INIT(NULL, 0)
582 .tp_name = "location.Database",
583 .tp_basicsize = sizeof(DatabaseObject),
584 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
585 .tp_new = Database_new,
586 .tp_dealloc = (destructor)Database_dealloc,
587 .tp_init = (initproc)Database_init,
588 .tp_doc = "Database object",
589 .tp_methods = Database_methods,
590 .tp_getset = Database_getsetters,
591 .tp_repr = (reprfunc)Database_repr,
592 };
593
594 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
595 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
596
597 return (PyObject*)self;
598 }
599
600 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
601 loc_database_enumerator_unref(self->enumerator);
602
603 Py_TYPE(self)->tp_free((PyObject* )self);
604 }
605
606 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
607 struct loc_network* network = NULL;
608
609 // Enumerate all networks
610 int r = loc_database_enumerator_next_network(self->enumerator, &network);
611 if (r) {
612 PyErr_SetFromErrno(PyExc_ValueError);
613 return NULL;
614 }
615
616 // A network was found
617 if (network) {
618 PyObject* obj = new_network(&NetworkType, network);
619 loc_network_unref(network);
620
621 return obj;
622 }
623
624 // Enumerate all ASes
625 struct loc_as* as = NULL;
626
627 r = loc_database_enumerator_next_as(self->enumerator, &as);
628 if (r) {
629 PyErr_SetFromErrno(PyExc_ValueError);
630 return NULL;
631 }
632
633 if (as) {
634 PyObject* obj = new_as(&ASType, as);
635 loc_as_unref(as);
636
637 return obj;
638 }
639
640 // Enumerate all countries
641 struct loc_country* country = NULL;
642
643 r = loc_database_enumerator_next_country(self->enumerator, &country);
644 if (r) {
645 PyErr_SetFromErrno(PyExc_ValueError);
646 return NULL;
647 }
648
649 if (country) {
650 PyObject* obj = new_country(&CountryType, country);
651 loc_country_unref(country);
652
653 return obj;
654 }
655
656 // Nothing found, that means the end
657 PyErr_SetNone(PyExc_StopIteration);
658 return NULL;
659 }
660
661 PyTypeObject DatabaseEnumeratorType = {
662 PyVarObject_HEAD_INIT(NULL, 0)
663 .tp_name = "location.DatabaseEnumerator",
664 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
665 .tp_flags = Py_TPFLAGS_DEFAULT,
666 .tp_alloc = PyType_GenericAlloc,
667 .tp_new = DatabaseEnumerator_new,
668 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
669 .tp_iter = PyObject_SelfIter,
670 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
671 };