]> git.ipfire.org Git - people/ms/libloc.git/blob - src/python/database.c
38a804cd8a8936e925601498f0c52c3fc1e3f6b5
[people/ms/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <loc/libloc.h>
20 #include <loc/as.h>
21 #include <loc/as-list.h>
22 #include <loc/database.h>
23
24 #include "locationmodule.h"
25 #include "as.h"
26 #include "country.h"
27 #include "database.h"
28 #include "network.h"
29
30 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
31 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
32
33 return (PyObject*)self;
34 }
35
36 static void Database_dealloc(DatabaseObject* self) {
37 if (self->db)
38 loc_database_unref(self->db);
39
40 if (self->path)
41 free(self->path);
42
43 Py_TYPE(self)->tp_free((PyObject* )self);
44 }
45
46 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
47 const char* path = NULL;
48
49 if (!PyArg_ParseTuple(args, "s", &path))
50 return -1;
51
52 self->path = strdup(path);
53
54 // Open the file for reading
55 FILE* f = fopen(self->path, "r");
56 if (!f) {
57 PyErr_SetFromErrno(PyExc_IOError);
58 return -1;
59 }
60
61 // Load the database
62 int r = loc_database_new(loc_ctx, &self->db, f);
63 fclose(f);
64
65 // Return on any errors
66 if (r)
67 return -1;
68
69 return 0;
70 }
71
72 static PyObject* Database_repr(DatabaseObject* self) {
73 return PyUnicode_FromFormat("<Database %s>", self->path);
74 }
75
76 static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
77 PyObject* public_key = NULL;
78 FILE* f = NULL;
79
80 // Parse arguments
81 if (!PyArg_ParseTuple(args, "O", &public_key))
82 return NULL;
83
84 // Convert into FILE*
85 int fd = PyObject_AsFileDescriptor(public_key);
86 if (fd < 0)
87 return NULL;
88
89 // Re-open file descriptor
90 f = fdopen(fd, "r");
91 if (!f) {
92 PyErr_SetFromErrno(PyExc_IOError);
93 return NULL;
94 }
95
96 int r = loc_database_verify(self->db, f);
97
98 if (r == 0)
99 Py_RETURN_TRUE;
100
101 Py_RETURN_FALSE;
102 }
103
104 static PyObject* Database_get_description(DatabaseObject* self) {
105 const char* description = loc_database_get_description(self->db);
106
107 return PyUnicode_FromString(description);
108 }
109
110 static PyObject* Database_get_vendor(DatabaseObject* self) {
111 const char* vendor = loc_database_get_vendor(self->db);
112
113 return PyUnicode_FromString(vendor);
114 }
115
116 static PyObject* Database_get_license(DatabaseObject* self) {
117 const char* license = loc_database_get_license(self->db);
118
119 return PyUnicode_FromString(license);
120 }
121
122 static PyObject* Database_get_created_at(DatabaseObject* self) {
123 time_t created_at = loc_database_created_at(self->db);
124
125 return PyLong_FromLong(created_at);
126 }
127
128 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
129 struct loc_as* as = NULL;
130 uint32_t number = 0;
131
132 if (!PyArg_ParseTuple(args, "i", &number))
133 return NULL;
134
135 // Try to retrieve the AS
136 int r = loc_database_get_as(self->db, &as, number);
137
138 // We got an AS
139 if (r == 0) {
140 PyObject* obj = new_as(&ASType, as);
141 loc_as_unref(as);
142
143 return obj;
144
145 // Nothing found
146 } else if (r == 1) {
147 Py_RETURN_NONE;
148 }
149
150 // Unexpected error
151 return NULL;
152 }
153
154 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
155 const char* country_code = NULL;
156
157 if (!PyArg_ParseTuple(args, "s", &country_code))
158 return NULL;
159
160 struct loc_country* country;
161 int r = loc_database_get_country(self->db, &country, country_code);
162 if (r) {
163 Py_RETURN_NONE;
164 }
165
166 PyObject* obj = new_country(&CountryType, country);
167 loc_country_unref(country);
168
169 return obj;
170 }
171
172 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
173 struct loc_network* network = NULL;
174 const char* address = NULL;
175
176 if (!PyArg_ParseTuple(args, "s", &address))
177 return NULL;
178
179 // Try to retrieve a matching network
180 int r = loc_database_lookup_from_string(self->db, address, &network);
181
182 // We got a network
183 if (r == 0) {
184 PyObject* obj = new_network(&NetworkType, network);
185 loc_network_unref(network);
186
187 return obj;
188
189 // Nothing found
190 } else if (r == 1) {
191 Py_RETURN_NONE;
192
193 // Invalid input
194 } else if (r == -EINVAL) {
195 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
196 return NULL;
197 }
198
199 // Unexpected error
200 return NULL;
201 }
202
203 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
204 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
205 if (self) {
206 self->enumerator = loc_database_enumerator_ref(enumerator);
207 }
208
209 return (PyObject*)self;
210 }
211
212 static PyObject* Database_iterate_all(DatabaseObject* self, enum loc_database_enumerator_mode what, int flags) {
213 struct loc_database_enumerator* enumerator;
214
215 int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
216 if (r) {
217 PyErr_SetFromErrno(PyExc_SystemError);
218 return NULL;
219 }
220
221 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
222 loc_database_enumerator_unref(enumerator);
223
224 return obj;
225 }
226
227 static PyObject* Database_ases(DatabaseObject* self) {
228 return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, 0);
229 }
230
231 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
232 const char* string = NULL;
233
234 if (!PyArg_ParseTuple(args, "s", &string))
235 return NULL;
236
237 struct loc_database_enumerator* enumerator;
238
239 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
240 if (r) {
241 PyErr_SetFromErrno(PyExc_SystemError);
242 return NULL;
243 }
244
245 // Search string we are searching for
246 loc_database_enumerator_set_string(enumerator, string);
247
248 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
249 loc_database_enumerator_unref(enumerator);
250
251 return obj;
252 }
253
254 static PyObject* Database_networks(DatabaseObject* self) {
255 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, 0);
256 }
257
258 static PyObject* Database_networks_flattened(DatabaseObject *self) {
259 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
260 }
261
262 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
263 char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL };
264 PyObject* country_codes = NULL;
265 PyObject* asn_list = NULL;
266 int flags = 0;
267 int family = 0;
268 int flatten = 0;
269
270 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", kwlist,
271 &PyList_Type, &country_codes, &PyList_Type, &asn_list, &flags, &family, &flatten))
272 return NULL;
273
274 struct loc_database_enumerator* enumerator;
275 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
276 (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
277 if (r) {
278 PyErr_SetFromErrno(PyExc_SystemError);
279 return NULL;
280 }
281
282 // Set country code we are searching for
283 if (country_codes) {
284 struct loc_country_list* countries;
285 r = loc_country_list_new(loc_ctx, &countries);
286 if (r) {
287 PyErr_SetString(PyExc_SystemError, "Could not create country list");
288 return NULL;
289 }
290
291 for (unsigned int i = 0; i < PyList_Size(country_codes); i++) {
292 PyObject* item = PyList_GetItem(country_codes, i);
293
294 if (!PyUnicode_Check(item)) {
295 PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
296 loc_country_list_unref(countries);
297 return NULL;
298 }
299
300 const char* country_code = PyUnicode_AsUTF8(item);
301
302 struct loc_country* country;
303 r = loc_country_new(loc_ctx, &country, country_code);
304 if (r) {
305 if (r == -EINVAL) {
306 PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
307 } else {
308 PyErr_SetString(PyExc_SystemError, "Could not create country");
309 }
310
311 loc_country_list_unref(countries);
312 return NULL;
313 }
314
315 // Append it to the list
316 r = loc_country_list_append(countries, country);
317 if (r) {
318 PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
319
320 loc_country_list_unref(countries);
321 loc_country_unref(country);
322 return NULL;
323 }
324
325 loc_country_unref(country);
326 }
327
328 loc_database_enumerator_set_countries(enumerator, countries);
329
330 Py_DECREF(country_codes);
331 loc_country_list_unref(countries);
332 }
333
334 // Set the ASN we are searching for
335 if (asn_list) {
336 struct loc_as_list* asns;
337 r = loc_as_list_new(loc_ctx, &asns);
338 if (r) {
339 PyErr_SetString(PyExc_SystemError, "Could not create AS list");
340 return NULL;
341 }
342
343 for (unsigned int i = 0; i < PyList_Size(asn_list); i++) {
344 PyObject* item = PyList_GetItem(asn_list, i);
345
346 if (!PyLong_Check(item)) {
347 PyErr_SetString(PyExc_TypeError, "ASNs must be numbers");
348
349 loc_as_list_unref(asns);
350 return NULL;
351 }
352
353 unsigned long number = PyLong_AsLong(item);
354
355 struct loc_as* as;
356 r = loc_as_new(loc_ctx, &as, number);
357 if (r) {
358 PyErr_SetString(PyExc_SystemError, "Could not create AS");
359
360 loc_as_list_unref(asns);
361 loc_as_unref(as);
362 return NULL;
363 }
364
365 r = loc_as_list_append(asns, as);
366 if (r) {
367 PyErr_SetString(PyExc_SystemError, "Could not append AS to the list");
368
369 loc_as_list_unref(asns);
370 loc_as_unref(as);
371 return NULL;
372 }
373
374 loc_as_unref(as);
375 }
376
377 r = loc_database_enumerator_set_asns(enumerator, asns);
378 if (r) {
379 PyErr_SetFromErrno(PyExc_SystemError);
380
381 loc_as_list_unref(asns);
382 return NULL;
383 }
384
385 loc_as_list_unref(asns);
386 }
387
388 // Set the flags we are searching for
389 if (flags) {
390 r = loc_database_enumerator_set_flag(enumerator, flags);
391
392 if (r) {
393 PyErr_SetFromErrno(PyExc_SystemError);
394 return NULL;
395 }
396 }
397
398 // Set the family we are searching for
399 if (family) {
400 r = loc_database_enumerator_set_family(enumerator, family);
401
402 if (r) {
403 PyErr_SetFromErrno(PyExc_SystemError);
404 return NULL;
405 }
406 }
407
408 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
409 loc_database_enumerator_unref(enumerator);
410
411 return obj;
412 }
413
414 static PyObject* Database_countries(DatabaseObject* self) {
415 return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, 0);
416 }
417
418 static struct PyMethodDef Database_methods[] = {
419 {
420 "get_as",
421 (PyCFunction)Database_get_as,
422 METH_VARARGS,
423 NULL,
424 },
425 {
426 "get_country",
427 (PyCFunction)Database_get_country,
428 METH_VARARGS,
429 NULL,
430 },
431 {
432 "lookup",
433 (PyCFunction)Database_lookup,
434 METH_VARARGS,
435 NULL,
436 },
437 {
438 "search_as",
439 (PyCFunction)Database_search_as,
440 METH_VARARGS,
441 NULL,
442 },
443 {
444 "search_networks",
445 (PyCFunction)Database_search_networks,
446 METH_VARARGS|METH_KEYWORDS,
447 NULL,
448 },
449 {
450 "verify",
451 (PyCFunction)Database_verify,
452 METH_VARARGS,
453 NULL,
454 },
455 { NULL },
456 };
457
458 static struct PyGetSetDef Database_getsetters[] = {
459 {
460 "ases",
461 (getter)Database_ases,
462 NULL,
463 NULL,
464 NULL,
465 },
466 {
467 "countries",
468 (getter)Database_countries,
469 NULL,
470 NULL,
471 NULL,
472 },
473 {
474 "created_at",
475 (getter)Database_get_created_at,
476 NULL,
477 NULL,
478 NULL,
479 },
480 {
481 "description",
482 (getter)Database_get_description,
483 NULL,
484 NULL,
485 NULL,
486 },
487 {
488 "license",
489 (getter)Database_get_license,
490 NULL,
491 NULL,
492 NULL,
493 },
494 {
495 "networks",
496 (getter)Database_networks,
497 NULL,
498 NULL,
499 NULL,
500 },
501 {
502 "networks_flattened",
503 (getter)Database_networks_flattened,
504 NULL,
505 NULL,
506 NULL,
507 },
508 {
509 "vendor",
510 (getter)Database_get_vendor,
511 NULL,
512 NULL,
513 NULL,
514 },
515 { NULL },
516 };
517
518 PyTypeObject DatabaseType = {
519 PyVarObject_HEAD_INIT(NULL, 0)
520 .tp_name = "location.Database",
521 .tp_basicsize = sizeof(DatabaseObject),
522 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
523 .tp_new = Database_new,
524 .tp_dealloc = (destructor)Database_dealloc,
525 .tp_init = (initproc)Database_init,
526 .tp_doc = "Database object",
527 .tp_methods = Database_methods,
528 .tp_getset = Database_getsetters,
529 .tp_repr = (reprfunc)Database_repr,
530 };
531
532 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
533 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
534
535 return (PyObject*)self;
536 }
537
538 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
539 loc_database_enumerator_unref(self->enumerator);
540
541 Py_TYPE(self)->tp_free((PyObject* )self);
542 }
543
544 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
545 struct loc_network* network = NULL;
546
547 // Enumerate all networks
548 int r = loc_database_enumerator_next_network(self->enumerator, &network);
549 if (r) {
550 PyErr_SetFromErrno(PyExc_ValueError);
551 return NULL;
552 }
553
554 // A network was found
555 if (network) {
556 PyObject* obj = new_network(&NetworkType, network);
557 loc_network_unref(network);
558
559 return obj;
560 }
561
562 // Enumerate all ASes
563 struct loc_as* as = NULL;
564
565 r = loc_database_enumerator_next_as(self->enumerator, &as);
566 if (r) {
567 PyErr_SetFromErrno(PyExc_ValueError);
568 return NULL;
569 }
570
571 if (as) {
572 PyObject* obj = new_as(&ASType, as);
573 loc_as_unref(as);
574
575 return obj;
576 }
577
578 // Enumerate all countries
579 struct loc_country* country = NULL;
580
581 r = loc_database_enumerator_next_country(self->enumerator, &country);
582 if (r) {
583 PyErr_SetFromErrno(PyExc_ValueError);
584 return NULL;
585 }
586
587 if (country) {
588 PyObject* obj = new_country(&CountryType, country);
589 loc_country_unref(country);
590
591 return obj;
592 }
593
594 // Nothing found, that means the end
595 PyErr_SetNone(PyExc_StopIteration);
596 return NULL;
597 }
598
599 PyTypeObject DatabaseEnumeratorType = {
600 PyVarObject_HEAD_INIT(NULL, 0)
601 .tp_name = "location.DatabaseEnumerator",
602 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
603 .tp_flags = Py_TPFLAGS_DEFAULT,
604 .tp_alloc = PyType_GenericAlloc,
605 .tp_new = DatabaseEnumerator_new,
606 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
607 .tp_iter = PyObject_SelfIter,
608 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
609 };