]> git.ipfire.org Git - pbs.git/commitdiff
cache: Implement an async redis cache driver
authorMichael Tremer <michael.tremer@ipfire.org>
Sat, 5 Aug 2023 10:11:09 +0000 (10:11 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sat, 5 Aug 2023 10:11:09 +0000 (10:11 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/cache.py

index 9b371607eb10a388be8c64b04ce4ffe2bb36289f..0b70c4f296118212d3694433f7b5fefeca596992 100644 (file)
@@ -1,26 +1,83 @@
 #!/usr/bin/python3
 
+import asyncio
 import datetime
+import logging
 import pickle
-import redis
-
-from . import base
+import redis.asyncio
 
 from .decorators import *
 
-class Cache(base.Object):
-       @lazy_property
-       def redis(self):
+# Setup logging
+log = logging.getLogger("pbs.cache")
+
+class Cache(object):
+       def __init__(self, backend):
+               self.backend = backend
+
+               # Stores connections assigned to tasks
+               self.__connections = {}
+
+               # Create a connection pool
+               self.pool = redis.asyncio.connection.ConnectionPool.from_url(
+                       "redis://localhost:6379/0",
+               )
+
+       async def connection(self, *args, **kwargs):
                """
-                       Connects to a local redis server
+                       Returns a connection from the pool
                """
-               return redis.Redis(host="localhost", port=6379, db=0)
+               # Fetch the current task
+               task = asyncio.current_task()
+
+               assert task, "Could not determine task"
+
+               # Try returning the same connection to the same task
+               try:
+                       return self.__connections[task]
+               except KeyError:
+                       pass
+
+               # Fetch a new connection from the pool
+               conn = await redis.asyncio.Redis(
+                       connection_pool=self.pool,
+                       single_connection_client=True,
+               )
+
+               # Store the connection
+               self.__connections[task] = conn
+
+               log.debug("Assigning cache connection %s to %s" % (conn, task))
+
+               # When the task finishes, release the connection
+               task.add_done_callback(self.__release_connection)
 
-       def get(self, key):
+               return conn
+
+       def __release_connection(self, task):
+               loop = asyncio.get_running_loop()
+
+               # Retrieve the connection
+               try:
+                       conn = self.__connections[task]
+               except KeyError:
+                       return
+
+               log.debug("Releasing cache connection %s of %s" % (conn, task))
+
+               # Delete it
+               del self.__connections[task]
+
+               # Return the connection back into the pool
+               asyncio.run_coroutine_threadsafe(conn.close(), loop)
+
+       async def get(self, key):
                """
                        Fetches the value of a cached key
                """
-               value = self.redis.get(key)
+               conn = await self.connection()
+
+               value = await conn.get(key)
 
                # Nothing found
                if not value:
@@ -35,10 +92,12 @@ class Cache(base.Object):
                except pickle.UnpicklingError:
                        return
 
-       def set(self, key, value, expires=None):
+       async def set(self, key, value, expires=None):
                """
                        Puts something into the cache
                """
+               conn = await self.connection()
+
                # Figure out when this expires
                if expires and isinstance(expires, datetime.timedelta):
                        expires = expires.total_seconds()
@@ -47,10 +106,12 @@ class Cache(base.Object):
                value = pickle.dumps(value)
 
                # Send to redis
-               return self.redis.set(key, value, ex=expires)
+               return await conn.set(key, value, ex=expires)
 
-       def delete(self, key):
+       async def delete(self, key):
                """
                        Deletes the key from the cache
                """
-               return self.redis.delete(key)
+               conn = await self.connection()
+
+               return await conn.delete(key)