]> git.ipfire.org Git - location/libloc.git/blob - src/python/database.c
python: Fix errors for Database.lookup()
[location/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <libloc/libloc.h>
20 #include <libloc/as.h>
21 #include <libloc/as-list.h>
22 #include <libloc/database.h>
23
24 #include "locationmodule.h"
25 #include "as.h"
26 #include "country.h"
27 #include "database.h"
28 #include "network.h"
29
30 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
31 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
32
33 return (PyObject*)self;
34 }
35
36 static void Database_dealloc(DatabaseObject* self) {
37 if (self->db)
38 loc_database_unref(self->db);
39
40 if (self->path)
41 free(self->path);
42
43 Py_TYPE(self)->tp_free((PyObject* )self);
44 }
45
46 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
47 const char* path = NULL;
48 FILE* f = NULL;
49
50 // Parse arguments
51 if (!PyArg_ParseTuple(args, "s", &path))
52 return -1;
53
54 // Copy path
55 self->path = strdup(path);
56 if (!self->path)
57 goto ERROR;
58
59 // Open the file for reading
60 f = fopen(self->path, "r");
61 if (!f)
62 goto ERROR;
63
64 // Load the database
65 int r = loc_database_new(loc_ctx, &self->db, f);
66 if (r)
67 goto ERROR;
68
69 fclose(f);
70 return 0;
71
72 ERROR:
73 if (f)
74 fclose(f);
75
76 PyErr_SetFromErrno(PyExc_OSError);
77 return -1;
78 }
79
80 static PyObject* Database_repr(DatabaseObject* self) {
81 return PyUnicode_FromFormat("<Database %s>", self->path);
82 }
83
84 static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
85 PyObject* public_key = NULL;
86 FILE* f = NULL;
87
88 // Parse arguments
89 if (!PyArg_ParseTuple(args, "O", &public_key))
90 return NULL;
91
92 // Convert into FILE*
93 int fd = PyObject_AsFileDescriptor(public_key);
94 if (fd < 0)
95 return NULL;
96
97 // Re-open file descriptor
98 f = fdopen(fd, "r");
99 if (!f) {
100 PyErr_SetFromErrno(PyExc_IOError);
101 return NULL;
102 }
103
104 int r = loc_database_verify(self->db, f);
105
106 if (r == 0)
107 Py_RETURN_TRUE;
108
109 Py_RETURN_FALSE;
110 }
111
112 static PyObject* Database_get_description(DatabaseObject* self) {
113 const char* description = loc_database_get_description(self->db);
114 if (!description)
115 Py_RETURN_NONE;
116
117 return PyUnicode_FromString(description);
118 }
119
120 static PyObject* Database_get_vendor(DatabaseObject* self) {
121 const char* vendor = loc_database_get_vendor(self->db);
122 if (!vendor)
123 Py_RETURN_NONE;
124
125 return PyUnicode_FromString(vendor);
126 }
127
128 static PyObject* Database_get_license(DatabaseObject* self) {
129 const char* license = loc_database_get_license(self->db);
130 if (!license)
131 Py_RETURN_NONE;
132
133 return PyUnicode_FromString(license);
134 }
135
136 static PyObject* Database_get_created_at(DatabaseObject* self) {
137 time_t created_at = loc_database_created_at(self->db);
138
139 return PyLong_FromLong(created_at);
140 }
141
142 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
143 struct loc_as* as = NULL;
144 uint32_t number = 0;
145
146 if (!PyArg_ParseTuple(args, "i", &number))
147 return NULL;
148
149 // Try to retrieve the AS
150 int r = loc_database_get_as(self->db, &as, number);
151
152 // We got an AS
153 if (r == 0) {
154 PyObject* obj = new_as(&ASType, as);
155 loc_as_unref(as);
156
157 return obj;
158
159 // Nothing found
160 } else if (r == 1) {
161 Py_RETURN_NONE;
162 }
163
164 // Unexpected error
165 return NULL;
166 }
167
168 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
169 const char* country_code = NULL;
170
171 if (!PyArg_ParseTuple(args, "s", &country_code))
172 return NULL;
173
174 struct loc_country* country;
175 int r = loc_database_get_country(self->db, &country, country_code);
176 if (r) {
177 Py_RETURN_NONE;
178 }
179
180 PyObject* obj = new_country(&CountryType, country);
181 loc_country_unref(country);
182
183 return obj;
184 }
185
186 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
187 struct loc_network* network = NULL;
188 const char* address = NULL;
189
190 if (!PyArg_ParseTuple(args, "s", &address))
191 return NULL;
192
193 // Try to retrieve a matching network
194 int r = loc_database_lookup_from_string(self->db, address, &network);
195
196 // We got a network
197 if (r == 0) {
198 PyObject* obj = new_network(&NetworkType, network);
199 loc_network_unref(network);
200
201 return obj;
202 }
203
204 // Nothing found
205 if (!errno)
206 Py_RETURN_NONE;
207
208 // Handle any errors
209 switch (errno) {
210 case EINVAL:
211 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
212
213 default:
214 PyErr_SetFromErrno(PyExc_OSError);
215 }
216
217 return NULL;
218 }
219
220 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
221 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
222 if (self) {
223 self->enumerator = loc_database_enumerator_ref(enumerator);
224 }
225
226 return (PyObject*)self;
227 }
228
229 static PyObject* Database_iterate_all(DatabaseObject* self,
230 enum loc_database_enumerator_mode what, int family, int flags) {
231 struct loc_database_enumerator* enumerator;
232
233 int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
234 if (r) {
235 PyErr_SetFromErrno(PyExc_SystemError);
236 return NULL;
237 }
238
239 // Set family
240 if (family)
241 loc_database_enumerator_set_family(enumerator, family);
242
243 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
244 loc_database_enumerator_unref(enumerator);
245
246 return obj;
247 }
248
249 static PyObject* Database_ases(DatabaseObject* self) {
250 return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, AF_UNSPEC, 0);
251 }
252
253 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
254 const char* string = NULL;
255
256 if (!PyArg_ParseTuple(args, "s", &string))
257 return NULL;
258
259 struct loc_database_enumerator* enumerator;
260
261 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
262 if (r) {
263 PyErr_SetFromErrno(PyExc_SystemError);
264 return NULL;
265 }
266
267 // Search string we are searching for
268 loc_database_enumerator_set_string(enumerator, string);
269
270 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
271 loc_database_enumerator_unref(enumerator);
272
273 return obj;
274 }
275
276 static PyObject* Database_networks(DatabaseObject* self) {
277 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC, 0);
278 }
279
280 static PyObject* Database_networks_flattened(DatabaseObject *self) {
281 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, AF_UNSPEC,
282 LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
283 }
284
285 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
286 char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL };
287 PyObject* country_codes = NULL;
288 PyObject* asn_list = NULL;
289 int flags = 0;
290 int family = 0;
291 int flatten = 0;
292
293 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", kwlist,
294 &PyList_Type, &country_codes, &PyList_Type, &asn_list, &flags, &family, &flatten))
295 return NULL;
296
297 struct loc_database_enumerator* enumerator;
298 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
299 (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
300 if (r) {
301 PyErr_SetFromErrno(PyExc_SystemError);
302 return NULL;
303 }
304
305 // Set country code we are searching for
306 if (country_codes) {
307 struct loc_country_list* countries;
308 r = loc_country_list_new(loc_ctx, &countries);
309 if (r) {
310 PyErr_SetString(PyExc_SystemError, "Could not create country list");
311 return NULL;
312 }
313
314 for (int i = 0; i < PyList_Size(country_codes); i++) {
315 PyObject* item = PyList_GetItem(country_codes, i);
316
317 if (!PyUnicode_Check(item)) {
318 PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
319 loc_country_list_unref(countries);
320 return NULL;
321 }
322
323 const char* country_code = PyUnicode_AsUTF8(item);
324
325 struct loc_country* country;
326 r = loc_country_new(loc_ctx, &country, country_code);
327 if (r) {
328 if (r == -EINVAL) {
329 PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
330 } else {
331 PyErr_SetString(PyExc_SystemError, "Could not create country");
332 }
333
334 loc_country_list_unref(countries);
335 return NULL;
336 }
337
338 // Append it to the list
339 r = loc_country_list_append(countries, country);
340 if (r) {
341 PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
342
343 loc_country_list_unref(countries);
344 loc_country_unref(country);
345 return NULL;
346 }
347
348 loc_country_unref(country);
349 }
350
351 r = loc_database_enumerator_set_countries(enumerator, countries);
352 if (r) {
353 PyErr_SetFromErrno(PyExc_SystemError);
354
355 loc_country_list_unref(countries);
356 return NULL;
357 }
358
359 loc_country_list_unref(countries);
360 }
361
362 // Set the ASN we are searching for
363 if (asn_list) {
364 struct loc_as_list* asns;
365 r = loc_as_list_new(loc_ctx, &asns);
366 if (r) {
367 PyErr_SetFromErrno(PyExc_OSError);
368 return NULL;
369 }
370
371 for (int i = 0; i < PyList_Size(asn_list); i++) {
372 PyObject* item = PyList_GetItem(asn_list, i);
373
374 if (!PyLong_Check(item)) {
375 PyErr_SetString(PyExc_TypeError, "ASNs must be numbers");
376
377 loc_as_list_unref(asns);
378 return NULL;
379 }
380
381 unsigned long number = PyLong_AsLong(item);
382
383 struct loc_as* as;
384 r = loc_as_new(loc_ctx, &as, number);
385 if (r) {
386 PyErr_SetFromErrno(PyExc_OSError);
387
388 loc_as_list_unref(asns);
389 loc_as_unref(as);
390 return NULL;
391 }
392
393 r = loc_as_list_append(asns, as);
394 if (r) {
395 PyErr_SetFromErrno(PyExc_OSError);
396
397 loc_as_list_unref(asns);
398 loc_as_unref(as);
399 return NULL;
400 }
401
402 loc_as_unref(as);
403 }
404
405 r = loc_database_enumerator_set_asns(enumerator, asns);
406 if (r) {
407 PyErr_SetFromErrno(PyExc_OSError);
408
409 loc_as_list_unref(asns);
410 return NULL;
411 }
412
413 loc_as_list_unref(asns);
414 }
415
416 // Set the flags we are searching for
417 if (flags) {
418 r = loc_database_enumerator_set_flag(enumerator, flags);
419
420 if (r) {
421 PyErr_SetFromErrno(PyExc_OSError);
422 return NULL;
423 }
424 }
425
426 // Set the family we are searching for
427 if (family) {
428 r = loc_database_enumerator_set_family(enumerator, family);
429
430 if (r) {
431 PyErr_SetFromErrno(PyExc_OSError);
432 return NULL;
433 }
434 }
435
436 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
437 loc_database_enumerator_unref(enumerator);
438
439 return obj;
440 }
441
442 static PyObject* Database_countries(DatabaseObject* self) {
443 return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, AF_UNSPEC, 0);
444 }
445
446 static PyObject* Database_list_bogons(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
447 char* kwlist[] = { "family", NULL };
448 int family = AF_UNSPEC;
449
450 // Parse arguments
451 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &family))
452 return NULL;
453
454 return Database_iterate_all(self, LOC_DB_ENUMERATE_BOGONS, family, 0);
455 }
456
457 static struct PyMethodDef Database_methods[] = {
458 {
459 "get_as",
460 (PyCFunction)Database_get_as,
461 METH_VARARGS,
462 NULL,
463 },
464 {
465 "get_country",
466 (PyCFunction)Database_get_country,
467 METH_VARARGS,
468 NULL,
469 },
470 {
471 "list_bogons",
472 (PyCFunction)Database_list_bogons,
473 METH_VARARGS|METH_KEYWORDS,
474 NULL,
475 },
476 {
477 "lookup",
478 (PyCFunction)Database_lookup,
479 METH_VARARGS,
480 NULL,
481 },
482 {
483 "search_as",
484 (PyCFunction)Database_search_as,
485 METH_VARARGS,
486 NULL,
487 },
488 {
489 "search_networks",
490 (PyCFunction)Database_search_networks,
491 METH_VARARGS|METH_KEYWORDS,
492 NULL,
493 },
494 {
495 "verify",
496 (PyCFunction)Database_verify,
497 METH_VARARGS,
498 NULL,
499 },
500 { NULL },
501 };
502
503 static struct PyGetSetDef Database_getsetters[] = {
504 {
505 "ases",
506 (getter)Database_ases,
507 NULL,
508 NULL,
509 NULL,
510 },
511 {
512 "countries",
513 (getter)Database_countries,
514 NULL,
515 NULL,
516 NULL,
517 },
518 {
519 "created_at",
520 (getter)Database_get_created_at,
521 NULL,
522 NULL,
523 NULL,
524 },
525 {
526 "description",
527 (getter)Database_get_description,
528 NULL,
529 NULL,
530 NULL,
531 },
532 {
533 "license",
534 (getter)Database_get_license,
535 NULL,
536 NULL,
537 NULL,
538 },
539 {
540 "networks",
541 (getter)Database_networks,
542 NULL,
543 NULL,
544 NULL,
545 },
546 {
547 "networks_flattened",
548 (getter)Database_networks_flattened,
549 NULL,
550 NULL,
551 NULL,
552 },
553 {
554 "vendor",
555 (getter)Database_get_vendor,
556 NULL,
557 NULL,
558 NULL,
559 },
560 { NULL },
561 };
562
563 PyTypeObject DatabaseType = {
564 PyVarObject_HEAD_INIT(NULL, 0)
565 .tp_name = "location.Database",
566 .tp_basicsize = sizeof(DatabaseObject),
567 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
568 .tp_new = Database_new,
569 .tp_dealloc = (destructor)Database_dealloc,
570 .tp_init = (initproc)Database_init,
571 .tp_doc = "Database object",
572 .tp_methods = Database_methods,
573 .tp_getset = Database_getsetters,
574 .tp_repr = (reprfunc)Database_repr,
575 };
576
577 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
578 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
579
580 return (PyObject*)self;
581 }
582
583 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
584 loc_database_enumerator_unref(self->enumerator);
585
586 Py_TYPE(self)->tp_free((PyObject* )self);
587 }
588
589 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
590 struct loc_network* network = NULL;
591
592 // Enumerate all networks
593 int r = loc_database_enumerator_next_network(self->enumerator, &network);
594 if (r) {
595 PyErr_SetFromErrno(PyExc_ValueError);
596 return NULL;
597 }
598
599 // A network was found
600 if (network) {
601 PyObject* obj = new_network(&NetworkType, network);
602 loc_network_unref(network);
603
604 return obj;
605 }
606
607 // Enumerate all ASes
608 struct loc_as* as = NULL;
609
610 r = loc_database_enumerator_next_as(self->enumerator, &as);
611 if (r) {
612 PyErr_SetFromErrno(PyExc_ValueError);
613 return NULL;
614 }
615
616 if (as) {
617 PyObject* obj = new_as(&ASType, as);
618 loc_as_unref(as);
619
620 return obj;
621 }
622
623 // Enumerate all countries
624 struct loc_country* country = NULL;
625
626 r = loc_database_enumerator_next_country(self->enumerator, &country);
627 if (r) {
628 PyErr_SetFromErrno(PyExc_ValueError);
629 return NULL;
630 }
631
632 if (country) {
633 PyObject* obj = new_country(&CountryType, country);
634 loc_country_unref(country);
635
636 return obj;
637 }
638
639 // Nothing found, that means the end
640 PyErr_SetNone(PyExc_StopIteration);
641 return NULL;
642 }
643
644 PyTypeObject DatabaseEnumeratorType = {
645 PyVarObject_HEAD_INIT(NULL, 0)
646 .tp_name = "location.DatabaseEnumerator",
647 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
648 .tp_flags = Py_TPFLAGS_DEFAULT,
649 .tp_alloc = PyType_GenericAlloc,
650 .tp_new = DatabaseEnumerator_new,
651 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
652 .tp_iter = PyObject_SelfIter,
653 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
654 };