]> git.ipfire.org Git - location/libloc.git/blob - src/python/database.c
importer: Drop EDROP as it has been merged into DROP
[location/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <loc/libloc.h>
20 #include <loc/database.h>
21
22 #include "locationmodule.h"
23 #include "as.h"
24 #include "country.h"
25 #include "database.h"
26 #include "network.h"
27
28 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
29 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
30
31 return (PyObject*)self;
32 }
33
34 static void Database_dealloc(DatabaseObject* self) {
35 if (self->db)
36 loc_database_unref(self->db);
37
38 if (self->path)
39 free(self->path);
40
41 Py_TYPE(self)->tp_free((PyObject* )self);
42 }
43
44 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
45 const char* path = NULL;
46
47 if (!PyArg_ParseTuple(args, "s", &path))
48 return -1;
49
50 self->path = strdup(path);
51
52 // Open the file for reading
53 FILE* f = fopen(self->path, "r");
54 if (!f) {
55 PyErr_SetFromErrno(PyExc_IOError);
56 return -1;
57 }
58
59 // Load the database
60 int r = loc_database_new(loc_ctx, &self->db, f);
61 fclose(f);
62
63 // Return on any errors
64 if (r)
65 return -1;
66
67 return 0;
68 }
69
70 static PyObject* Database_repr(DatabaseObject* self) {
71 return PyUnicode_FromFormat("<Database %s>", self->path);
72 }
73
74 static PyObject* Database_get_description(DatabaseObject* self) {
75 const char* description = loc_database_get_description(self->db);
76
77 return PyUnicode_FromString(description);
78 }
79
80 static PyObject* Database_get_vendor(DatabaseObject* self) {
81 const char* vendor = loc_database_get_vendor(self->db);
82
83 return PyUnicode_FromString(vendor);
84 }
85
86 static PyObject* Database_get_license(DatabaseObject* self) {
87 const char* license = loc_database_get_license(self->db);
88
89 return PyUnicode_FromString(license);
90 }
91
92 static PyObject* Database_get_created_at(DatabaseObject* self) {
93 time_t created_at = loc_database_created_at(self->db);
94
95 return PyLong_FromLong(created_at);
96 }
97
98 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
99 struct loc_as* as = NULL;
100 uint32_t number = 0;
101
102 if (!PyArg_ParseTuple(args, "i", &number))
103 return NULL;
104
105 // Try to retrieve the AS
106 int r = loc_database_get_as(self->db, &as, number);
107
108 // We got an AS
109 if (r == 0) {
110 PyObject* obj = new_as(&ASType, as);
111 loc_as_unref(as);
112
113 return obj;
114
115 // Nothing found
116 } else if (r == 1) {
117 Py_RETURN_NONE;
118 }
119
120 // Unexpected error
121 return NULL;
122 }
123
124 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
125 const char* country_code = NULL;
126
127 if (!PyArg_ParseTuple(args, "s", &country_code))
128 return NULL;
129
130 struct loc_country* country;
131 int r = loc_database_get_country(self->db, &country, country_code);
132 if (r) {
133 Py_RETURN_NONE;
134 }
135
136 PyObject* obj = new_country(&CountryType, country);
137 loc_country_unref(country);
138
139 return obj;
140 }
141
142 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
143 struct loc_network* network = NULL;
144 const char* address = NULL;
145
146 if (!PyArg_ParseTuple(args, "s", &address))
147 return NULL;
148
149 // Try to retrieve a matching network
150 int r = loc_database_lookup_from_string(self->db, address, &network);
151
152 // We got a network
153 if (r == 0) {
154 PyObject* obj = new_network(&NetworkType, network);
155 loc_network_unref(network);
156
157 return obj;
158
159 // Nothing found
160 } else if (r == 1) {
161 Py_RETURN_NONE;
162
163 // Invalid input
164 } else if (r == -EINVAL) {
165 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
166 return NULL;
167 }
168
169 // Unexpected error
170 return NULL;
171 }
172
173 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
174 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
175 if (self) {
176 self->enumerator = loc_database_enumerator_ref(enumerator);
177 }
178
179 return (PyObject*)self;
180 }
181
182 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
183 const char* string = NULL;
184
185 if (!PyArg_ParseTuple(args, "s", &string))
186 return NULL;
187
188 struct loc_database_enumerator* enumerator;
189
190 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES);
191 if (r) {
192 PyErr_SetFromErrno(PyExc_SystemError);
193 return NULL;
194 }
195
196 // Search string we are searching for
197 loc_database_enumerator_set_string(enumerator, string);
198
199 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
200 loc_database_enumerator_unref(enumerator);
201
202 return obj;
203 }
204
205 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
206 char* kwlist[] = { "country_code", "asn", "flags", NULL };
207 const char* country_code = NULL;
208 unsigned int asn = 0;
209 int flags = 0;
210
211 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|sii", kwlist, &country_code, &asn, &flags))
212 return NULL;
213
214 struct loc_database_enumerator* enumerator;
215 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS);
216 if (r) {
217 PyErr_SetFromErrno(PyExc_SystemError);
218 return NULL;
219 }
220
221 // Set country code we are searching for
222 if (country_code) {
223 r = loc_database_enumerator_set_country_code(enumerator, country_code);
224
225 if (r) {
226 PyErr_SetFromErrno(PyExc_SystemError);
227 return NULL;
228 }
229 }
230
231 // Set the ASN we are searching for
232 if (asn) {
233 r = loc_database_enumerator_set_asn(enumerator, asn);
234
235 if (r) {
236 PyErr_SetFromErrno(PyExc_SystemError);
237 return NULL;
238 }
239 }
240
241 // Set the flags we are searching for
242 if (flags) {
243 r = loc_database_enumerator_set_flag(enumerator, flags);
244
245 if (r) {
246 PyErr_SetFromErrno(PyExc_SystemError);
247 return NULL;
248 }
249 }
250
251 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
252 loc_database_enumerator_unref(enumerator);
253
254 return obj;
255 }
256
257 static struct PyMethodDef Database_methods[] = {
258 {
259 "get_as",
260 (PyCFunction)Database_get_as,
261 METH_VARARGS,
262 NULL,
263 },
264 {
265 "get_country",
266 (PyCFunction)Database_get_country,
267 METH_VARARGS,
268 NULL,
269 },
270 {
271 "lookup",
272 (PyCFunction)Database_lookup,
273 METH_VARARGS,
274 NULL,
275 },
276 {
277 "search_as",
278 (PyCFunction)Database_search_as,
279 METH_VARARGS,
280 NULL,
281 },
282 {
283 "search_networks",
284 (PyCFunction)Database_search_networks,
285 METH_VARARGS|METH_KEYWORDS,
286 NULL,
287 },
288 { NULL },
289 };
290
291 static struct PyGetSetDef Database_getsetters[] = {
292 {
293 "created_at",
294 (getter)Database_get_created_at,
295 NULL,
296 NULL,
297 NULL,
298 },
299 {
300 "description",
301 (getter)Database_get_description,
302 NULL,
303 NULL,
304 NULL,
305 },
306 {
307 "license",
308 (getter)Database_get_license,
309 NULL,
310 NULL,
311 NULL,
312 },
313 {
314 "vendor",
315 (getter)Database_get_vendor,
316 NULL,
317 NULL,
318 NULL,
319 },
320 { NULL },
321 };
322
323 PyTypeObject DatabaseType = {
324 PyVarObject_HEAD_INIT(NULL, 0)
325 .tp_name = "location.Database",
326 .tp_basicsize = sizeof(DatabaseObject),
327 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
328 .tp_new = Database_new,
329 .tp_dealloc = (destructor)Database_dealloc,
330 .tp_init = (initproc)Database_init,
331 .tp_doc = "Database object",
332 .tp_methods = Database_methods,
333 .tp_getset = Database_getsetters,
334 .tp_repr = (reprfunc)Database_repr,
335 };
336
337 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
338 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
339
340 return (PyObject*)self;
341 }
342
343 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
344 loc_database_enumerator_unref(self->enumerator);
345
346 Py_TYPE(self)->tp_free((PyObject* )self);
347 }
348
349 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
350 struct loc_network* network = NULL;
351
352 // Enumerate all networks
353 int r = loc_database_enumerator_next_network(self->enumerator, &network);
354 if (r) {
355 PyErr_SetFromErrno(PyExc_ValueError);
356 return NULL;
357 }
358
359 // A network was found
360 if (network) {
361 PyObject* obj = new_network(&NetworkType, network);
362 loc_network_unref(network);
363
364 return obj;
365 }
366
367 // Enumerate all ASes
368 struct loc_as* as = NULL;
369
370 r = loc_database_enumerator_next_as(self->enumerator, &as);
371 if (r) {
372 PyErr_SetFromErrno(PyExc_ValueError);
373 return NULL;
374 }
375
376 if (as) {
377 PyObject* obj = new_as(&ASType, as);
378 loc_as_unref(as);
379
380 return obj;
381 }
382
383 // Nothing found, that means the end
384 PyErr_SetNone(PyExc_StopIteration);
385 return NULL;
386 }
387
388 PyTypeObject DatabaseEnumeratorType = {
389 PyVarObject_HEAD_INIT(NULL, 0)
390 .tp_name = "location.DatabaseEnumerator",
391 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
392 .tp_flags = Py_TPFLAGS_DEFAULT,
393 .tp_alloc = PyType_GenericAlloc,
394 .tp_new = DatabaseEnumerator_new,
395 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
396 .tp_iter = PyObject_SelfIter,
397 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
398 };