]>
Commit | Line | Data |
---|---|---|
1 | from __future__ import print_function | |
2 | from datetime import datetime | |
3 | import os | |
4 | import requests | |
5 | import unittest | |
6 | import sqlite3 | |
7 | import subprocess | |
8 | import sys | |
9 | ||
10 | if sys.version_info[0] == 2: | |
11 | from urlparse import urljoin | |
12 | else: | |
13 | from urllib.parse import urljoin | |
14 | ||
15 | DAEMON = os.environ.get('DAEMON', 'authoritative') | |
16 | PDNSUTIL_CMD = os.environ.get('PDNSUTIL_CMD', 'NOT_SET BUT_THIS MIGHT_BE_A_LIST').split(' ') | |
17 | SQLITE_DB = os.environ.get('SQLITE_DB', 'pdns.sqlite3') | |
18 | SDIG = os.environ.get('SDIG', 'sdig') | |
19 | DNSPORT = os.environ.get('DNSPORT', '53') | |
20 | ||
21 | class ApiTestCase(unittest.TestCase): | |
22 | ||
23 | def setUp(self): | |
24 | # TODO: config | |
25 | self.server_address = '127.0.0.1' | |
26 | self.webServerBasicAuthPassword = 'something' | |
27 | self.server_port = int(os.environ.get('WEBPORT', '5580')) | |
28 | self.server_url = 'http://%s:%s/' % (self.server_address, self.server_port) | |
29 | self.server_web_password = os.environ.get('WEBPASSWORD', 'MISSING') | |
30 | self.session = requests.Session() | |
31 | self.session.headers = {'X-API-Key': os.environ.get('APIKEY', 'changeme-key'), 'Origin': 'http://%s:%s' % (self.server_address, self.server_port)} | |
32 | ||
33 | def url(self, relative_url): | |
34 | return urljoin(self.server_url, relative_url) | |
35 | ||
36 | def assert_success_json(self, result): | |
37 | try: | |
38 | result.raise_for_status() | |
39 | except: | |
40 | print(result.content) | |
41 | raise | |
42 | self.assertEquals(result.headers['Content-Type'], 'application/json') | |
43 | ||
44 | def assert_error_json(self, result): | |
45 | self.assertTrue(400 <= result.status_code < 600, "Response has not an error code "+str(result.status_code)) | |
46 | self.assertEquals(result.headers['Content-Type'], 'application/json', "Response status code "+str(result.status_code)) | |
47 | ||
48 | def assert_success(self, result): | |
49 | try: | |
50 | result.raise_for_status() | |
51 | except: | |
52 | print(result.content) | |
53 | raise | |
54 | ||
55 | ||
56 | def unique_zone_name(): | |
57 | return 'test-' + datetime.now().strftime('%d%H%S%M%f') + '.org.' | |
58 | ||
59 | def unique_tsigkey_name(): | |
60 | return 'test-' + datetime.now().strftime('%d%H%S%M%f') + '-key' | |
61 | ||
62 | def is_auth(): | |
63 | return DAEMON == 'authoritative' | |
64 | ||
65 | ||
66 | def is_recursor(): | |
67 | return DAEMON == 'recursor' | |
68 | ||
69 | ||
70 | def get_auth_db(): | |
71 | """Return Connection to Authoritative backend DB.""" | |
72 | return sqlite3.Connection(SQLITE_DB) | |
73 | ||
74 | ||
75 | def get_db_records(zonename, qtype): | |
76 | with get_auth_db() as db: | |
77 | rows = db.execute(""" | |
78 | SELECT name, type, content, ttl, ordername | |
79 | FROM records | |
80 | WHERE type = ? AND domain_id = ( | |
81 | SELECT id FROM domains WHERE name = ? | |
82 | )""", (qtype, zonename.rstrip('.'))).fetchall() | |
83 | recs = [{'name': row[0], 'type': row[1], 'content': row[2], 'ttl': row[3], 'ordername': row[4]} for row in rows] | |
84 | print("DB Records:", recs) | |
85 | return recs | |
86 | ||
87 | ||
88 | def pdnsutil(subcommand, *args): | |
89 | try: | |
90 | return subprocess.check_output(PDNSUTIL_CMD + [subcommand] + list(args), close_fds=True).decode('ascii') | |
91 | except subprocess.CalledProcessError as except_inst: | |
92 | raise RuntimeError("pdnsutil %s %s failed: %s" % (subcommand, args, except_inst.output.decode('ascii', errors='replace'))) | |
93 | ||
94 | def pdnsutil_rectify(zonename): | |
95 | """Run pdnsutil rectify-zone on the given zone.""" | |
96 | pdnsutil('rectify-zone', zonename) | |
97 | ||
98 | def sdig(*args): | |
99 | try: | |
100 | return subprocess.check_call([SDIG, '127.0.0.1', str(DNSPORT)] + list(args)) | |
101 | except subprocess.CalledProcessError as except_inst: | |
102 | raise RuntimeError("sdig %s %s failed: %s" % (command, args, except_inst.output.decode('ascii', errors='replace'))) | |
103 | ||
104 | def get_db_tsigkeys(keyname): | |
105 | with get_auth_db() as db: | |
106 | rows = db.execute(""" | |
107 | SELECT name, algorithm, secret | |
108 | FROM tsigkeys | |
109 | WHERE name = ?""", (keyname, )).fetchall() | |
110 | keys = [{'name': row[0], 'algorithm': row[1], 'secret': row[2]} for row in rows] | |
111 | print("DB TSIG keys:", keys) | |
112 | return keys | |
113 |