]> git.ipfire.org Git - location/libloc.git/blob - src/python/database.c
database: Pass flag to enumerator to flatten output
[location/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <loc/libloc.h>
20 #include <loc/database.h>
21
22 #include "locationmodule.h"
23 #include "as.h"
24 #include "country.h"
25 #include "database.h"
26 #include "network.h"
27
28 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
29 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
30
31 return (PyObject*)self;
32 }
33
34 static void Database_dealloc(DatabaseObject* self) {
35 if (self->db)
36 loc_database_unref(self->db);
37
38 if (self->path)
39 free(self->path);
40
41 Py_TYPE(self)->tp_free((PyObject* )self);
42 }
43
44 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
45 const char* path = NULL;
46
47 if (!PyArg_ParseTuple(args, "s", &path))
48 return -1;
49
50 self->path = strdup(path);
51
52 // Open the file for reading
53 FILE* f = fopen(self->path, "r");
54 if (!f) {
55 PyErr_SetFromErrno(PyExc_IOError);
56 return -1;
57 }
58
59 // Load the database
60 int r = loc_database_new(loc_ctx, &self->db, f);
61 fclose(f);
62
63 // Return on any errors
64 if (r)
65 return -1;
66
67 return 0;
68 }
69
70 static PyObject* Database_repr(DatabaseObject* self) {
71 return PyUnicode_FromFormat("<Database %s>", self->path);
72 }
73
74 static PyObject* Database_verify(DatabaseObject* self, PyObject* args) {
75 PyObject* public_key = NULL;
76 FILE* f = NULL;
77
78 // Parse arguments
79 if (!PyArg_ParseTuple(args, "O", &public_key))
80 return NULL;
81
82 // Convert into FILE*
83 int fd = PyObject_AsFileDescriptor(public_key);
84 if (fd < 0)
85 return NULL;
86
87 // Re-open file descriptor
88 f = fdopen(fd, "r");
89 if (!f) {
90 PyErr_SetFromErrno(PyExc_IOError);
91 return NULL;
92 }
93
94 int r = loc_database_verify(self->db, f);
95
96 if (r == 0)
97 Py_RETURN_TRUE;
98
99 Py_RETURN_FALSE;
100 }
101
102 static PyObject* Database_get_description(DatabaseObject* self) {
103 const char* description = loc_database_get_description(self->db);
104
105 return PyUnicode_FromString(description);
106 }
107
108 static PyObject* Database_get_vendor(DatabaseObject* self) {
109 const char* vendor = loc_database_get_vendor(self->db);
110
111 return PyUnicode_FromString(vendor);
112 }
113
114 static PyObject* Database_get_license(DatabaseObject* self) {
115 const char* license = loc_database_get_license(self->db);
116
117 return PyUnicode_FromString(license);
118 }
119
120 static PyObject* Database_get_created_at(DatabaseObject* self) {
121 time_t created_at = loc_database_created_at(self->db);
122
123 return PyLong_FromLong(created_at);
124 }
125
126 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
127 struct loc_as* as = NULL;
128 uint32_t number = 0;
129
130 if (!PyArg_ParseTuple(args, "i", &number))
131 return NULL;
132
133 // Try to retrieve the AS
134 int r = loc_database_get_as(self->db, &as, number);
135
136 // We got an AS
137 if (r == 0) {
138 PyObject* obj = new_as(&ASType, as);
139 loc_as_unref(as);
140
141 return obj;
142
143 // Nothing found
144 } else if (r == 1) {
145 Py_RETURN_NONE;
146 }
147
148 // Unexpected error
149 return NULL;
150 }
151
152 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
153 const char* country_code = NULL;
154
155 if (!PyArg_ParseTuple(args, "s", &country_code))
156 return NULL;
157
158 struct loc_country* country;
159 int r = loc_database_get_country(self->db, &country, country_code);
160 if (r) {
161 Py_RETURN_NONE;
162 }
163
164 PyObject* obj = new_country(&CountryType, country);
165 loc_country_unref(country);
166
167 return obj;
168 }
169
170 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
171 struct loc_network* network = NULL;
172 const char* address = NULL;
173
174 if (!PyArg_ParseTuple(args, "s", &address))
175 return NULL;
176
177 // Try to retrieve a matching network
178 int r = loc_database_lookup_from_string(self->db, address, &network);
179
180 // We got a network
181 if (r == 0) {
182 PyObject* obj = new_network(&NetworkType, network);
183 loc_network_unref(network);
184
185 return obj;
186
187 // Nothing found
188 } else if (r == 1) {
189 Py_RETURN_NONE;
190
191 // Invalid input
192 } else if (r == -EINVAL) {
193 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
194 return NULL;
195 }
196
197 // Unexpected error
198 return NULL;
199 }
200
201 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
202 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
203 if (self) {
204 self->enumerator = loc_database_enumerator_ref(enumerator);
205 }
206
207 return (PyObject*)self;
208 }
209
210 static PyObject* Database_iterate_all(DatabaseObject* self, enum loc_database_enumerator_mode what, int flags) {
211 struct loc_database_enumerator* enumerator;
212
213 int r = loc_database_enumerator_new(&enumerator, self->db, what, flags);
214 if (r) {
215 PyErr_SetFromErrno(PyExc_SystemError);
216 return NULL;
217 }
218
219 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
220 loc_database_enumerator_unref(enumerator);
221
222 return obj;
223 }
224
225 static PyObject* Database_ases(DatabaseObject* self) {
226 return Database_iterate_all(self, LOC_DB_ENUMERATE_ASES, 0);
227 }
228
229 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
230 const char* string = NULL;
231
232 if (!PyArg_ParseTuple(args, "s", &string))
233 return NULL;
234
235 struct loc_database_enumerator* enumerator;
236
237 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES, 0);
238 if (r) {
239 PyErr_SetFromErrno(PyExc_SystemError);
240 return NULL;
241 }
242
243 // Search string we are searching for
244 loc_database_enumerator_set_string(enumerator, string);
245
246 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
247 loc_database_enumerator_unref(enumerator);
248
249 return obj;
250 }
251
252 static PyObject* Database_networks(DatabaseObject* self) {
253 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, 0);
254 }
255
256 static PyObject* Database_networks_flattened(DatabaseObject *self) {
257 return Database_iterate_all(self, LOC_DB_ENUMERATE_NETWORKS, LOC_DB_ENUMERATOR_FLAGS_FLATTEN);
258 }
259
260 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
261 char* kwlist[] = { "country_code", "asn", "flags", "family", NULL };
262 const char* country_code = NULL;
263 unsigned int asn = 0;
264 int flags = 0;
265 int family = 0;
266
267 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|siii", kwlist, &country_code, &asn, &flags, &family))
268 return NULL;
269
270 struct loc_database_enumerator* enumerator;
271 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS, 0);
272 if (r) {
273 PyErr_SetFromErrno(PyExc_SystemError);
274 return NULL;
275 }
276
277 // Set country code we are searching for
278 if (country_code) {
279 r = loc_database_enumerator_set_country_code(enumerator, country_code);
280
281 if (r) {
282 PyErr_SetFromErrno(PyExc_SystemError);
283 return NULL;
284 }
285 }
286
287 // Set the ASN we are searching for
288 if (asn) {
289 r = loc_database_enumerator_set_asn(enumerator, asn);
290
291 if (r) {
292 PyErr_SetFromErrno(PyExc_SystemError);
293 return NULL;
294 }
295 }
296
297 // Set the flags we are searching for
298 if (flags) {
299 r = loc_database_enumerator_set_flag(enumerator, flags);
300
301 if (r) {
302 PyErr_SetFromErrno(PyExc_SystemError);
303 return NULL;
304 }
305 }
306
307 // Set the family we are searching for
308 if (family) {
309 r = loc_database_enumerator_set_family(enumerator, family);
310
311 if (r) {
312 PyErr_SetFromErrno(PyExc_SystemError);
313 return NULL;
314 }
315 }
316
317 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
318 loc_database_enumerator_unref(enumerator);
319
320 return obj;
321 }
322
323 static PyObject* Database_countries(DatabaseObject* self) {
324 return Database_iterate_all(self, LOC_DB_ENUMERATE_COUNTRIES, 0);
325 }
326
327 static struct PyMethodDef Database_methods[] = {
328 {
329 "get_as",
330 (PyCFunction)Database_get_as,
331 METH_VARARGS,
332 NULL,
333 },
334 {
335 "get_country",
336 (PyCFunction)Database_get_country,
337 METH_VARARGS,
338 NULL,
339 },
340 {
341 "lookup",
342 (PyCFunction)Database_lookup,
343 METH_VARARGS,
344 NULL,
345 },
346 {
347 "search_as",
348 (PyCFunction)Database_search_as,
349 METH_VARARGS,
350 NULL,
351 },
352 {
353 "search_networks",
354 (PyCFunction)Database_search_networks,
355 METH_VARARGS|METH_KEYWORDS,
356 NULL,
357 },
358 {
359 "verify",
360 (PyCFunction)Database_verify,
361 METH_VARARGS,
362 NULL,
363 },
364 { NULL },
365 };
366
367 static struct PyGetSetDef Database_getsetters[] = {
368 {
369 "ases",
370 (getter)Database_ases,
371 NULL,
372 NULL,
373 NULL,
374 },
375 {
376 "countries",
377 (getter)Database_countries,
378 NULL,
379 NULL,
380 NULL,
381 },
382 {
383 "created_at",
384 (getter)Database_get_created_at,
385 NULL,
386 NULL,
387 NULL,
388 },
389 {
390 "description",
391 (getter)Database_get_description,
392 NULL,
393 NULL,
394 NULL,
395 },
396 {
397 "license",
398 (getter)Database_get_license,
399 NULL,
400 NULL,
401 NULL,
402 },
403 {
404 "networks",
405 (getter)Database_networks,
406 NULL,
407 NULL,
408 NULL,
409 },
410 {
411 "networks_flattened",
412 (getter)Database_networks_flattened,
413 NULL,
414 NULL,
415 NULL,
416 },
417 {
418 "vendor",
419 (getter)Database_get_vendor,
420 NULL,
421 NULL,
422 NULL,
423 },
424 { NULL },
425 };
426
427 PyTypeObject DatabaseType = {
428 PyVarObject_HEAD_INIT(NULL, 0)
429 .tp_name = "location.Database",
430 .tp_basicsize = sizeof(DatabaseObject),
431 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
432 .tp_new = Database_new,
433 .tp_dealloc = (destructor)Database_dealloc,
434 .tp_init = (initproc)Database_init,
435 .tp_doc = "Database object",
436 .tp_methods = Database_methods,
437 .tp_getset = Database_getsetters,
438 .tp_repr = (reprfunc)Database_repr,
439 };
440
441 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
442 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
443
444 return (PyObject*)self;
445 }
446
447 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
448 loc_database_enumerator_unref(self->enumerator);
449
450 Py_TYPE(self)->tp_free((PyObject* )self);
451 }
452
453 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
454 struct loc_network* network = NULL;
455
456 // Enumerate all networks
457 int r = loc_database_enumerator_next_network(self->enumerator, &network);
458 if (r) {
459 PyErr_SetFromErrno(PyExc_ValueError);
460 return NULL;
461 }
462
463 // A network was found
464 if (network) {
465 PyObject* obj = new_network(&NetworkType, network);
466 loc_network_unref(network);
467
468 return obj;
469 }
470
471 // Enumerate all ASes
472 struct loc_as* as = NULL;
473
474 r = loc_database_enumerator_next_as(self->enumerator, &as);
475 if (r) {
476 PyErr_SetFromErrno(PyExc_ValueError);
477 return NULL;
478 }
479
480 if (as) {
481 PyObject* obj = new_as(&ASType, as);
482 loc_as_unref(as);
483
484 return obj;
485 }
486
487 // Enumerate all countries
488 struct loc_country* country = NULL;
489
490 r = loc_database_enumerator_next_country(self->enumerator, &country);
491 if (r) {
492 PyErr_SetFromErrno(PyExc_ValueError);
493 return NULL;
494 }
495
496 if (country) {
497 PyObject* obj = new_country(&CountryType, country);
498 loc_country_unref(country);
499
500 return obj;
501 }
502
503 // Nothing found, that means the end
504 PyErr_SetNone(PyExc_StopIteration);
505 return NULL;
506 }
507
508 PyTypeObject DatabaseEnumeratorType = {
509 PyVarObject_HEAD_INIT(NULL, 0)
510 .tp_name = "location.DatabaseEnumerator",
511 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
512 .tp_flags = Py_TPFLAGS_DEFAULT,
513 .tp_alloc = PyType_GenericAlloc,
514 .tp_new = DatabaseEnumerator_new,
515 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
516 .tp_iter = PyObject_SelfIter,
517 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
518 };