]> git.ipfire.org Git - location/libloc.git/blob - src/python/database.c
e6f6f37e2bc437a55c5b3c84ef5672a0faa7353f
[location/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <loc/libloc.h>
20 #include <loc/database.h>
21
22 #include "locationmodule.h"
23 #include "as.h"
24 #include "country.h"
25 #include "database.h"
26 #include "network.h"
27
28 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
29 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
30
31 return (PyObject*)self;
32 }
33
34 static void Database_dealloc(DatabaseObject* self) {
35 if (self->db)
36 loc_database_unref(self->db);
37
38 if (self->path)
39 free(self->path);
40
41 Py_TYPE(self)->tp_free((PyObject* )self);
42 }
43
44 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
45 const char* path = NULL;
46
47 if (!PyArg_ParseTuple(args, "s", &path))
48 return -1;
49
50 self->path = strdup(path);
51
52 // Open the file for reading
53 FILE* f = fopen(self->path, "r");
54 if (!f) {
55 PyErr_SetFromErrno(PyExc_IOError);
56 return -1;
57 }
58
59 // Load the database
60 int r = loc_database_new(loc_ctx, &self->db, f);
61 fclose(f);
62
63 // Return on any errors
64 if (r)
65 return -1;
66
67 return 0;
68 }
69
70 static PyObject* Database_repr(DatabaseObject* self) {
71 return PyUnicode_FromFormat("<Database %s>", self->path);
72 }
73
74 static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
75 PyObject* public_key = NULL;
76 FILE* f = NULL;
77
78 // Parse arguments
79 if (!PyArg_ParseTuple(args, "O", &public_key))
80 return NULL;
81
82 // Convert into FILE*
83 int fd = PyObject_AsFileDescriptor(public_key);
84 if (fd < 0)
85 return NULL;
86
87 // Re-open file descriptor
88 f = fdopen(fd, "r");
89 if (!f) {
90 PyErr_SetFromErrno(PyExc_IOError);
91 return NULL;
92 }
93
94 int r = loc_database_verify(self->db, f);
95
96 if (r == 0)
97 Py_RETURN_TRUE;
98
99 Py_RETURN_FALSE;
100 }
101
102 static PyObject* Database_get_description(DatabaseObject* self) {
103 const char* description = loc_database_get_description(self->db);
104
105 return PyUnicode_FromString(description);
106 }
107
108 static PyObject* Database_get_vendor(DatabaseObject* self) {
109 const char* vendor = loc_database_get_vendor(self->db);
110
111 return PyUnicode_FromString(vendor);
112 }
113
114 static PyObject* Database_get_license(DatabaseObject* self) {
115 const char* license = loc_database_get_license(self->db);
116
117 return PyUnicode_FromString(license);
118 }
119
120 static PyObject* Database_get_created_at(DatabaseObject* self) {
121 time_t created_at = loc_database_created_at(self->db);
122
123 return PyLong_FromLong(created_at);
124 }
125
126 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
127 struct loc_as* as = NULL;
128 uint32_t number = 0;
129
130 if (!PyArg_ParseTuple(args, "i", &number))
131 return NULL;
132
133 // Try to retrieve the AS
134 int r = loc_database_get_as(self->db, &as, number);
135
136 // We got an AS
137 if (r == 0) {
138 PyObject* obj = new_as(&ASType, as);
139 loc_as_unref(as);
140
141 return obj;
142
143 // Nothing found
144 } else if (r == 1) {
145 Py_RETURN_NONE;
146 }
147
148 // Unexpected error
149 return NULL;
150 }
151
152 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
153 const char* country_code = NULL;
154
155 if (!PyArg_ParseTuple(args, "s", &country_code))
156 return NULL;
157
158 struct loc_country* country;
159 int r = loc_database_get_country(self->db, &country, country_code);
160 if (r) {
161 Py_RETURN_NONE;
162 }
163
164 PyObject* obj = new_country(&CountryType, country);
165 loc_country_unref(country);
166
167 return obj;
168 }
169
170 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
171 struct loc_network* network = NULL;
172 const char* address = NULL;
173
174 if (!PyArg_ParseTuple(args, "s", &address))
175 return NULL;
176
177 // Try to retrieve a matching network
178 int r = loc_database_lookup_from_string(self->db, address, &network);
179
180 // We got a network
181 if (r == 0) {
182 PyObject* obj = new_network(&NetworkType, network);
183 loc_network_unref(network);
184
185 return obj;
186
187 // Nothing found
188 } else if (r == 1) {
189 Py_RETURN_NONE;
190
191 // Invalid input
192 } else if (r == -EINVAL) {
193 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
194 return NULL;
195 }
196
197 // Unexpected error
198 return NULL;
199 }
200
201 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
202 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
203 if (self) {
204 self->enumerator = loc_database_enumerator_ref(enumerator);
205 }
206
207 return (PyObject*)self;
208 }
209
210 static PyObject* Database_iterate_all(DatabaseObject* self, enum loc_database_enumerator_mode what, int flags) {
211 struct loc_database_enumerator* enumerator;
212
213 int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
214 if (r) {
215 PyErr_SetFromErrno(PyExc_SystemError);
216 return NULL;
217 }
218
219 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
220 loc_database_enumerator_unref(enumerator);
221
222 return obj;
223 }
224
225 static PyObject* Database_ases(DatabaseObject* self) {
226 return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, 0);
227 }
228
229 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
230 const char* string = NULL;
231
232 if (!PyArg_ParseTuple(args, "s", &string))
233 return NULL;
234
235 struct loc_database_enumerator* enumerator;
236
237 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
238 if (r) {
239 PyErr_SetFromErrno(PyExc_SystemError);
240 return NULL;
241 }
242
243 // Search string we are searching for
244 loc_database_enumerator_set_string(enumerator, string);
245
246 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
247 loc_database_enumerator_unref(enumerator);
248
249 return obj;
250 }
251
252 static PyObject* Database_networks(DatabaseObject* self) {
253 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, 0);
254 }
255
256 static PyObject* Database_networks_flattened(DatabaseObject *self) {
257 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
258 }
259
260 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
261 char* kwlist[] = { "country_codes", "asn", "flags", "family", "flatten", NULL };
262 PyObject* country_codes = NULL;
263 unsigned int asn = 0;
264 int flags = 0;
265 int family = 0;
266 int flatten = 0;
267
268 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!iiip", kwlist,
269 &PyList_Type, &country_codes, &asn, &flags, &family, &flatten))
270 return NULL;
271
272 struct loc_database_enumerator* enumerator;
273 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS,
274 (flatten) ? LOC_DB_ENUMERATOR_FLAGS_FLATTEN : 0);
275 if (r) {
276 PyErr_SetFromErrno(PyExc_SystemError);
277 return NULL;
278 }
279
280 // Set country code we are searching for
281 if (country_codes) {
282 struct loc_country_list* countries;
283 r = loc_country_list_new(loc_ctx, &countries);
284 if (r) {
285 PyErr_SetString(PyExc_SystemError, "Could not create country list");
286 return NULL;
287 }
288
289 for (unsigned int i = 0; i < PyList_Size(country_codes); i++) {
290 PyObject* item = PyList_GetItem(country_codes, i);
291
292 if (!PyUnicode_Check(item)) {
293 PyErr_SetString(PyExc_TypeError, "Country codes must be strings");
294 loc_country_list_unref(countries);
295 return NULL;
296 }
297
298 const char* country_code = PyUnicode_AsUTF8(item);
299
300 struct loc_country* country;
301 r = loc_country_new(loc_ctx, &country, country_code);
302 if (r) {
303 if (r == -EINVAL) {
304 PyErr_Format(PyExc_ValueError, "Invalid country code: %s", country_code);
305 } else {
306 PyErr_SetString(PyExc_SystemError, "Could not create country");
307 }
308
309 loc_country_list_unref(countries);
310 return NULL;
311 }
312
313 // Append it to the list
314 r = loc_country_list_append(countries, country);
315 if (r) {
316 PyErr_SetString(PyExc_SystemError, "Could not append country to the list");
317
318 loc_country_list_unref(countries);
319 loc_country_unref(country);
320 return NULL;
321 }
322
323 loc_country_unref(country);
324 }
325
326 loc_database_enumerator_set_countries(enumerator, countries);
327
328 Py_DECREF(country_codes);
329 loc_country_list_unref(countries);
330 }
331
332 // Set the ASN we are searching for
333 if (asn) {
334 r = loc_database_enumerator_set_asn(enumerator, asn);
335
336 if (r) {
337 PyErr_SetFromErrno(PyExc_SystemError);
338 return NULL;
339 }
340 }
341
342 // Set the flags we are searching for
343 if (flags) {
344 r = loc_database_enumerator_set_flag(enumerator, flags);
345
346 if (r) {
347 PyErr_SetFromErrno(PyExc_SystemError);
348 return NULL;
349 }
350 }
351
352 // Set the family we are searching for
353 if (family) {
354 r = loc_database_enumerator_set_family(enumerator, family);
355
356 if (r) {
357 PyErr_SetFromErrno(PyExc_SystemError);
358 return NULL;
359 }
360 }
361
362 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
363 loc_database_enumerator_unref(enumerator);
364
365 return obj;
366 }
367
368 static PyObject* Database_countries(DatabaseObject* self) {
369 return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, 0);
370 }
371
372 static struct PyMethodDef Database_methods[] = {
373 {
374 "get_as",
375 (PyCFunction)Database_get_as,
376 METH_VARARGS,
377 NULL,
378 },
379 {
380 "get_country",
381 (PyCFunction)Database_get_country,
382 METH_VARARGS,
383 NULL,
384 },
385 {
386 "lookup",
387 (PyCFunction)Database_lookup,
388 METH_VARARGS,
389 NULL,
390 },
391 {
392 "search_as",
393 (PyCFunction)Database_search_as,
394 METH_VARARGS,
395 NULL,
396 },
397 {
398 "search_networks",
399 (PyCFunction)Database_search_networks,
400 METH_VARARGS|METH_KEYWORDS,
401 NULL,
402 },
403 {
404 "verify",
405 (PyCFunction)Database_verify,
406 METH_VARARGS,
407 NULL,
408 },
409 { NULL },
410 };
411
412 static struct PyGetSetDef Database_getsetters[] = {
413 {
414 "ases",
415 (getter)Database_ases,
416 NULL,
417 NULL,
418 NULL,
419 },
420 {
421 "countries",
422 (getter)Database_countries,
423 NULL,
424 NULL,
425 NULL,
426 },
427 {
428 "created_at",
429 (getter)Database_get_created_at,
430 NULL,
431 NULL,
432 NULL,
433 },
434 {
435 "description",
436 (getter)Database_get_description,
437 NULL,
438 NULL,
439 NULL,
440 },
441 {
442 "license",
443 (getter)Database_get_license,
444 NULL,
445 NULL,
446 NULL,
447 },
448 {
449 "networks",
450 (getter)Database_networks,
451 NULL,
452 NULL,
453 NULL,
454 },
455 {
456 "networks_flattened",
457 (getter)Database_networks_flattened,
458 NULL,
459 NULL,
460 NULL,
461 },
462 {
463 "vendor",
464 (getter)Database_get_vendor,
465 NULL,
466 NULL,
467 NULL,
468 },
469 { NULL },
470 };
471
472 PyTypeObject DatabaseType = {
473 PyVarObject_HEAD_INIT(NULL, 0)
474 .tp_name = "location.Database",
475 .tp_basicsize = sizeof(DatabaseObject),
476 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
477 .tp_new = Database_new,
478 .tp_dealloc = (destructor)Database_dealloc,
479 .tp_init = (initproc)Database_init,
480 .tp_doc = "Database object",
481 .tp_methods = Database_methods,
482 .tp_getset = Database_getsetters,
483 .tp_repr = (reprfunc)Database_repr,
484 };
485
486 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
487 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
488
489 return (PyObject*)self;
490 }
491
492 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
493 loc_database_enumerator_unref(self->enumerator);
494
495 Py_TYPE(self)->tp_free((PyObject* )self);
496 }
497
498 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
499 struct loc_network* network = NULL;
500
501 // Enumerate all networks
502 int r = loc_database_enumerator_next_network(self->enumerator, &network);
503 if (r) {
504 PyErr_SetFromErrno(PyExc_ValueError);
505 return NULL;
506 }
507
508 // A network was found
509 if (network) {
510 PyObject* obj = new_network(&NetworkType, network);
511 loc_network_unref(network);
512
513 return obj;
514 }
515
516 // Enumerate all ASes
517 struct loc_as* as = NULL;
518
519 r = loc_database_enumerator_next_as(self->enumerator, &as);
520 if (r) {
521 PyErr_SetFromErrno(PyExc_ValueError);
522 return NULL;
523 }
524
525 if (as) {
526 PyObject* obj = new_as(&ASType, as);
527 loc_as_unref(as);
528
529 return obj;
530 }
531
532 // Enumerate all countries
533 struct loc_country* country = NULL;
534
535 r = loc_database_enumerator_next_country(self->enumerator, &country);
536 if (r) {
537 PyErr_SetFromErrno(PyExc_ValueError);
538 return NULL;
539 }
540
541 if (country) {
542 PyObject* obj = new_country(&CountryType, country);
543 loc_country_unref(country);
544
545 return obj;
546 }
547
548 // Nothing found, that means the end
549 PyErr_SetNone(PyExc_StopIteration);
550 return NULL;
551 }
552
553 PyTypeObject DatabaseEnumeratorType = {
554 PyVarObject_HEAD_INIT(NULL, 0)
555 .tp_name = "location.DatabaseEnumerator",
556 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
557 .tp_flags = Py_TPFLAGS_DEFAULT,
558 .tp_alloc = PyType_GenericAlloc,
559 .tp_new = DatabaseEnumerator_new,
560 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
561 .tp_iter = PyObject_SelfIter,
562 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
563 };