]> git.ipfire.org Git - people/ms/libloc.git/blob - src/python/database.c
Change how errors are handled by iterators
[people/ms/libloc.git] / src / python / database.c
1 /*
2 libloc - A library to determine the location of someone on the Internet
3
4 Copyright (C) 2017 IPFire Development Team <info@ipfire.org>
5
6 This library is free software; you can redistribute it and/or
7 modify it under the terms of the GNU Lesser General Public
8 License as published by the Free Software Foundation; either
9 version 2.1 of the License, or (at your option) any later version.
10
11 This library is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15 */
16
17 #include <Python.h>
18
19 #include <loc/libloc.h>
20 #include <loc/database.h>
21
22 #include "locationmodule.h"
23 #include "as.h"
24 #include "database.h"
25 #include "network.h"
26
27 static PyObject* Database_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
28 DatabaseObject* self = (DatabaseObject*)type->tp_alloc(type, 0);
29
30 return (PyObject*)self;
31 }
32
33 static void Database_dealloc(DatabaseObject* self) {
34 if (self->db)
35 loc_database_unref(self->db);
36
37 if (self->path)
38 free(self->path);
39
40 Py_TYPE(self)->tp_free((PyObject* )self);
41 }
42
43 static int Database_init(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
44 const char* path = NULL;
45
46 if (!PyArg_ParseTuple(args, "s", &path))
47 return -1;
48
49 self->path = strdup(path);
50
51 // Open the file for reading
52 FILE* f = fopen(self->path, "r");
53 if (!f) {
54 PyErr_SetFromErrno(PyExc_IOError);
55 return -1;
56 }
57
58 // Load the database
59 int r = loc_database_new(loc_ctx, &self->db, f);
60 fclose(f);
61
62 // Return on any errors
63 if (r)
64 return -1;
65
66 return 0;
67 }
68
69 static PyObject* Database_repr(DatabaseObject* self) {
70 return PyUnicode_FromFormat("<Database %s>", self->path);
71 }
72
73 static PyObject* Database_get_description(DatabaseObject* self) {
74 const char* description = loc_database_get_description(self->db);
75
76 return PyUnicode_FromString(description);
77 }
78
79 static PyObject* Database_get_vendor(DatabaseObject* self) {
80 const char* vendor = loc_database_get_vendor(self->db);
81
82 return PyUnicode_FromString(vendor);
83 }
84
85 static PyObject* Database_get_license(DatabaseObject* self) {
86 const char* license = loc_database_get_license(self->db);
87
88 return PyUnicode_FromString(license);
89 }
90
91 static PyObject* Database_get_created_at(DatabaseObject* self) {
92 time_t created_at = loc_database_created_at(self->db);
93
94 return PyLong_FromLong(created_at);
95 }
96
97 static PyObject* Database_get_as(DatabaseObject* self, PyObject* args) {
98 struct loc_as* as = NULL;
99 uint32_t number = 0;
100
101 if (!PyArg_ParseTuple(args, "i", &number))
102 return NULL;
103
104 // Try to retrieve the AS
105 int r = loc_database_get_as(self->db, &as, number);
106
107 // We got an AS
108 if (r == 0) {
109 PyObject* obj = new_as(&ASType, as);
110 loc_as_unref(as);
111
112 return obj;
113
114 // Nothing found
115 } else if (r == 1) {
116 Py_RETURN_NONE;
117 }
118
119 // Unexpected error
120 return NULL;
121 }
122
123 static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) {
124 struct loc_network* network = NULL;
125 const char* address = NULL;
126
127 if (!PyArg_ParseTuple(args, "s", &address))
128 return NULL;
129
130 // Try to retrieve a matching network
131 int r = loc_database_lookup_from_string(self->db, address, &network);
132
133 // We got a network
134 if (r == 0) {
135 PyObject* obj = new_network(&NetworkType, network);
136 loc_network_unref(network);
137
138 return obj;
139
140 // Nothing found
141 } else if (r == 1) {
142 Py_RETURN_NONE;
143
144 // Invalid input
145 } else if (r == -EINVAL) {
146 PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address);
147 return NULL;
148 }
149
150 // Unexpected error
151 return NULL;
152 }
153
154 static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) {
155 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
156 if (self) {
157 self->enumerator = loc_database_enumerator_ref(enumerator);
158 }
159
160 return (PyObject*)self;
161 }
162
163 static PyObject* Database_search_as(DatabaseObject* self, PyObject* args) {
164 const char* string = NULL;
165
166 if (!PyArg_ParseTuple(args, "s", &string))
167 return NULL;
168
169 struct loc_database_enumerator* enumerator;
170
171 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_ASES);
172 if (r) {
173 PyErr_SetFromErrno(PyExc_SystemError);
174 return NULL;
175 }
176
177 // Search string we are searching for
178 loc_database_enumerator_set_string(enumerator, string);
179
180 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
181 loc_database_enumerator_unref(enumerator);
182
183 return obj;
184 }
185
186 static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) {
187 char* kwlist[] = { "country_code", "asn", NULL };
188 const char* country_code = NULL;
189 unsigned int asn = 0;
190
191 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|si", kwlist, &country_code, &asn))
192 return NULL;
193
194 struct loc_database_enumerator* enumerator;
195 int r = loc_database_enumerator_new(&enumerator, self->db, LOC_DB_ENUMERATE_NETWORKS);
196 if (r) {
197 PyErr_SetFromErrno(PyExc_SystemError);
198 return NULL;
199 }
200
201 // Set country code we are searching for
202 if (country_code) {
203 r = loc_database_enumerator_set_country_code(enumerator, country_code);
204
205 if (r) {
206 PyErr_SetFromErrno(PyExc_SystemError);
207 return NULL;
208 }
209 }
210
211 // Set the ASN we are searching for
212 if (asn) {
213 r = loc_database_enumerator_set_asn(enumerator, asn);
214
215 if (r) {
216 PyErr_SetFromErrno(PyExc_SystemError);
217 return NULL;
218 }
219 }
220
221 PyObject* obj = new_database_enumerator(&DatabaseEnumeratorType, enumerator);
222 loc_database_enumerator_unref(enumerator);
223
224 return obj;
225 }
226
227 static struct PyMethodDef Database_methods[] = {
228 {
229 "get_as",
230 (PyCFunction)Database_get_as,
231 METH_VARARGS,
232 NULL,
233 },
234 {
235 "lookup",
236 (PyCFunction)Database_lookup,
237 METH_VARARGS,
238 NULL,
239 },
240 {
241 "search_as",
242 (PyCFunction)Database_search_as,
243 METH_VARARGS,
244 NULL,
245 },
246 {
247 "search_networks",
248 (PyCFunction)Database_search_networks,
249 METH_VARARGS|METH_KEYWORDS,
250 NULL,
251 },
252 { NULL },
253 };
254
255 static struct PyGetSetDef Database_getsetters[] = {
256 {
257 "created_at",
258 (getter)Database_get_created_at,
259 NULL,
260 NULL,
261 NULL,
262 },
263 {
264 "description",
265 (getter)Database_get_description,
266 NULL,
267 NULL,
268 NULL,
269 },
270 {
271 "license",
272 (getter)Database_get_license,
273 NULL,
274 NULL,
275 NULL,
276 },
277 {
278 "vendor",
279 (getter)Database_get_vendor,
280 NULL,
281 NULL,
282 NULL,
283 },
284 { NULL },
285 };
286
287 PyTypeObject DatabaseType = {
288 PyVarObject_HEAD_INIT(NULL, 0)
289 .tp_name = "location.Database",
290 .tp_basicsize = sizeof(DatabaseObject),
291 .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,
292 .tp_new = Database_new,
293 .tp_dealloc = (destructor)Database_dealloc,
294 .tp_init = (initproc)Database_init,
295 .tp_doc = "Database object",
296 .tp_methods = Database_methods,
297 .tp_getset = Database_getsetters,
298 .tp_repr = (reprfunc)Database_repr,
299 };
300
301 static PyObject* DatabaseEnumerator_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
302 DatabaseEnumeratorObject* self = (DatabaseEnumeratorObject*)type->tp_alloc(type, 0);
303
304 return (PyObject*)self;
305 }
306
307 static void DatabaseEnumerator_dealloc(DatabaseEnumeratorObject* self) {
308 loc_database_enumerator_unref(self->enumerator);
309
310 Py_TYPE(self)->tp_free((PyObject* )self);
311 }
312
313 static PyObject* DatabaseEnumerator_next(DatabaseEnumeratorObject* self) {
314 struct loc_network* network = NULL;
315
316 // Enumerate all networks
317 int r = loc_database_enumerator_next_network(self->enumerator, &network);
318 if (r) {
319 return NULL;
320 }
321
322 // A network was found
323 if (network) {
324 PyObject* obj = new_network(&NetworkType, network);
325 loc_network_unref(network);
326
327 return obj;
328 }
329
330 // Enumerate all ASes
331 struct loc_as* as = NULL;
332
333 r = loc_database_enumerator_next_as(self->enumerator, &as);
334 if (r) {
335 return NULL;
336 }
337
338 if (as) {
339 PyObject* obj = new_as(&ASType, as);
340 loc_as_unref(as);
341
342 return obj;
343 }
344
345 // Nothing found, that means the end
346 PyErr_SetNone(PyExc_StopIteration);
347 return NULL;
348 }
349
350 PyTypeObject DatabaseEnumeratorType = {
351 PyVarObject_HEAD_INIT(NULL, 0)
352 .tp_name = "location.DatabaseEnumerator",
353 .tp_basicsize = sizeof(DatabaseEnumeratorObject),
354 .tp_flags = Py_TPFLAGS_DEFAULT,
355 .tp_alloc = PyType_GenericAlloc,
356 .tp_new = DatabaseEnumerator_new,
357 .tp_dealloc = (destructor)DatabaseEnumerator_dealloc,
358 .tp_iter = PyObject_SelfIter,
359 .tp_iternext = (iternextfunc)DatabaseEnumerator_next,
360 };