]> git.ipfire.org Git - people/ms/libloc.git/blob - src/python/database.c
python: Raise error when a network/AS could not be read
[people/ms/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <loc/libloc.h>
20 #include <loc/database.h>
21
22 #include "locationmodule.h"
23 #include "as.h"
24 #include "country.h"
25 #include "database.h"
26 #include "network.h"
27
28 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
29 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
30
31 return (PyObject*)self;
32 }
33
34 static void Database_dealloc(DatabaseObject* self) {
35 if (self->db)
36 loc_database_unref(self->db);
37
38 if (self->path)
39 free(self->path);
40
41 Py_TYPE(self)->tp_free((PyObject* )self);
42 }
43
44 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
45 const char* path = NULL;
46
47 if (!PyArg_ParseTuple(args, "s", &path))
48 return -1;
49
50 self->path = strdup(path);
51
52 // Open the file for reading
53 FILE* f = fopen(self->path, "r");
54 if (!f) {
55 PyErr_SetFromErrno(PyExc_IOError);
56 return -1;
57 }
58
59 // Load the database
60 int r = loc_database_new(loc_ctx, &self->db, f);
61 fclose(f);
62
63 // Return on any errors
64 if (r)
65 return -1;
66
67 return 0;
68 }
69
70 static PyObject* Database_repr(DatabaseObject* self) {
71 return PyUnicode_FromFormat("<Database %s>", self->path);
72 }
73
74 static PyObject* Database_get_description(DatabaseObject* self) {
75 const char* description = loc_database_get_description(self->db);
76
77 return PyUnicode_FromString(description);
78 }
79
80 static PyObject* Database_get_vendor(DatabaseObject* self) {
81 const char* vendor = loc_database_get_vendor(self->db);
82
83 return PyUnicode_FromString(vendor);
84 }
85
86 static PyObject* Database_get_license(DatabaseObject* self) {
87 const char* license = loc_database_get_license(self->db);
88
89 return PyUnicode_FromString(license);
90 }
91
92 static PyObject* Database_get_created_at(DatabaseObject* self) {
93 time_t created_at = loc_database_created_at(self->db);
94
95 return PyLong_FromLong(created_at);
96 }
97
98 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
99 struct loc_as* as = NULL;
100 uint32_t number = 0;
101
102 if (!PyArg_ParseTuple(args, "i", &number))
103 return NULL;
104
105 // Try to retrieve the AS
106 int r = loc_database_get_as(self->db, &as, number);
107
108 // We got an AS
109 if (r == 0) {
110 PyObject* obj = new_as(&ASType, as);
111 loc_as_unref(as);
112
113 return obj;
114
115 // Nothing found
116 } else if (r == 1) {
117 Py_RETURN_NONE;
118 }
119
120 // Unexpected error
121 return NULL;
122 }
123
124 static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) {
125 const char* country_code = NULL;
126
127 if (!PyArg_ParseTuple(args, "s", &country_code))
128 return NULL;
129
130 struct loc_country* country;
131 int r = loc_database_get_country(self->db, &country, country_code);
132 if (r) {
133 Py_RETURN_NONE;
134 }
135
136 PyObject* obj = new_country(&CountryType, country);
137 loc_country_unref(country);
138
139 return obj;
140 }
141
142 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
143 struct loc_network* network = NULL;
144 const char* address = NULL;
145
146 if (!PyArg_ParseTuple(args, "s", &address))
147 return NULL;
148
149 // Try to retrieve a matching network
150 int r = loc_database_lookup_from_string(self->db, address, &network);
151
152 // We got a network
153 if (r == 0) {
154 PyObject* obj = new_network(&NetworkType, network);
155 loc_network_unref(network);
156
157 return obj;
158
159 // Nothing found
160 } else if (r == 1) {
161 Py_RETURN_NONE;
162
163 // Invalid input
164 } else if (r == -EINVAL) {
165 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
166 return NULL;
167 }
168
169 // Unexpected error
170 return NULL;
171 }
172
173 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
174 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
175 if (self) {
176 self->enumerator = loc_database_enumerator_ref(enumerator);
177 }
178
179 return (PyObject*)self;
180 }
181
182 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
183 const char* string = NULL;
184
185 if (!PyArg_ParseTuple(args, "s", &string))
186 return NULL;
187
188 struct loc_database_enumerator* enumerator;
189
190 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES);
191 if (r) {
192 PyErr_SetFromErrno(PyExc_SystemError);
193 return NULL;
194 }
195
196 // Search string we are searching for
197 loc_database_enumerator_set_string(enumerator, string);
198
199 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
200 loc_database_enumerator_unref(enumerator);
201
202 return obj;
203 }
204
205 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
206 char* kwlist[] = { "country_code", "asn", NULL };
207 const char* country_code = NULL;
208 unsigned int asn = 0;
209
210 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|si", kwlist, &country_code, &asn))
211 return NULL;
212
213 struct loc_database_enumerator* enumerator;
214 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS);
215 if (r) {
216 PyErr_SetFromErrno(PyExc_SystemError);
217 return NULL;
218 }
219
220 // Set country code we are searching for
221 if (country_code) {
222 r = loc_database_enumerator_set_country_code(enumerator, country_code);
223
224 if (r) {
225 PyErr_SetFromErrno(PyExc_SystemError);
226 return NULL;
227 }
228 }
229
230 // Set the ASN we are searching for
231 if (asn) {
232 r = loc_database_enumerator_set_asn(enumerator, asn);
233
234 if (r) {
235 PyErr_SetFromErrno(PyExc_SystemError);
236 return NULL;
237 }
238 }
239
240 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
241 loc_database_enumerator_unref(enumerator);
242
243 return obj;
244 }
245
246 static struct PyMethodDef Database_methods[] = {
247 {
248 "get_as",
249 (PyCFunction)Database_get_as,
250 METH_VARARGS,
251 NULL,
252 },
253 {
254 "get_country",
255 (PyCFunction)Database_get_country,
256 METH_VARARGS,
257 NULL,
258 },
259 {
260 "lookup",
261 (PyCFunction)Database_lookup,
262 METH_VARARGS,
263 NULL,
264 },
265 {
266 "search_as",
267 (PyCFunction)Database_search_as,
268 METH_VARARGS,
269 NULL,
270 },
271 {
272 "search_networks",
273 (PyCFunction)Database_search_networks,
274 METH_VARARGS|METH_KEYWORDS,
275 NULL,
276 },
277 { NULL },
278 };
279
280 static struct PyGetSetDef Database_getsetters[] = {
281 {
282 "created_at",
283 (getter)Database_get_created_at,
284 NULL,
285 NULL,
286 NULL,
287 },
288 {
289 "description",
290 (getter)Database_get_description,
291 NULL,
292 NULL,
293 NULL,
294 },
295 {
296 "license",
297 (getter)Database_get_license,
298 NULL,
299 NULL,
300 NULL,
301 },
302 {
303 "vendor",
304 (getter)Database_get_vendor,
305 NULL,
306 NULL,
307 NULL,
308 },
309 { NULL },
310 };
311
312 PyTypeObject DatabaseType = {
313 PyVarObject_HEAD_INIT(NULL, 0)
314 .tp_name = "location.Database",
315 .tp_basicsize = sizeof(DatabaseObject),
316 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
317 .tp_new = Database_new,
318 .tp_dealloc = (destructor)Database_dealloc,
319 .tp_init = (initproc)Database_init,
320 .tp_doc = "Database object",
321 .tp_methods = Database_methods,
322 .tp_getset = Database_getsetters,
323 .tp_repr = (reprfunc)Database_repr,
324 };
325
326 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
327 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
328
329 return (PyObject*)self;
330 }
331
332 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
333 loc_database_enumerator_unref(self->enumerator);
334
335 Py_TYPE(self)->tp_free((PyObject* )self);
336 }
337
338 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
339 struct loc_network* network = NULL;
340
341 // Enumerate all networks
342 int r = loc_database_enumerator_next_network(self->enumerator, &network);
343 if (r) {
344 PyErr_SetFromErrno(PyExc_ValueError);
345 return NULL;
346 }
347
348 // A network was found
349 if (network) {
350 PyObject* obj = new_network(&NetworkType, network);
351 loc_network_unref(network);
352
353 return obj;
354 }
355
356 // Enumerate all ASes
357 struct loc_as* as = NULL;
358
359 r = loc_database_enumerator_next_as(self->enumerator, &as);
360 if (r) {
361 PyErr_SetFromErrno(PyExc_ValueError);
362 return NULL;
363 }
364
365 if (as) {
366 PyObject* obj = new_as(&ASType, as);
367 loc_as_unref(as);
368
369 return obj;
370 }
371
372 // Nothing found, that means the end
373 PyErr_SetNone(PyExc_StopIteration);
374 return NULL;
375 }
376
377 PyTypeObject DatabaseEnumeratorType = {
378 PyVarObject_HEAD_INIT(NULL, 0)
379 .tp_name = "location.DatabaseEnumerator",
380 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
381 .tp_flags = Py_TPFLAGS_DEFAULT,
382 .tp_alloc = PyType_GenericAlloc,
383 .tp_new = DatabaseEnumerator_new,
384 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
385 .tp_iter = PyObject_SelfIter,
386 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
387 };