]> git.ipfire.org Git - location/libloc.git/blob - src/python/database.c
ed222759102b5b2ae5bb44c5c5c2447630320b1e
[location/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <loc/libloc.h>
20 #include <loc/as.h>
21 #include <loc/as-list.h>
22 #include <loc/database.h>
23
24 #include "locationmodule.h"
25 #include "as.h"
26 #include "country.h"
27 #include "database.h"
28 #include "network.h"
29
30 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
31 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
32
33 return (PyObject*)self;
34 }
35
36 static void Database_dealloc(DatabaseObject* self) {
37 if (self->db)
38 loc_database_unref(self->db);
39
40 if (self->path)
41 free(self->path);
42
43 Py_TYPE(self)->tp_free((PyObject* )self);
44 }
45
46 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
47 const char* path = NULL;
48
49 if (!PyArg_ParseTuple(args, "s", &path))
50 return -1;
51
52 self->path = strdup(path);
53
54 // Open the file for reading
55 FILE* f = fopen(self->path, "r");
56 if (!f) {
57 PyErr_SetFromErrno(PyExc_IOError);
58 return -1;
59 }
60
61 // Load the database
62 int r = loc_database_new(loc_ctx, &self->db, f);
63 fclose(f);
64
65 // Return on any errors
66 if (r)
67 return -1;
68
69 return 0;
70 }
71
72 static PyObject* Database_repr(DatabaseObject* self) {
73 return PyUnicode_FromFormat("<Database %s>", self->path);
74 }
75
76 static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
77 PyObject* public_key = NULL;
78 FILE* f = NULL;
79
80 // Parse arguments
81 if (!PyArg_ParseTuple(args, "O", &public_key))
82 return NULL;
83
84 // Convert into FILE*
85 int fd = PyObject_AsFileDescriptor(public_key);
86 if (fd < 0)
87 return NULL;
88
89 // Re-open file descriptor
90 f = fdopen(fd, "r");
91 if (!f) {
92 PyErr_SetFromErrno(PyExc_IOError);
93 return NULL;
94 }
95
96 int r = loc_database_verify(self->db, f);
97
98 if (r == 0)
99 Py_RETURN_TRUE;
100
101 Py_RETURN_FALSE;
102 }
103
104 static PyObject* Database_get_description(DatabaseObject* self) {
105 const char* description = loc_database_get_description(self->db);
106
107 return PyUnicode_FromString(description);
108 }
109
110 static PyObject* Database_get_vendor(DatabaseObject* self) {
111 const char* vendor = loc_database_get_vendor(self->db);
112
113 return PyUnicode_FromString(vendor);
114 }
115
116 static PyObject* Database_get_license(DatabaseObject* self) {
117 const char* license = loc_database_get_license(self->db);
118
119 return PyUnicode_FromString(license);
120 }
121
122 static PyObject* Database_get_created_at(DatabaseObject* self) {
123 time_t created_at = loc_database_created_at(self->db);
124
125 return PyLong_FromLong(created_at);
126 }
127
128 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
129 struct loc_as* as = NULL;
130 uint32_t number = 0;
131
132 if (!PyArg_ParseTuple(args, "i", &number))
133 return NULL;
134
135 // Try to retrieve the AS
136 int r = loc_database_get_as(self->db, &as, number);
137
138 // We got an AS
139 if (r == 0) {
140 PyObject* obj = new_as(&ASType, as);
141 loc_as_unref(as);
142
143 return obj;
144
145 // Nothing found
146 } else if (r == 1) {
147 Py_RETURN_NONE;
148 }
149
150 // Unexpected error
151 return NULL;
152 }
153
154 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
155 const char* country_code = NULL;
156
157 if (!PyArg_ParseTuple(args, "s", &country_code))
158 return NULL;
159
160 struct loc_country* country;
161 int r = loc_database_get_country(self->db, &country, country_code);
162 if (r) {
163 Py_RETURN_NONE;
164 }
165
166 PyObject* obj = new_country(&CountryType, country);
167 loc_country_unref(country);
168
169 return obj;
170 }
171
172 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
173 struct loc_network* network = NULL;
174 const char* address = NULL;
175
176 if (!PyArg_ParseTuple(args, "s", &address))
177 return NULL;
178
179 // Try to retrieve a matching network
180 int r = loc_database_lookup_from_string(self->db, address, &network);
181
182 // We got a network
183 if (r == 0) {
184 PyObject* obj = new_network(&NetworkType, network);
185 loc_network_unref(network);
186
187 return obj;
188
189 // Nothing found
190 } else if (r == 1) {
191 Py_RETURN_NONE;
192
193 // Invalid input
194 } else if (r == -EINVAL) {
195 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
196 return NULL;
197 }
198
199 // Unexpected error
200 return NULL;
201 }
202
203 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
204 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
205 if (self) {
206 self->enumerator = loc_database_enumerator_ref(enumerator);
207 }
208
209 return (PyObject*)self;
210 }
211
212 static PyObject* Database_iterate_all(DatabaseObject* self, enum loc_database_enumerator_mode what, int flags) {
213 struct loc_database_enumerator* enumerator;
214
215 int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
216 if (r) {
217 PyErr_SetFromErrno(PyExc_SystemError);
218 return NULL;
219 }
220
221 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
222 loc_database_enumerator_unref(enumerator);
223
224 return obj;
225 }
226
227 static PyObject* Database_ases(DatabaseObject* self) {
228 return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, 0);
229 }
230
231 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
232 const char* string = NULL;
233
234 if (!PyArg_ParseTuple(args, "s", &string))
235 return NULL;
236
237 struct loc_database_enumerator* enumerator;
238
239 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
240 if (r) {
241 PyErr_SetFromErrno(PyExc_SystemError);
242 return NULL;
243 }
244
245 // Search string we are searching for
246 loc_database_enumerator_set_string(enumerator, string);
247
248 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
249 loc_database_enumerator_unref(enumerator);
250
251 return obj;
252 }
253
254 static PyObject* Database_networks(DatabaseObject* self) {
255 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, 0);
256 }
257
258 static PyObject* Database_networks_flattened(DatabaseObject *self) {
259 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
260 }
261
262 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
263 char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL };
264 PyObject* country_codes = NULL;
265 PyObject* asn_list = NULL;
266 int flags = 0;
267 int family = 0;
268 int flatten = 0;
269
270 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", kwlist,
271 &PyList_Type, &country_codes, &PyList_Type, &asn_list, &flags, &family, &flatten))
272 return NULL;
273
274 struct loc_database_enumerator* enumerator;
275 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
276 (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
277 if (r) {
278 PyErr_SetFromErrno(PyExc_SystemError);
279 return NULL;
280 }
281
282 // Set country code we are searching for
283 if (country_codes) {
284 struct loc_country_list* countries;
285 r = loc_country_list_new(loc_ctx, &countries);
286 if (r) {
287 PyErr_SetString(PyExc_SystemError, "Could not create country list");
288 return NULL;
289 }
290
291 for (unsigned int i = 0; i < PyList_Size(country_codes); i++) {
292 PyObject* item = PyList_GetItem(country_codes, i);
293
294 if (!PyUnicode_Check(item)) {
295 PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
296 loc_country_list_unref(countries);
297 return NULL;
298 }
299
300 const char* country_code = PyUnicode_AsUTF8(item);
301
302 struct loc_country* country;
303 r = loc_country_new(loc_ctx, &country, country_code);
304 if (r) {
305 if (r == -EINVAL) {
306 PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
307 } else {
308 PyErr_SetString(PyExc_SystemError, "Could not create country");
309 }
310
311 loc_country_list_unref(countries);
312 return NULL;
313 }
314
315 // Append it to the list
316 r = loc_country_list_append(countries, country);
317 if (r) {
318 PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
319
320 loc_country_list_unref(countries);
321 loc_country_unref(country);
322 return NULL;
323 }
324
325 loc_country_unref(country);
326 }
327
328 r = loc_database_enumerator_set_countries(enumerator, countries);
329 if (r) {
330 PyErr_SetFromErrno(PyExc_SystemError);
331
332 loc_as_list_unref(countries);
333 return NULL;
334 }
335
336 loc_country_list_unref(countries);
337 }
338
339 // Set the ASN we are searching for
340 if (asn_list) {
341 struct loc_as_list* asns;
342 r = loc_as_list_new(loc_ctx, &asns);
343 if (r) {
344 PyErr_SetString(PyExc_SystemError, "Could not create AS list");
345 return NULL;
346 }
347
348 for (unsigned int i = 0; i < PyList_Size(asn_list); i++) {
349 PyObject* item = PyList_GetItem(asn_list, i);
350
351 if (!PyLong_Check(item)) {
352 PyErr_SetString(PyExc_TypeError, "ASNs must be numbers");
353
354 loc_as_list_unref(asns);
355 return NULL;
356 }
357
358 unsigned long number = PyLong_AsLong(item);
359
360 struct loc_as* as;
361 r = loc_as_new(loc_ctx, &as, number);
362 if (r) {
363 PyErr_SetString(PyExc_SystemError, "Could not create AS");
364
365 loc_as_list_unref(asns);
366 loc_as_unref(as);
367 return NULL;
368 }
369
370 r = loc_as_list_append(asns, as);
371 if (r) {
372 PyErr_SetString(PyExc_SystemError, "Could not append AS to the list");
373
374 loc_as_list_unref(asns);
375 loc_as_unref(as);
376 return NULL;
377 }
378
379 loc_as_unref(as);
380 }
381
382 r = loc_database_enumerator_set_asns(enumerator, asns);
383 if (r) {
384 PyErr_SetFromErrno(PyExc_SystemError);
385
386 loc_as_list_unref(asns);
387 return NULL;
388 }
389
390 loc_as_list_unref(asns);
391 }
392
393 // Set the flags we are searching for
394 if (flags) {
395 r = loc_database_enumerator_set_flag(enumerator, flags);
396
397 if (r) {
398 PyErr_SetFromErrno(PyExc_SystemError);
399 return NULL;
400 }
401 }
402
403 // Set the family we are searching for
404 if (family) {
405 r = loc_database_enumerator_set_family(enumerator, family);
406
407 if (r) {
408 PyErr_SetFromErrno(PyExc_SystemError);
409 return NULL;
410 }
411 }
412
413 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
414 loc_database_enumerator_unref(enumerator);
415
416 return obj;
417 }
418
419 static PyObject* Database_countries(DatabaseObject* self) {
420 return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, 0);
421 }
422
423 static struct PyMethodDef Database_methods[] = {
424 {
425 "get_as",
426 (PyCFunction)Database_get_as,
427 METH_VARARGS,
428 NULL,
429 },
430 {
431 "get_country",
432 (PyCFunction)Database_get_country,
433 METH_VARARGS,
434 NULL,
435 },
436 {
437 "lookup",
438 (PyCFunction)Database_lookup,
439 METH_VARARGS,
440 NULL,
441 },
442 {
443 "search_as",
444 (PyCFunction)Database_search_as,
445 METH_VARARGS,
446 NULL,
447 },
448 {
449 "search_networks",
450 (PyCFunction)Database_search_networks,
451 METH_VARARGS|METH_KEYWORDS,
452 NULL,
453 },
454 {
455 "verify",
456 (PyCFunction)Database_verify,
457 METH_VARARGS,
458 NULL,
459 },
460 { NULL },
461 };
462
463 static struct PyGetSetDef Database_getsetters[] = {
464 {
465 "ases",
466 (getter)Database_ases,
467 NULL,
468 NULL,
469 NULL,
470 },
471 {
472 "countries",
473 (getter)Database_countries,
474 NULL,
475 NULL,
476 NULL,
477 },
478 {
479 "created_at",
480 (getter)Database_get_created_at,
481 NULL,
482 NULL,
483 NULL,
484 },
485 {
486 "description",
487 (getter)Database_get_description,
488 NULL,
489 NULL,
490 NULL,
491 },
492 {
493 "license",
494 (getter)Database_get_license,
495 NULL,
496 NULL,
497 NULL,
498 },
499 {
500 "networks",
501 (getter)Database_networks,
502 NULL,
503 NULL,
504 NULL,
505 },
506 {
507 "networks_flattened",
508 (getter)Database_networks_flattened,
509 NULL,
510 NULL,
511 NULL,
512 },
513 {
514 "vendor",
515 (getter)Database_get_vendor,
516 NULL,
517 NULL,
518 NULL,
519 },
520 { NULL },
521 };
522
523 PyTypeObject DatabaseType = {
524 PyVarObject_HEAD_INIT(NULL, 0)
525 .tp_name = "location.Database",
526 .tp_basicsize = sizeof(DatabaseObject),
527 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
528 .tp_new = Database_new,
529 .tp_dealloc = (destructor)Database_dealloc,
530 .tp_init = (initproc)Database_init,
531 .tp_doc = "Database object",
532 .tp_methods = Database_methods,
533 .tp_getset = Database_getsetters,
534 .tp_repr = (reprfunc)Database_repr,
535 };
536
537 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
538 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
539
540 return (PyObject*)self;
541 }
542
543 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
544 loc_database_enumerator_unref(self->enumerator);
545
546 Py_TYPE(self)->tp_free((PyObject* )self);
547 }
548
549 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
550 struct loc_network* network = NULL;
551
552 // Enumerate all networks
553 int r = loc_database_enumerator_next_network(self->enumerator, &network);
554 if (r) {
555 PyErr_SetFromErrno(PyExc_ValueError);
556 return NULL;
557 }
558
559 // A network was found
560 if (network) {
561 PyObject* obj = new_network(&NetworkType, network);
562 loc_network_unref(network);
563
564 return obj;
565 }
566
567 // Enumerate all ASes
568 struct loc_as* as = NULL;
569
570 r = loc_database_enumerator_next_as(self->enumerator, &as);
571 if (r) {
572 PyErr_SetFromErrno(PyExc_ValueError);
573 return NULL;
574 }
575
576 if (as) {
577 PyObject* obj = new_as(&ASType, as);
578 loc_as_unref(as);
579
580 return obj;
581 }
582
583 // Enumerate all countries
584 struct loc_country* country = NULL;
585
586 r = loc_database_enumerator_next_country(self->enumerator, &country);
587 if (r) {
588 PyErr_SetFromErrno(PyExc_ValueError);
589 return NULL;
590 }
591
592 if (country) {
593 PyObject* obj = new_country(&CountryType, country);
594 loc_country_unref(country);
595
596 return obj;
597 }
598
599 // Nothing found, that means the end
600 PyErr_SetNone(PyExc_StopIteration);
601 return NULL;
602 }
603
604 PyTypeObject DatabaseEnumeratorType = {
605 PyVarObject_HEAD_INIT(NULL, 0)
606 .tp_name = "location.DatabaseEnumerator",
607 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
608 .tp_flags = Py_TPFLAGS_DEFAULT,
609 .tp_alloc = PyType_GenericAlloc,
610 .tp_new = DatabaseEnumerator_new,
611 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
612 .tp_iter = PyObject_SelfIter,
613 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
614 };