]> git.ipfire.org Git - pbs.git/commitdiff
sessions: Use the same design pattern as anything else
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 24 Oct 2017 20:47:06 +0000 (21:47 +0100)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 24 Oct 2017 20:47:06 +0000 (21:47 +0100)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/sessions.py
src/buildservice/users.py
src/web/handlers.py

index 18f8e4d25218881a63b4e75662bd25f47d98b1cb..c6214898ab299bab37ee2f95ec5b308d7ee0480c 100644 (file)
@@ -6,17 +6,21 @@ from . import users
 from .decorators import *
 
 class Sessions(base.Object):
-       def __iter__(self):
-               query = "SELECT * FROM sessions WHERE valid_until >= NOW() \
-                       ORDER BY valid_until DESC"
+       def _get_session(self, query, *args):
+               res = self.db.get(query, *args)
+
+               if res:
+                       return Session(self.backend, res.id, data=res)
+
+       def _get_sessions(self, query, *args):
+               res = self.db.query(query, *args)
 
-               sessions = []
-               for row in self.db.query(query):
-                       session = Session(self.backend, row.id, data=row)
-                       sessions.append(session)
+               for row in res:
+                       yield Session(self.backend, row.id, data=row)
 
-               # Sort
-               sessions.sort()
+       def __iter__(self):
+               sessions = self._get_sessions("SELECT * FROM sessions \
+                       WHERE valid_until >= NOW() ORDER BY valid_until DESC")
 
                return iter(sessions)
 
@@ -29,18 +33,13 @@ class Sessions(base.Object):
                """
                session_id = users.generate_random_string(48)
 
-               res = self.db.get("INSERT INTO sessions(session_id, user_id, address, user_agent) \
+               return self._get_session("INSERT INTO sessions(session_id, user_id, address, user_agent) \
                        VALUES(%s, %s, %s, %s) RETURNING *", session_id, user.id, address, user_agent)
 
-               return Session(self.backend, res.id, data=res)
-
        def get_by_session_id(self, session_id):
-               res = self.db.get("SELECT * FROM sessions \
+               return self._get_session("SELECT * FROM sessions \
                        WHERE session_id = %s AND valid_until >= NOW()", session_id)
 
-               if res:
-                       return Session(self.backend, res.id, data=res)
-
        # Alias function
        get = get_by_session_id
 
index 4b0269f7a6e4e45cebb6239a29657ca5acdcbe04..8d731d8e55a4645304f75e0c6fa9064689f13a9a 100644 (file)
@@ -246,6 +246,9 @@ class User(base.DataObject):
        def __repr__(self):
                return "<%s %s>" % (self.__class__.__name__, self.realname)
 
+       def __hash__(self):
+               return hash(self.id)
+
        def __eq__(self, other):
                if isinstance(other, self.__class__):
                        return self.id == other.id
@@ -446,6 +449,11 @@ class User(base.DataObject):
                # All others must be checked individually.
                return self.perms.get(perm, False) == True
 
+       @property
+       def sessions(self):
+               return self.backend.sessions._get_sessions("SELECT * FROM sessions \
+                       WHERE user_id = %s AND valid_until >= NOW() ORDER BY created_at")
+
 
 class UserEmail(base.DataObject):
        table = "users_emails"
index a92ab0bb18cd4e0cc57568afe09e3b874f00d766..b6313dd64953108bb53f63e35af3dd647421acaf 100644 (file)
@@ -111,6 +111,7 @@ class SessionsHandler(BaseHandler):
                users = {}
 
                for s in self.backend.sessions:
+                       print s.user, s.user in users
                        try:
                                users[s.user].append(s)
                        except KeyError: