]> git.ipfire.org Git - people/pmueller/ipfire-2.x.git/blob - src/scripts/openvpn-metrics
Early spring clean: Remove trailing whitespaces, and correct licence headers
[people/pmueller/ipfire-2.x.git] / src / scripts / openvpn-metrics
1 #!/usr/bin/python3
2 ###############################################################################
3 # #
4 # IPFire.org - A linux based firewall #
5 # Copyright (C) 2007-2022 IPFire Team <info@ipfire.org> #
6 # #
7 # This program is free software: you can redistribute it and/or modify #
8 # it under the terms of the GNU General Public License as published by #
9 # the Free Software Foundation, either version 3 of the License, or #
10 # (at your option) any later version. #
11 # #
12 # This program is distributed in the hope that it will be useful, #
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of #
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
15 # GNU General Public License for more details. #
16 # #
17 # You should have received a copy of the GNU General Public License #
18 # along with this program. If not, see <http://www.gnu.org/licenses/>. #
19 # #
20 ###############################################################################
21
22 import argparse
23 import logging
24 import logging.handlers
25 import os
26 import sqlite3
27 import sys
28
29 _ = lambda x: x
30
31 DEFAULT_DATABASE_PATH = "/var/ipfire/ovpn/clients.db"
32
33 def setup_logging(level=logging.INFO):
34 l = logging.getLogger("openvpn-metrics")
35 l.setLevel(level)
36
37 # Log to console
38 h = logging.StreamHandler()
39 h.setLevel(logging.DEBUG)
40 l.addHandler(h)
41
42 # Log to syslog
43 h = logging.handlers.SysLogHandler(address="/dev/log",
44 facility=logging.handlers.SysLogHandler.LOG_DAEMON)
45 h.setLevel(logging.INFO)
46 l.addHandler(h)
47
48 # Format syslog messages
49 formatter = logging.Formatter("openvpn-metrics[%(process)d]: %(message)s")
50 h.setFormatter(formatter)
51
52 return l
53
54 # Initialise logging
55 log = setup_logging()
56
57 class OpenVPNMetrics(object):
58 def __init__(self):
59 self.db = self._open_database()
60
61 def parse_cli(self):
62 parser = argparse.ArgumentParser(
63 description=_("Tool that collects metrics of OpenVPN Clients"),
64 )
65 subparsers = parser.add_subparsers()
66
67 # client-connect
68 client_connect = subparsers.add_parser("client-connect",
69 help=_("Called when a client connects"),
70 )
71 client_connect.add_argument("file", nargs="?",
72 help=_("Configuration file")
73 )
74 client_connect.set_defaults(func=self.client_connect)
75
76 # client-disconnect
77 client_disconnect = subparsers.add_parser("client-disconnect",
78 help=_("Called when a client disconnects"),
79 )
80 client_disconnect.add_argument("file", nargs="?",
81 help=_("Configuration file")
82 )
83 client_disconnect.set_defaults(func=self.client_disconnect)
84
85 # Parse CLI
86 args = parser.parse_args()
87
88 # Print usage if no action was given
89 if not "func" in args:
90 parser.print_usage()
91 sys.exit(2)
92
93 return args
94
95 def __call__(self):
96 # Parse command line arguments
97 args = self.parse_cli()
98
99 # Call function
100 try:
101 ret = args.func(args)
102 except Exception as e:
103 log.critical(e)
104
105 # Return with exit code
106 sys.exit(ret or 0)
107
108 def _open_database(self, path=DEFAULT_DATABASE_PATH):
109 db = sqlite3.connect(path)
110
111 # Create schema if it doesn't exist already
112 db.executescript("""
113 CREATE TABLE IF NOT EXISTS sessions(
114 common_name TEXT NOT NULL,
115 connected_at TEXT NOT NULL,
116 disconnected_at TEXT,
117 bytes_received INTEGER,
118 bytes_sent INTEGER
119 );
120
121 -- Create index for speeding up searches
122 CREATE INDEX IF NOT EXISTS sessions_common_name ON sessions(common_name);
123 """)
124
125 return db
126
127 def _get_environ(self, key):
128 if not key in os.environ:
129 sys.stderr.write("%s missing from environment\n" % key)
130 raise SystemExit(1)
131
132 return os.environ.get(key)
133
134 def client_connect(self, args):
135 common_name = self._get_environ("common_name")
136
137 # Time
138 time_ascii = self._get_environ("time_ascii")
139 time_unix = self._get_environ("time_unix")
140
141 log.info("Opening session for %s at %s" % (common_name, time_ascii))
142
143 c = self.db.cursor()
144 c.execute("INSERT INTO sessions(common_name, connected_at) \
145 VALUES(?, DATETIME(?, 'unixepoch'))", (common_name, time_unix))
146 self.db.commit()
147
148 def client_disconnect(self, args):
149 common_name = self._get_environ("common_name")
150 duration = self._get_environ("time_duration")
151
152 # Collect some usage statistics
153 bytes_received = self._get_environ("bytes_received")
154 bytes_sent = self._get_environ("bytes_sent")
155
156 log.info("Closing session for %s after %ss and receiving/sending %s/%s bytes" \
157 % (common_name, duration, bytes_received, bytes_sent))
158
159 c = self.db.cursor()
160 c.execute("UPDATE sessions SET disconnected_at = DATETIME(connected_at, '+' || ? || ' seconds'), \
161 bytes_received = ?, bytes_sent = ? \
162 WHERE common_name = ? AND disconnected_at IS NULL",
163 (duration, bytes_received, bytes_sent, common_name))
164 self.db.commit()
165
166 def main():
167 m = OpenVPNMetrics()
168 m()
169
170 main()