]> git.ipfire.org Git - ipfire.org.git/blob - src/backend/accounts.py
228810860bcb37b148cf44799aafa624ab53f7a1
[ipfire.org.git] / src / backend / accounts.py
1 #!/usr/bin/python
2 # encoding: utf-8
3
4 import PIL
5 import io
6 import ldap
7 import logging
8 import urllib.parse
9 import urllib.request
10
11 from .decorators import *
12 from .misc import Object
13
14 class Accounts(Object):
15 def __iter__(self):
16 # Only return developers (group with ID 1000)
17 accounts = self._search("(&(objectClass=posixAccount)(gidNumber=1000))")
18
19 return iter(sorted(accounts))
20
21 @lazy_property
22 def ldap(self):
23 # Connect to LDAP server
24 ldap_uri = self.settings.get("ldap_uri")
25 conn = ldap.initialize(ldap_uri)
26
27 # Bind with username and password
28 bind_dn = self.settings.get("ldap_bind_dn")
29 if bind_dn:
30 bind_pw = self.settings.get("ldap_bind_pw", "")
31 conn.simple_bind(bind_dn, bind_pw)
32
33 return conn
34
35 def _query(self, query, attrlist=None, limit=0):
36 logging.debug("Performing LDAP query: %s" % query)
37
38 search_base = self.settings.get("ldap_search_base")
39
40 try:
41 results = self.ldap.search_ext_s(search_base, ldap.SCOPE_SUBTREE,
42 query, attrlist=attrlist, sizelimit=limit)
43 except:
44 # Close current connection
45 self.ldap.close()
46 del self.ldap
47
48 raise
49
50 return results
51
52 def _search(self, query, attrlist=None, limit=0):
53 accounts = []
54
55 for dn, attrs in self._query(query, attrlist=attrlist, limit=limit):
56 account = Account(self.backend, dn, attrs)
57 accounts.append(account)
58
59 return accounts
60
61 def search(self, query):
62 # Search for exact matches
63 accounts = self._search("(&(objectClass=posixAccount) \
64 (|(uid=%s)(mail=%s)(sipAuthenticationUser=%s)(telephoneNumber=%s)(homePhone=%s)(mobile=%s)))" \
65 % (query, query, query, query, query, query))
66
67 # Find accounts by name
68 if not accounts:
69 for account in self._search("(&(objectClass=posixAccount)(cn=*%s*))" % query):
70 if not account in accounts:
71 accounts.append(account)
72
73 return sorted(accounts)
74
75 def _search_one(self, query):
76 result = self._search(query, limit=1)
77 assert len(result) <= 1
78
79 if result:
80 return result[0]
81
82 def get_by_uid(self, uid):
83 return self._search_one("(&(objectClass=posixAccount)(uid=%s))" % uid)
84
85 def get_by_mail(self, mail):
86 return self._search_one("(&(objectClass=posixAccount)(mail=%s))" % mail)
87
88 find = get_by_uid
89
90 def find_account(self, s):
91 account = self.get_by_uid(s)
92 if account:
93 return account
94
95 return self.get_by_mail(s)
96
97 def get_by_sip_id(self, sip_id):
98 return self._search_one("(|(&(objectClass=sipUser)(sipAuthenticationUser=%s)) \
99 (&(objectClass=sipRoutingObject)(sipLocalAddress=%s)))" % (sip_id, sip_id))
100
101 def get_by_phone_number(self, number):
102 return self._search_one("(&(objectClass=posixAccount) \
103 (|(sipAuthenticationUser=%s)(telephoneNumber=%s)(homePhone=%s)(mobile=%s)))" \
104 % (number, number, number, number))
105
106 # Session stuff
107
108 def _cleanup_expired_sessions(self):
109 self.db.execute("DELETE FROM sessions WHERE time_expires <= NOW()")
110
111 def create_session(self, account, host):
112 self._cleanup_expired_sessions()
113
114 res = self.db.get("INSERT INTO sessions(host, uid) VALUES(%s, %s) \
115 RETURNING session_id, time_expires", host, account.uid)
116
117 # Session could not be created
118 if not res:
119 return None, None
120
121 logging.info("Created session %s for %s which expires %s" \
122 % (res.session_id, account, res.time_expires))
123 return res.session_id, res.time_expires
124
125 def destroy_session(self, session_id, host):
126 logging.info("Destroying session %s" % session_id)
127
128 self.db.execute("DELETE FROM sessions \
129 WHERE session_id = %s AND host = %s", session_id, host)
130 self._cleanup_expired_sessions()
131
132 def get_by_session(self, session_id, host):
133 logging.debug("Looking up session %s" % session_id)
134
135 res = self.db.get("SELECT uid FROM sessions WHERE session_id = %s \
136 AND host = %s AND NOW() BETWEEN time_created AND time_expires",
137 session_id, host)
138
139 # Session does not exist or has expired
140 if not res:
141 return
142
143 # Update the session expiration time
144 self.db.execute("UPDATE sessions SET time_expires = NOW() + INTERVAL '14 days' \
145 WHERE session_id = %s AND host = %s", session_id, host)
146
147 return self.get_by_uid(res.uid)
148
149
150 class Account(Object):
151 def __init__(self, backend, dn, attrs=None):
152 Object.__init__(self, backend)
153 self.dn = dn
154
155 self.__attrs = attrs or {}
156
157 def __str__(self):
158 return self.name
159
160 def __repr__(self):
161 return "<%s %s>" % (self.__class__.__name__, self.dn)
162
163 def __eq__(self, other):
164 if isinstance(other, self.__class__):
165 return self.dn == other.dn
166
167 def __lt__(self, other):
168 if isinstance(other, self.__class__):
169 return self.name < other.name
170
171 @property
172 def ldap(self):
173 return self.accounts.ldap
174
175 @property
176 def attributes(self):
177 return self.__attrs
178
179 def _get_first_attribute(self, attr, default=None):
180 if attr not in self.attributes:
181 return default
182
183 res = self.attributes.get(attr, [])
184 if res:
185 return res[0]
186
187 def get(self, key):
188 try:
189 attribute = self.attributes[key]
190 except KeyError:
191 raise AttributeError(key)
192
193 if len(attribute) == 1:
194 return attribute[0]
195
196 return attribute
197
198 def check_password(self, password):
199 """
200 Bind to the server with given credentials and return
201 true if password is corrent and false if not.
202
203 Raises exceptions from the server on any other errors.
204 """
205
206 logging.debug("Checking credentials for %s" % self.dn)
207 try:
208 self.ldap.simple_bind_s(self.dn, password.encode("utf-8"))
209 except ldap.INVALID_CREDENTIALS:
210 logging.debug("Account credentials are invalid.")
211 return False
212
213 logging.debug("Successfully authenticated.")
214 return True
215
216 def is_admin(self):
217 return "wheel" in self.groups
218
219 def is_talk_enabled(self):
220 return "sipUser" in self.classes or "sipRoutingObject" in self.classes \
221 or self.telephone_numbers or self.address
222
223 @property
224 def classes(self):
225 return (x.decode() for x in self.attributes.get("objectClass", []))
226
227 @property
228 def uid(self):
229 return self._get_first_attribute("uid").decode()
230
231 @property
232 def name(self):
233 return self._get_first_attribute("cn").decode()
234
235 @property
236 def first_name(self):
237 return self._get_first_attribute("givenName").decode()
238
239 @lazy_property
240 def groups(self):
241 groups = []
242
243 res = self.accounts._query("(&(objectClass=posixGroup) \
244 (memberUid=%s))" % self.uid, ["cn"])
245
246 for dn, attrs in res:
247 cns = attrs.get("cn")
248 if cns:
249 groups.append(cns[0].decode())
250
251 return groups
252
253 @property
254 def address(self):
255 address = self._get_first_attribute("homePostalAddress", "".encode()).decode()
256 address = address.replace(", ", "\n")
257
258 return address
259
260 @property
261 def email(self):
262 name = self.name.lower()
263 name = name.replace(" ", ".")
264 name = name.replace("Ä", "Ae")
265 name = name.replace("Ö", "Oe")
266 name = name.replace("Ü", "Ue")
267 name = name.replace("ä", "ae")
268 name = name.replace("ö", "oe")
269 name = name.replace("ü", "ue")
270
271 for mail in self.attributes.get("mail", []):
272 if mail.decode().startswith("%s@ipfire.org" % name):
273 return mail
274
275 # If everything else fails, we will go with the UID
276 return "%s@ipfire.org" % self.uid
277
278 @property
279 def sip_id(self):
280 if "sipUser" in self.classes:
281 return self._get_first_attribute("sipAuthenticationUser").decode()
282
283 if "sipRoutingObject" in self.classes:
284 return self._get_first_attribute("sipLocalAddress").decode()
285
286 @property
287 def sip_password(self):
288 return self._get_first_attribute("sipPassword").decode()
289
290 @property
291 def sip_url(self):
292 return "%s@ipfire.org" % self.sip_id
293
294 def uses_sip_forwarding(self):
295 if self.sip_routing_url:
296 return True
297
298 return False
299
300 @property
301 def sip_routing_url(self):
302 if "sipRoutingObject" in self.classes:
303 return self._get_first_attribute("sipRoutingAddress").decode()
304
305 @lazy_property
306 def sip_registrations(self):
307 sip_registrations = []
308
309 for reg in self.backend.talk.freeswitch.get_sip_registrations(self.sip_url):
310 reg.account = self
311
312 sip_registrations.append(reg)
313
314 return sip_registrations
315
316 def get_cdr(self, limit=None):
317 return self.backend.talk.freeswitch.get_cdr_by_account(self, limit=limit)
318
319 @property
320 def telephone_numbers(self):
321 return self._telephone_numbers + self.mobile_telephone_numbers \
322 + self.home_telephone_numbers
323
324 @property
325 def _telephone_numbers(self):
326 return self.attributes.get("telephoneNumber") or []
327
328 @property
329 def home_telephone_numbers(self):
330 return self.attributes.get("homePhone") or []
331
332 @property
333 def mobile_telephone_numbers(self):
334 return self.attributes.get("mobile") or []
335
336 def avatar_url(self, size=None):
337 if self.backend.debug:
338 hostname = "http://people.dev.ipfire.org"
339 else:
340 hostname = "https://people.ipfire.org"
341
342 url = "%s/users/%s.jpg" % (hostname, self.uid)
343
344 if size:
345 url += "?size=%s" % size
346
347 return url
348
349 def get_avatar(self, size=None):
350 avatar = self._get_first_attribute("jpegPhoto")
351 if not avatar:
352 return
353
354 if not size:
355 return avatar
356
357 return self._resize_avatar(avatar, size)
358
359 def _resize_avatar(self, image, size):
360 image = PIL.Image.open(io.BytesIO(image))
361
362 # Convert RGBA images into RGB because JPEG doesn't support alpha-channels
363 if image.mode == "RGBA":
364 image = image.convert("RGB")
365
366 # Resize the image to the desired resolution
367 image.thumbnail((size, size), PIL.Image.ANTIALIAS)
368
369 with io.BytesIO() as f:
370 # If writing out the image does not work with optimization,
371 # we try to write it out without any optimization.
372 try:
373 image.save(f, "JPEG", optimize=True, quality=98)
374 except:
375 image.save(f, "JPEG", quality=98)
376
377 return f.getvalue()
378
379
380 if __name__ == "__main__":
381 a = Accounts()
382
383 print(a.list())