]> git.ipfire.org Git - ipfire.org.git/blobdiff - src/backend/accounts.py
CSS: Add CSS for file listings
[ipfire.org.git] / src / backend / accounts.py
index ffbf2a0cabb0c86763d0288e4e16cfce4c58ccbd..215bfede6dfd09b95e3af0336c40ee3fd9e2e412 100644 (file)
@@ -2,74 +2,94 @@
 # encoding: utf-8
 
 import PIL
-import StringIO
-import hashlib
+import PIL.ImageOps
+import datetime
+import io
 import ldap
+import ldap.modlist
 import logging
-import urllib
+import phonenumbers
+import sshpubkeys
+import urllib.parse
+import urllib.request
+import zxcvbn
 
-from misc import Object
+from . import util
+from .decorators import *
+from .misc import Object
 
 class Accounts(Object):
-       @property
+       def __iter__(self):
+               # Only return developers (group with ID 1000)
+               accounts = self._search("(&(objectClass=posixAccount)(gidNumber=1000))")
+
+               return iter(sorted(accounts))
+
+       @lazy_property
        def ldap(self):
-               if not hasattr(self, "_ldap"):
-                       # Connect to LDAP server
-                       ldap_uri = self.settings.get("ldap_uri")
-                       self._ldap = ldap.initialize(ldap_uri)
+               # Connect to LDAP server
+               ldap_uri = self.settings.get("ldap_uri")
+               conn = ldap.initialize(ldap_uri)
 
-                       # Bind with username and password
-                       bind_dn = self.settings.get("ldap_bind_dn")
-                       if bind_dn:
-                               bind_pw = self.settings.get("ldap_bind_pw", "")
-                               self._ldap.simple_bind(bind_dn, bind_pw)
+               # Bind with username and password
+               bind_dn = self.settings.get("ldap_bind_dn")
+               if bind_dn:
+                       bind_pw = self.settings.get("ldap_bind_pw", "")
+                       conn.simple_bind(bind_dn, bind_pw)
 
-               return self._ldap
+               return conn
 
-       def _search(self, query, attrlist=None, limit=0):
+       def _query(self, query, attrlist=None, limit=0):
                logging.debug("Performing LDAP query: %s" % query)
 
                search_base = self.settings.get("ldap_search_base")
 
-                try:
-                    results = self.ldap.search_ext_s(search_base, ldap.SCOPE_SUBTREE,
-                            query, attrlist=attrlist, sizelimit=limit)
-                except:
-                    # Close current connection
-                    del self._ldap
+               try:
+                       results = self.ldap.search_ext_s(search_base, ldap.SCOPE_SUBTREE,
+                               query, attrlist=attrlist, sizelimit=limit)
+               except:
+                       # Close current connection
+                       del self.ldap
 
-                    raise
+                       raise
 
                return results
 
-       def search(self, query, limit=0):
-               results = self._search(query, limit=limit)
-
+       def _search(self, query, attrlist=None, limit=0):
                accounts = []
-               for dn, attrs in results:
+
+               for dn, attrs in self._query(query, attrlist=attrlist, limit=limit):
                        account = Account(self.backend, dn, attrs)
                        accounts.append(account)
 
+               return accounts
+
+       def search(self, query):
+               # Search for exact matches
+               accounts = self._search("(&(objectClass=posixAccount) \
+                       (|(uid=%s)(mail=%s)(sipAuthenticationUser=%s)(telephoneNumber=%s)(homePhone=%s)(mobile=%s)))" \
+                       % (query, query, query, query, query, query))
+
+               # Find accounts by name
+               if not accounts:
+                       for account in self._search("(&(objectClass=posixAccount)(|(cn=*%s*)(uid=*%s*)))" % (query, query)):
+                               if not account in accounts:
+                                       accounts.append(account)
+
                return sorted(accounts)
 
-       def search_one(self, query):
-               result = self.search(query, limit=1)
+       def _search_one(self, query):
+               result = self._search(query, limit=1)
                assert len(result) <= 1
 
                if result:
                        return result[0]
 
-       def get_all(self):
-               # Only return developers (group with ID 500)
-               return self.search("(&(objectClass=posixAccount)(gidNumber=500))")
-
-       list = get_all
-
        def get_by_uid(self, uid):
-               return self.search_one("(&(objectClass=posixAccount)(uid=%s))" % uid)
+               return self._search_one("(&(objectClass=posixAccount)(uid=%s))" % uid)
 
        def get_by_mail(self, mail):
-               return self.search_one("(&(objectClass=posixAccount)(mail=%s))" % mail)
+               return self._search_one("(&(objectClass=posixAccount)(mail=%s))" % mail)
 
        find = get_by_uid
 
@@ -81,9 +101,14 @@ class Accounts(Object):
                return self.get_by_mail(s)
 
        def get_by_sip_id(self, sip_id):
-               return self.search_one("(|(&(objectClass=sipUser)(sipAuthenticationUser=%s)) \
+               return self._search_one("(|(&(objectClass=sipUser)(sipAuthenticationUser=%s)) \
                        (&(objectClass=sipRoutingObject)(sipLocalAddress=%s)))" % (sip_id, sip_id))
 
+       def get_by_phone_number(self, number):
+               return self._search_one("(&(objectClass=posixAccount) \
+                       (|(sipAuthenticationUser=%s)(telephoneNumber=%s)(homePhone=%s)(mobile=%s)))" \
+                       % (number, number, number, number))
+
        # Session stuff
 
        def _cleanup_expired_sessions(self):
@@ -133,40 +158,132 @@ class Account(Object):
                Object.__init__(self, backend)
                self.dn = dn
 
-               self.__attrs = attrs or {}
+               self.attributes = attrs or {}
+
+       def __str__(self):
+               return self.name
 
        def __repr__(self):
                return "<%s %s>" % (self.__class__.__name__, self.dn)
 
-       def __cmp__(self, other):
-               return cmp(self.name, other.name)
+       def __eq__(self, other):
+               if isinstance(other, self.__class__):
+                       return self.dn == other.dn
+
+       def __lt__(self, other):
+               if isinstance(other, self.__class__):
+                       return self.name < other.name
 
        @property
        def ldap(self):
                return self.accounts.ldap
 
-       @property
-       def attributes(self):
-               return self.__attrs
+       def _exists(self, key):
+               try:
+                       self.attributes[key]
+               except KeyError:
+                       return False
 
-       def _get_first_attribute(self, attr, default=None):
-               if not self.attributes.has_key(attr):
-                       return default
+               return True
 
-               res = self.attributes.get(attr, [])
-               if res:
-                       return res[0]
+       def _get(self, key):
+               for value in self.attributes.get(key, []):
+                       yield value
 
-       def get(self, key):
-               try:
-                       attribute = self.attributes[key]
-               except KeyError:
-                       raise AttributeError(key)
+       def _get_bytes(self, key, default=None):
+               for value in self._get(key):
+                       return value
+
+               return default
+
+       def _get_strings(self, key):
+               for value in self._get(key):
+                       yield value.decode()
+
+       def _get_string(self, key, default=None):
+               for value in self._get_strings(key):
+                       return value
+
+               return default
+
+       def _get_phone_numbers(self, key):
+               for value in self._get_strings(key):
+                       yield phonenumbers.parse(value, None)
+
+       def _modify(self, modlist):
+               logging.debug("Modifying %s: %s" % (self.dn, modlist))
+
+               # Run modify operation
+               self.ldap.modify_s(self.dn, modlist)
+
+       def _set(self, key, values):
+               current = self._get(key)
+
+               # Don't do anything if nothing has changed
+               if list(current) == values:
+                       return
+
+               # Remove all old values and add all new ones
+               modlist = []
 
-               if len(attribute) == 1:
-                       return attribute[0]
+               if self._exists(key):
+                       modlist.append((ldap.MOD_DELETE, key, None))
 
-               return attribute
+               # Add new values
+               if values:
+                       modlist.append((ldap.MOD_ADD, key, values))
+
+               # Run modify operation
+               self._modify(modlist)
+
+               # Update cache
+               self.attributes.update({ key : values })
+
+       def _set_bytes(self, key, values):
+               return self._set(key, values)
+
+       def _set_strings(self, key, values):
+               return self._set(key, [e.encode() for e in values if e])
+
+       def _set_string(self, key, value):
+               return self._set_strings(key, [value,])
+
+       def _add(self, key, values):
+               modlist = [
+                       (ldap.MOD_ADD, key, values),
+               ]
+
+               self._modify(modlist)
+
+       def _add_strings(self, key, values):
+               return self._add(key, [e.encode() for e in values])
+
+       def _add_string(self, key, value):
+               return self._add_strings(key, [value,])
+
+       def _delete(self, key, values):
+               modlist = [
+                       (ldap.MOD_DELETE, key, values),
+               ]
+
+               self._modify(modlist)
+
+       def _delete_strings(self, key, values):
+               return self._delete(key, [e.encode() for e in values])
+
+       def _delete_string(self, key, value):
+               return self._delete_strings(key, [value,])
+
+       def passwd(self, password):
+               """
+                       Sets a new password
+               """
+               # The new password must have a score of 3 or better
+               quality = self.check_password_quality(password)
+               if quality["score"] < 3:
+                       raise ValueError("Password too weak")
+
+               self.ldap.passwd_s(self.dn, None, password)
 
        def check_password(self, password):
                """
@@ -175,17 +292,34 @@ class Account(Object):
 
                        Raises exceptions from the server on any other errors.
                """
+               if not password:
+                       return
 
                logging.debug("Checking credentials for %s" % self.dn)
+
+               # Create a new LDAP connection
+               ldap_uri = self.backend.settings.get("ldap_uri")
+               conn = ldap.initialize(ldap_uri)
+
                try:
-                       self.ldap.simple_bind_s(self.dn, password.encode("utf-8"))
+                       conn.simple_bind_s(self.dn, password.encode("utf-8"))
                except ldap.INVALID_CREDENTIALS:
-                       logging.debug("Account credentials are invalid.")
+                       logging.debug("Account credentials are invalid for %s" % self)
                        return False
 
-               logging.debug("Successfully authenticated.")
+               logging.info("Successfully authenticated %s" % self)
+
                return True
 
+       def check_password_quality(self, password):
+               """
+                       Passwords are passed through zxcvbn to make sure
+                       that they are strong enough.
+               """
+               return zxcvbn.zxcvbn(password, user_inputs=(
+                       self.first_name, self.last_name,
+               ))
+
        def is_admin(self):
                return "wheel" in self.groups
 
@@ -193,43 +327,85 @@ class Account(Object):
                return "sipUser" in self.classes or "sipRoutingObject" in self.classes \
                        or self.telephone_numbers or self.address
 
+       def can_be_managed_by(self, account):
+               """
+                       Returns True if account is allowed to manage this account
+               """
+               # Admins can manage all accounts
+               if account.is_admin():
+                       return True
+
+               # Users can manage themselves
+               return self == account
+
        @property
        def classes(self):
-               return self.attributes.get("objectClass", [])
+               return self._get_strings("objectClass")
 
        @property
        def uid(self):
-               return self._get_first_attribute("uid")
+               return self._get_string("uid")
 
        @property
        def name(self):
-               return self._get_first_attribute("cn")
+               return self._get_string("cn")
 
-       @property
-       def first_name(self):
-               return self._get_first_attribute("givenName")
+       # First Name
 
-       @property
+       def get_first_name(self):
+               return self._get_string("givenName")
+
+       def set_first_name(self, first_name):
+               self._set_string("givenName", first_name)
+
+               # Update Common Name
+               self._set_string("cn", "%s %s" % (first_name, self.last_name))
+
+       first_name = property(get_first_name, set_first_name)
+
+       # Last Name
+
+       def get_last_name(self):
+               return self._get_string("sn")
+
+       def set_last_name(self, last_name):
+               self._set_string("sn", last_name)
+
+               # Update Common Name
+               self._set_string("cn", "%s %s" % (self.first_name, last_name))
+
+       last_name = property(get_last_name, set_last_name)
+
+       @lazy_property
        def groups(self):
-               if not hasattr(self, "_groups"):
-                       self._groups = []
+               groups = []
 
-                       res = self.accounts._search("(&(objectClass=posixGroup) \
-                               (memberUid=%s))" % self.uid, ["cn"])
+               res = self.accounts._query("(&(objectClass=posixGroup) \
+                       (memberUid=%s))" % self.uid, ["cn"])
 
-                       for dn, attrs in res:
-                               cns = attrs.get("cn")
-                               if cns:
-                                       self._groups.append(cns[0])
+               for dn, attrs in res:
+                       cns = attrs.get("cn")
+                       if cns:
+                               groups.append(cns[0].decode())
 
-               return self._groups
+               return groups
 
-       @property
-       def address(self):
-               address = self._get_first_attribute("homePostalAddress", "")
-               address = address.replace(", ", "\n")
+       # Address
+
+       def get_address(self):
+               address = self._get_string("homePostalAddress")
+
+               if address:
+                       return (line.strip() for line in address.split(","))
+
+               return []
 
-               return address
+       def set_address(self, address):
+               data = ", ".join(address.splitlines())
+
+               self._set_bytes("homePostalAddress", data.encode())
+
+       address = property(get_address, set_address)
 
        @property
        def email(self):
@@ -243,81 +419,188 @@ class Account(Object):
                name = name.replace("ΓΌ", "ue")
 
                for mail in self.attributes.get("mail", []):
-                       if mail.startswith("%s@ipfire.org" % name):
+                       if mail.decode().startswith("%s@ipfire.org" % name):
                                return mail
 
                # If everything else fails, we will go with the UID
                return "%s@ipfire.org" % self.uid
 
+       # Mail Routing Address
+
+       def get_mail_routing_address(self):
+               return self._get_string("mailRoutingAddress", None)
+
+       def set_mail_routing_address(self, address):
+               self._set_string("mailRoutingAddress", address or None)
+
+       mail_routing_address = property(get_mail_routing_address, set_mail_routing_address)
+
        @property
        def sip_id(self):
                if "sipUser" in self.classes:
-                       return self._get_first_attribute("sipAuthenticationUser")
+                       return self._get_string("sipAuthenticationUser")
 
                if "sipRoutingObject" in self.classes:
-                       return self._get_first_attribute("sipLocalAddress")
+                       return self._get_string("sipLocalAddress")
 
        @property
        def sip_password(self):
-               return self._get_first_attribute("sipPassword")
+               return self._get_string("sipPassword")
+
+       @staticmethod
+       def _generate_sip_password():
+               return util.random_string(8)
 
        @property
        def sip_url(self):
                return "%s@ipfire.org" % self.sip_id
 
        def uses_sip_forwarding(self):
-               if self.sip_routing_url:
+               if self.sip_routing_address:
                        return True
 
                return False
 
-       @property
-       def sip_routing_url(self):
+       # SIP Routing
+
+       def get_sip_routing_address(self):
                if "sipRoutingObject" in self.classes:
-                       return self._get_first_attribute("sipRoutingAddress")
+                       return self._get_string("sipRoutingAddress")
 
-       def sip_is_online(self):
-               assert self.sip_id
+       def set_sip_routing_address(self, address):
+               if not address:
+                       address = None
 
-               if not hasattr(self, "_is_online"):
-                       self._is_online = self.backend.talk.user_is_online(self.sip_id)
+               # Don't do anything if nothing has changed
+               if self.get_sip_routing_address() == address:
+                       return
 
-               return self._is_online
+               if address:
+                       modlist = [
+                               # This is no longer a SIP user any more
+                               (ldap.MOD_DELETE, "objectClass", b"sipUser"),
+                               (ldap.MOD_DELETE, "sipAuthenticationUser", None),
+                               (ldap.MOD_DELETE, "sipPassword", None),
+
+                               (ldap.MOD_ADD, "objectClass", b"sipRoutingObject"),
+                               (ldap.MOD_ADD, "sipLocalAddress", self.sip_id.encode()),
+                               (ldap.MOD_ADD, "sipRoutingAddress", address.encode()),
+                       ]
+               else:
+                       modlist = [
+                               (ldap.MOD_DELETE, "objectClass", b"sipRoutingObject"),
+                               (ldap.MOD_DELETE, "sipLocalAddress", None),
+                               (ldap.MOD_DELETE, "sipRoutingAddress", None),
 
-       @property
-       def telephone_numbers(self):
-               return self._telephone_numbers + self.mobile_telephone_numbers \
-                       + self.home_telephone_numbers
+                               (ldap.MOD_ADD, "objectClass", b"sipUser"),
+                               (ldap.MOD_ADD, "sipAuthenticationUser", self.sip_id.encode()),
+                               (ldap.MOD_ADD, "sipPassword", self._generate_sip_password().encode()),
+                       ]
 
-       @property
-       def _telephone_numbers(self):
-               return self.attributes.get("telephoneNumber") or []
+               # Run modification
+               self._modify(modlist)
 
-       @property
-       def home_telephone_numbers(self):
-               return self.attributes.get("homePhone") or []
+               # XXX Cache is invalid here
+
+       sip_routing_address = property(get_sip_routing_address, set_sip_routing_address)
+
+       @lazy_property
+       def sip_registrations(self):
+               sip_registrations = []
+
+               for reg in self.backend.talk.freeswitch.get_sip_registrations(self.sip_url):
+                       reg.account = self
+
+                       sip_registrations.append(reg)
+
+               return sip_registrations
+
+       @lazy_property
+       def sip_channels(self):
+               return self.backend.talk.freeswitch.get_sip_channels(self)
+
+       def get_cdr(self, date=None, limit=None):
+               return self.backend.talk.freeswitch.get_cdr_by_account(self, date=date, limit=limit)
+
+       # Phone Numbers
+
+       @lazy_property
+       def phone_number(self):
+               """
+                       Returns the IPFire phone number
+               """
+               if self.sip_id:
+                       return phonenumbers.parse("+4923636035%s" % self.sip_id)
+
+       @lazy_property
+       def fax_number(self):
+               if self.sip_id:
+                       return phonenumbers.parse("+49236360359%s" % self.sip_id)
+
+       def get_phone_numbers(self):
+               ret = []
+
+               for field in ("telephoneNumber", "homePhone", "mobile"):
+                       for number in self._get_phone_numbers(field):
+                               ret.append(number)
+
+               return ret
+
+       def set_phone_numbers(self, phone_numbers):
+               # Sort phone numbers by landline and mobile
+               _landline_numbers = []
+               _mobile_numbers = []
+
+               for number in phone_numbers:
+                       try:
+                               number = phonenumbers.parse(number, None)
+                       except phonenumbers.phonenumberutil.NumberParseException:
+                               continue
+
+                       # Convert to string (in E.164 format)
+                       s = phonenumbers.format_number(number, phonenumbers.PhoneNumberFormat.E164)
+
+                       # Separate mobile numbers
+                       if phonenumbers.number_type(number) == phonenumbers.PhoneNumberType.MOBILE:
+                               _mobile_numbers.append(s)
+                       else:
+                               _landline_numbers.append(s)
+
+               # Save
+               self._set_strings("telephoneNumber", _landline_numbers)
+               self._set_strings("mobile", _mobile_numbers)
+
+       phone_numbers = property(get_phone_numbers, set_phone_numbers)
 
        @property
-       def mobile_telephone_numbers(self):
-               return self.attributes.get("mobile") or []
+       def _all_telephone_numbers(self):
+               ret = [ self.sip_id, ]
+
+               if self.phone_number:
+                       s = phonenumbers.format_number(self.phone_number, phonenumbers.PhoneNumberFormat.E164)
+                       ret.append(s)
+
+               for number in self.phone_numbers:
+                       s = phonenumbers.format_number(number, phonenumbers.PhoneNumberFormat.E164)
+                       ret.append(s)
+
+               return ret
 
        def avatar_url(self, size=None):
                if self.backend.debug:
-                       hostname = "accounts.dev.ipfire.org"
+                       hostname = "http://people.dev.ipfire.org"
                else:
-                       hostname = "accounts.ipfire.org"
+                       hostname = "https://people.ipfire.org"
 
-               url = "https://%s/avatar/%s.jpg" % (hostname, self.uid)
+               url = "%s/users/%s.jpg" % (hostname, self.uid)
 
                if size:
                        url += "?size=%s" % size
 
                return url
 
-       gravatar_icon = avatar_url
-
        def get_avatar(self, size=None):
-               avatar = self._get_first_attribute("jpegPhoto")
+               avatar = self._get_bytes("jpegPhoto")
                if not avatar:
                        return
 
@@ -327,38 +610,88 @@ class Account(Object):
                return self._resize_avatar(avatar, size)
 
        def _resize_avatar(self, image, size):
-               image = StringIO.StringIO(image)
-               image = PIL.Image.open(image)
+               image = PIL.Image.open(io.BytesIO(image))
 
-               # Resize the image to the desired resolution
-               image.thumbnail((size, size), PIL.Image.ANTIALIAS)
+               # Convert RGBA images into RGB because JPEG doesn't support alpha-channels
+               if image.mode == "RGBA":
+                       image = image.convert("RGB")
 
-               f = StringIO.StringIO()
+               # Resize the image to the desired resolution (and make it square)
+               thumbnail = PIL.ImageOps.fit(image, (size, size), PIL.Image.ANTIALIAS)
 
-               # If writing out the image does not work with optimization,
-               # we try to write it out without any optimization.
-               try:
-                       image.save(f, "JPEG", optimize=True)
-               except:
-                       image.save(f, "JPEG")
+               with io.BytesIO() as f:
+                       # If writing out the image does not work with optimization,
+                       # we try to write it out without any optimization.
+                       try:
+                               thumbnail.save(f, "JPEG", optimize=True, quality=98)
+                       except:
+                               thumbnail.save(f, "JPEG", quality=98)
 
-               return f.getvalue()
+                       return f.getvalue()
 
-       def get_gravatar_url(self, size=128):
-               try:
-                       gravatar_email = self.email.lower()
-               except:
-                       gravatar_email = "nobody@ipfire.org"
+       def upload_avatar(self, avatar):
+               self._set("jpegPhoto", avatar)
 
-               # construct the url
-               gravatar_url = "https://www.gravatar.com/avatar/" + \
-                       hashlib.md5(gravatar_email).hexdigest() + "?"
-               gravatar_url += urllib.urlencode({'d': "mm", 's': str(size)})
+       # SSH Keys
+
+       @lazy_property
+       def ssh_keys(self):
+               ret = []
+
+               for key in self._get_strings("sshPublicKey"):
+                       s = sshpubkeys.SSHKey()
+
+                       try:
+                               s.parse(key)
+                       except (sshpubkeys.InvalidKeyError, NotImplementedError) as e:
+                               logging.warning("Could not parse SSH key %s: %s" % (key, e))
+                               continue
+
+                       ret.append(s)
+
+               return ret
+
+       def get_ssh_key_by_hash_sha256(self, hash_sha256):
+               for key in self.ssh_keys:
+                       if not key.hash_sha256() == hash_sha256:
+                               continue
+
+                       return key
+
+       def add_ssh_key(self, key):
+               k = sshpubkeys.SSHKey()
+
+               # Try to parse the key
+               k.parse(key)
+
+               # Check for types and sufficient sizes
+               if k.key_type == b"ssh-rsa":
+                       if k.bits < 4096:
+                               raise sshpubkeys.TooShortKeyError("RSA keys cannot be smaller than 4096 bits")
+
+               elif k.key_type == b"ssh-dss":
+                       raise sshpubkeys.InvalidKeyError("DSA keys are not supported")
+
+               # Ignore any duplicates
+               if key in (k.keydata for k in self.ssh_keys):
+                       logging.debug("SSH Key has already been added for %s: %s" % (self, key))
+                       return
+
+               # Save key to LDAP
+               self._add_string("sshPublicKey", key)
+
+               # Append to cache
+               self.ssh_keys.append(k)
+
+       def delete_ssh_key(self, key):
+               if not key in (k.keydata for k in self.ssh_keys):
+                       return
 
-               return gravatar_url
+               # Delete key from LDAP
+               self._delete_string("sshPublicKey", key)
 
 
 if __name__ == "__main__":
        a = Accounts()
 
-       print a.list()
+       print(a.list())