]> git.ipfire.org Git - dbl.git/commitdiff
Make the entire application async
authorMichael Tremer <michael.tremer@ipfire.org>
Fri, 20 Feb 2026 15:40:03 +0000 (15:40 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Fri, 20 Feb 2026 15:40:03 +0000 (15:40 +0000)
This was needed because we seem to have some issues with the API
starting the next request or trying to access the database after the
session has been closed. Instead we will now have a single database
session per task which can be managed easier and will only be closed
after the entire task has completed.

As another benefit, we are now able to run many requests simultaneously.
So far this has not been a big bottleneck, but some operations (like
closing a report) can take a moment and would therefore have been
blocking other requests.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
13 files changed:
src/dbl/__init__.py
src/dbl/api/__init__.py
src/dbl/api/lists.py
src/dbl/api/middlewares.py
src/dbl/api/reports.py
src/dbl/auth.py
src/dbl/database.py
src/dbl/domains.py
src/dbl/exporters.py
src/dbl/lists.py
src/dbl/reports.py
src/dbl/sources.py
src/scripts/dbl.in

index eaf374ac62b6cf41871e4e7ce418f2da54e5d9f8..43ff8ea041261ead86b299181af0cef4da8b8bb4 100644 (file)
@@ -18,6 +18,7 @@
 #                                                                             #
 ###############################################################################
 
+import asyncio
 import configparser
 import functools
 import httpx
@@ -117,6 +118,24 @@ class Backend(object):
                # Put everything back together again
                return "%s[.]%s" % (name, tld)
 
+       # Functions to run something in background
+
+       # A list of any background tasks
+       __tasks = set()
+
+       def run_task(self, callback, *args, **kwargs):
+               """
+                       Runs the given coroutine in the background
+               """
+               # Create a new task
+               task = asyncio.create_task(callback(*args, **kwargs))
+
+               # Keep a reference to the task and remove it when the task has finished
+               self.__tasks.add(task)
+               task.add_done_callback(self.__tasks.discard)
+
+               return task
+
        @functools.cached_property
        def auth(self):
                return auth.Auth(self)
@@ -177,7 +196,7 @@ class Backend(object):
 
                return conn
 
-       def search(self, name):
+       async def search(self, name):
                """
                        Searches for a domain
                """
@@ -211,7 +230,7 @@ class Backend(object):
                res = {}
 
                # Group all matches by list
-               for domain in self.db.fetch(stmt):
+               async for domain in self.db.fetch(stmt):
                        try:
                                res[domain.list].append(domain.name)
                        except KeyError:
@@ -226,13 +245,13 @@ class Backend(object):
                        domain = matches.pop()
 
                        # Fetch the history of the longest match
-                       for history in list.get_domain_history(domain, limit=1):
+                       async for history in list.get_domain_history(domain, limit=1):
                                res[list] = history
                                break
 
                return res
 
-       def check(self, domain):
+       async def check(self, domain):
                """
                        Returns the status of a domain
                """
@@ -246,4 +265,4 @@ class Backend(object):
                        )
                )
 
-               return not self.db.fetch_one(stmt)
+               return not await self.db.fetch_one(stmt)
index 53248033be9ac189fbeead65f2820e6761e5e7d9..9720a9823d387694fccc895f421dcb7fbffdf4a2 100644 (file)
@@ -60,7 +60,7 @@ from . import reports
 # Search
 
 @app.get("/search")
-def search(q: str):
+async def search(q: str):
        """
                Performs a simple search
        """
@@ -71,7 +71,7 @@ def search(q: str):
                raise fastapi.HTTPException(400, "Not a valid FQDN: %s" % q)
 
        # Perform the search
-       results = backend.search(q)
+       results = await backend.search(q)
 
        # Format the result
        for list, event in results.items():
index 48550bd33ba58a272e897367b0d544849e28912c..cc2722e9d4902e7e51eb35837cc45b567b1cf46f 100644 (file)
@@ -39,25 +39,25 @@ router = fastapi.APIRouter(
        tags=["Lists"],
 )
 
-def get_list_from_path(list: str = fastapi.Path(...)) -> lists.List:
+async def get_list_from_path(list: str = fastapi.Path(...)) -> lists.List:
        """
                Fetches a list by its slug
        """
-       return backend.lists.get_by_slug(list)
+       return await backend.lists.get_by_slug(list)
 
 @router.get("")
-def get_lists() -> typing.List[lists.List]:
-       return [l for l in backend.lists]
+async def get_lists() -> typing.List[lists.List]:
+       return [l async for l in backend.lists]
 
 @router.get("/{list}")
-def get_list(list = fastapi.Depends(get_list_from_path)) -> lists.List:
+async def get_list(list = fastapi.Depends(get_list_from_path)) -> lists.List:
        if not list:
                raise fastapi.HTTPException(404, "Could not find list")
 
        return list
 
 @router.get("/{list}/history")
-def get_list_history(
+async def get_list_history(
                list = fastapi.Depends(get_list_from_path),
                before: datetime.datetime = None,
                limit: int = 10,
@@ -69,28 +69,28 @@ def get_list_history(
                        "domains_blocked" : e.domains_blocked,
                        "domains_allowed" : e.domains_allowed,
                        "domains_removed" : e.domains_removed,
-               } for e in list.get_history(before=before, limit=limit)
+               } async for e in list.get_history(before=before, limit=limit)
        ]
 
 @router.get("/{list}/sources")
-def get_list_sources(list = fastapi.Depends(get_list_from_path)) -> typing.List[sources.Source]:
+async def get_list_sources(list = fastapi.Depends(get_list_from_path)) -> typing.List[sources.Source]:
        return list.sources
 
 @router.get("/{list}/reports")
-def get_list_reports(
+async def get_list_reports(
                list = fastapi.Depends(get_list_from_path),
                open: bool | None = None,
                name: str | None = None,
                limit: int | None = None
 ) -> typing.List[reports.Report]:
-       return list.get_reports(open=open, name=name, limit=limit)
+       return await list.get_reports(open=open, name=name, limit=limit)
 
 @router.get("/{list}/domains/{name}")
-def get_list_domains(
+async def get_list_domains(
                name: str, list = fastapi.Depends(get_list_from_path),
 ) -> typing.List[domains.DomainEvent]:
        # Fetch the domain history
-       return list.get_domain_history(name)
+       return [e async for e in list.get_domain_history(name)]
 
 
 class CreateReport(pydantic.BaseModel):
@@ -108,19 +108,17 @@ class CreateReport(pydantic.BaseModel):
 
 
 @router.post("/{list}/reports")
-def list_report(
+async def list_report(
        report: CreateReport,
        auth = fastapi.Depends(require_api_key),
        list = fastapi.Depends(get_list_from_path),
 ) -> reports.Report:
-       # Create a new report
-       with backend.db:
-               return list.report(
-                       name        = report.name,
-                       reported_by = report.reported_by,
-                       comment     = report.comment,
-                       block       = report.block,
-               )
+       return await list.report(
+               name        = report.name,
+               reported_by = report.reported_by,
+               comment     = report.comment,
+               block       = report.block,
+       )
 
 # Include our endpoints
 app.include_router(router)
index 6a59f66389716ac6a81c9bcdfa39c2fb3f7ec461..36e249fd0f437e818d3427841e30732f4d1fce30 100644 (file)
@@ -25,5 +25,5 @@ class DatabaseSessionMiddleware(fastapi.applications.BaseHTTPMiddleware):
                backend = request.app.state.backend
 
                # Acquire a new database session and process the request
-               with backend.db.session() as session:
+               async with await backend.db.session() as session:
                        return await call_next(request)
index eb241c9cb517a3794b7c362d504e4f225d52e175..cb13a11fbc0f18d367f92c87825c828ead62cc08 100644 (file)
@@ -35,12 +35,12 @@ router = fastapi.APIRouter(
        tags=["Reports"],
 )
 
-def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report:
+async def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report:
        """
                Fetches a report by its ID
        """
        # Fetch the report
-       report = backend.reports.get_by_id(id)
+       report = await backend.reports.get_by_id(id)
 
        # Raise 404 if we could not find the report
        if not report:
@@ -49,11 +49,11 @@ def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report:
        return report
 
 @router.get("")
-def get_reports() -> typing.List[reports.Report]:
-       return [l for l in backend.reports]
+async def get_reports() -> typing.List[reports.Report]:
+       return [l async for l in backend.reports]
 
 @router.get("/{id}")
-def get_report(report = fastapi.Depends(get_report_from_path)) -> reports.Report:
+async def get_report(report = fastapi.Depends(get_report_from_path)) -> reports.Report:
        return report
 
 class CloseReport(pydantic.BaseModel):
@@ -64,16 +64,14 @@ class CloseReport(pydantic.BaseModel):
        accept: bool = True
 
 @router.post("/{id}/close")
-def close_report(
+async def close_report(
                data: CloseReport,
                report: reports.Report = fastapi.Depends(get_report_from_path),
 ) -> fastapi.Response:
-       # Close the report
-       with backend.db:
-               report.close(
-                       closed_by = data.closed_by,
-                       accept    = data.accept,
-               )
+       await report.close(
+               closed_by = data.closed_by,
+               accept    = data.accept,
+       )
 
        # Send 204
        return fastapi.Response(status_code=fastapi.status.HTTP_204_NO_CONTENT)
index 2c0440c765740926ac177804b6beabfaa1973f56..4159a766e7a15f8530642966b806c61794359741 100644 (file)
@@ -42,12 +42,12 @@ class Auth(object):
        def __init__(self, backend):
                self.backend = backend
 
-       def __call__(self, key):
+       async def __call__(self, key):
                """
                        The main authentication function which takes an API key
                        and checks it against the database.
                """
-               key = self.get_key(key)
+               key = await self.get_key(key)
 
                # If we have found a key, we have successfully authenticated
                if key:
@@ -56,7 +56,7 @@ class Auth(object):
 
                return False
 
-       def get_key(self, key):
+       async def get_key(self, key):
                """
                        Fetches a specific key
                """
@@ -72,7 +72,7 @@ class Auth(object):
                keys = self.get_keys(prefix)
 
                # Check all keys
-               for key in keys:
+               async for key in keys:
                        # Return True if the key matches the secret
                        if key.check(secret):
                                return key
@@ -92,9 +92,9 @@ class Auth(object):
                        )
                )
 
-               return self.backend.db.fetch_as_list(stmt)
+               return self.backend.db.fetch(stmt)
 
-       def create(self, created_by):
+       async def create(self, created_by):
                """
                        Creates a new API key
                """
@@ -105,7 +105,7 @@ class Auth(object):
                secret = secrets.token_urlsafe(32)
 
                # Insert the new token into the database
-               key = self.backend.db.insert(
+               key = await self.backend.db.insert(
                        APIKey,
                        prefix     = prefix,
                        secret     = secret,
@@ -147,13 +147,13 @@ class APIKey(sqlmodel.SQLModel, database.BackendMixin, table=True):
        # Deleted By
        deleted_by : str | None
 
-       def check(self, secret):
+       async def check(self, secret):
                """
                        Checks if the provided secret matches
                """
                return secrets.compare_digest(self.secret, secret)
 
-       def delete(self, deleted_by):
+       async def delete(self, deleted_by):
                """
                        Deletes the key
                """
index edeeb8306217699b71e671474e065bc14ec96878..e86045fa6b54cb4c6d02309f5a78a1fe2b28b2fc 100644 (file)
 #                                                                             #
 ###############################################################################
 
+import asyncio
 import collections
 import functools
 import logging
+import sqlalchemy.ext.asyncio
 import sqlalchemy.orm
 import sqlmodel
 
@@ -35,23 +37,24 @@ class Database(object):
                self.engine = self.connect(uri)
 
                # Create a session maker
-               self.sessionmaker = sqlalchemy.orm.sessionmaker(
+               self.sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker(
                        self.engine,
                        expire_on_commit = False,
+                       class_           = sqlalchemy.ext.asyncio.AsyncSession,
                        info             = {
                                "backend" : self.backend,
                        },
                )
 
                # Session
-               self.__session = None
+               self.__sessions = {}
 
        def connect(self, uri):
                """
                        Connects to the database
                """
                # Create the database engine
-               engine = sqlmodel.create_engine(
+               engine = sqlalchemy.ext.asyncio.create_async_engine(
                        uri,
 
                        # Log more if we are running in debug mode
@@ -59,34 +62,94 @@ class Database(object):
 
                        # Use our own logger
                        logging_name=log.name,
+
+                       # Increase the pool size
+                       pool_size=128,
                )
 
                return engine
 
-       def session(self):
+       async def session(self):
                """
                        Returns the current database session
                """
-               if self.__session is None:
-                       self.__session = self.sessionmaker()
+               # Fetch the current task
+               task = asyncio.current_task()
+
+               # Check that we have a task
+               assert task, "Could not determine task"
+
+               # Try returning the same session to the same task
+               try:
+                       return self.__sessions[task]
+               except KeyError:
+                       pass
+
+               # Fetch a new session from the engine
+               session = self.__sessions[task] = self.sessionmaker()
+
+               log.debug("Assigning database session %s to %s" % (session, task))
 
-               return self.__session
+               # When the task finishes, release the connection
+               task.add_done_callback(self.release_session)
 
-       def __enter__(self):
-               return self.session()
+               return session
 
-       def __exit__(self, type, exception, traceback):
-               session = self.session()
+       def release_session(self, task):
+               """
+                       Called when a task that requested a session has finished.
+
+                       This method will schedule that the session is being closed.
+               """
+               self.backend.run_task(self.__release_session, task)
 
+       async def __release_session(self, task):
+               """
+                       Called when a task that had a session assigned has finished
+               """
+               exception = None
+
+               # Retrieve the session
+               try:
+                       session = self.__sessions[task]
+               except KeyError:
+                       return
+
+               # Fetch any exception if the task is done
+               if task.done():
+                       exception = task.exception()
+
+               # If there is no exception, we can commit
                if exception is None:
-                       session.commit()
+                       await session.commit()
+
+               log.debug("Releasing database session %s of %s" % (session, task))
+
+               # Delete it
+               del self.__sessions[task]
+
+               # Close the session
+               await session.close()
+
+               # Re-raise the exception
+               if exception:
+                       raise exception
+
+       async def __aenter__(self):
+               return await self.session()
+
+       async def __aexit__(self, type, exception, traceback):
+               session = await self.session()
 
-       def transaction(self):
+               if exception is None:
+                       await session.commit()
+
+       async def transaction(self):
                """
                        Opens a new transaction
                """
                # Fetch our session
-               session = self.session()
+               session = await self.session()
 
                # If we are already in a transaction, begin a nested one
                if session.in_transaction():
@@ -95,22 +158,22 @@ class Database(object):
                # Otherwise begin the first transaction of the session
                return session.begin()
 
-       def execute(self, stmt):
+       async def execute(self, stmt):
                """
                        Executes a statement and returns a result object
                """
                # Fetch our session
-               session = self.session()
+               session = await self.session()
 
                # Execute the statement
-               return session.execute(stmt)
+               return await session.execute(stmt)
 
-       def insert(self, cls, **kwargs):
+       async def insert(self, cls, **kwargs):
                """
                        Inserts a new object into the database
                """
                # Fetch our session
-               session = self.session()
+               session = await self.session()
 
                # Create a new object
                object = cls(**kwargs)
@@ -121,11 +184,11 @@ class Database(object):
                # Return the object
                return object
 
-       def select(self, stmt, batch_size=128):
+       async def select(self, stmt, batch_size=128):
                """
                        Selects custom queries
                """
-               result = self.execute(stmt)
+               result = await self.execute(stmt)
 
                # Process the result in batches
                if batch_size:
@@ -137,11 +200,11 @@ class Database(object):
 
                        yield row(**object)
 
-       def fetch(self, stmt, batch_size=128):
+       async def fetch(self, stmt, batch_size=128):
                """
                        Fetches objects of the given type
                """
-               result = self.execute(stmt)
+               result = await self.execute(stmt)
 
                # Process the result in batches
                if batch_size:
@@ -154,8 +217,8 @@ class Database(object):
                for object in result.scalars():
                        yield object
 
-       def fetch_one(self, stmt):
-               result = self.execute(stmt)
+       async def fetch_one(self, stmt):
+               result = await self.execute(stmt)
 
                # Apply unique filtering
                result = result.unique()
@@ -163,33 +226,67 @@ class Database(object):
                # Return exactly one object or none, but fail otherwise
                return result.scalar_one_or_none()
 
-       def fetch_as_list(self, *args, **kwargs):
+       async def fetch_as_list(self, *args, **kwargs):
                """
                        Fetches objects and returns them as a list instead of an iterator
                """
                objects = self.fetch(*args, **kwargs)
 
                # Return as list
-               return [o for o in objects]
+               return [o async for o in objects]
 
-       def fetch_as_set(self, *args, **kwargs):
+       async def fetch_as_set(self, *args, **kwargs):
                """
                        Fetches objects and returns them as a set instead of an iterator
                """
                objects = self.fetch(*args, **kwargs)
 
                # Return as set
-               return set([o for o in objects])
+               return set([o async for o in objects])
 
-       def commit(self):
+       async def commit(self):
                """
                        Manually triggers a database commit
                """
                # Fetch our session
-               session = self.session()
+               session = await self.session()
 
                # Commit!
-               session.commit()
+               await session.commit()
+
+       async def flush(self, *objects):
+               """
+                       Manually triggers a flush
+               """
+               # Fetch our session
+               session = await self.session()
+
+               # Flush!
+               await session.flush(objects)
+
+       async def refresh(self, o):
+               """
+                       Refreshes the given object
+               """
+               # Fetch our session
+               session = await self.session()
+
+               # Refresh!
+               await session.refresh(o)
+
+       async def flush_and_refresh(self, *objects):
+               """
+                       Flushes and refreshes in one go.
+               """
+               # Fetch our session
+               session = await self.session()
+
+               # Flush!
+               await session.flush(objects)
+
+               # Refresh!
+               for o in objects:
+                       await session.refresh(o)
 
 
 class BackendMixin:
index be70add9bacb7688b6c3430f9ce7d819a22872c8..23f069fc7d52ea8a58d50630f3f7af59be58a61f 100644 (file)
@@ -34,7 +34,7 @@ class Domains(object):
        def __init__(self, backend):
                self.backend = backend
 
-       def get_by_name(self, name):
+       async def get_by_name(self, name):
                """
                        Fetch all domain objects that match the name
                """
@@ -59,7 +59,7 @@ class Domains(object):
                        )
                )
 
-               return self.backend.db.fetch_as_set(stmt)
+               return await self.backend.db.fetch_as_set(stmt)
 
 
 class Domain(sqlmodel.SQLModel, database.BackendMixin, table=True):
@@ -78,7 +78,9 @@ class Domain(sqlmodel.SQLModel, database.BackendMixin, table=True):
        list_id: int = sqlmodel.Field(foreign_key="lists.id")
 
        # List
-       list: "List" = sqlmodel.Relationship()
+       list: "List" = sqlmodel.Relationship(
+               sa_relationship_kwargs={ "lazy" : "selectin" },
+       )
 
        # Name
        name: str
@@ -87,7 +89,9 @@ class Domain(sqlmodel.SQLModel, database.BackendMixin, table=True):
        source_id: int = sqlmodel.Field(foreign_key="sources.id")
 
        # Source
-       source: "Source" = sqlmodel.Relationship(back_populates="domains")
+       source: "Source" = sqlmodel.Relationship(
+               back_populates="domains", sa_relationship_kwargs={ "lazy" : "selectin" },
+       )
 
        # Added At
        added_at: datetime.datetime = sqlmodel.Field(
@@ -165,13 +169,17 @@ class DomainEvent(sqlmodel.SQLModel, table=True):
        id: int = sqlmodel.Field(primary_key=True, foreign_key="domains.id", exclude=True)
 
        # Domain
-       domain: Domain = sqlmodel.Relationship()
+       domain: Domain = sqlmodel.Relationship(
+               sa_relationship_kwargs={ "lazy" : "joined", "innerjoin" : True },
+       )
 
        # List ID
        list_id: int = sqlmodel.Field(foreign_key="lists.id", exclude=True)
 
        # List
-       list: "List" = sqlmodel.Relationship()
+       list: "List" = sqlmodel.Relationship(
+               sa_relationship_kwargs={ "lazy" : "joined", "innerjoin" : True },
+       )
 
        # List Slug
        @pydantic.computed_field
@@ -187,7 +195,9 @@ class DomainEvent(sqlmodel.SQLModel, table=True):
        source_id: int | None = sqlmodel.Field(foreign_key="sources.id", exclude=True)
 
        # Source
-       source: "Source" = sqlmodel.Relationship()
+       source: "Source" = sqlmodel.Relationship(
+               sa_relationship_kwargs={ "lazy" : "selectin" },
+       )
 
        # Source Name
        @pydantic.computed_field
@@ -212,7 +222,9 @@ class DomainEvent(sqlmodel.SQLModel, table=True):
        report_id: uuid.UUID | None = sqlmodel.Field(foreign_key="reports.id")
 
        # Report
-       report: "Report" = sqlmodel.Relationship()
+       report: "Report" = sqlmodel.Relationship(
+               sa_relationship_kwargs={ "lazy" : "selectin" },
+       )
 
        # Blocks
        blocks: int
index ac8149f1f41f92c101676b924d9506f84acbb47d..9e580b869f4941a0b637e78e768e74a4c1434e49 100644 (file)
@@ -38,23 +38,23 @@ class Exporter(abc.ABC):
                self.backend = backend
                self.list = list
 
-       def __call__(self, f, **kwargs):
+       async def __call__(self, f, **kwargs):
                """
                        The main entry point to export something with this exporter...
                """
                # Export!
                with util.Stopwatch(_("Exporting %(name)s using %(exporter)s") % \
                                { "name" : self.list.name, "exporter" : self.__class__.__name__ }):
-                       self.export(f, **kwargs)
+                       await self.export(f, **kwargs)
 
        @abc.abstractmethod
-       def export(self, f, **kwargs):
+       async def export(self, f, **kwargs):
                """
                        Runs the export
                """
                raise NotImplementedError
 
-       def export_to_tarball(self, tarball, name, **kwargs):
+       async def export_to_tarball(self, tarball, name, **kwargs):
                """
                        Exports the list to the tarball using the given exporter
                """
@@ -64,7 +64,7 @@ class Exporter(abc.ABC):
                f = io.BytesIO()
 
                # Export the data
-               self(f, **kwargs)
+               await self(f, **kwargs)
 
                # Expand the filename
                name = name % {
@@ -101,12 +101,12 @@ class NullExporter(Exporter):
        """
                Exports nothing, i.e. writes an empty file
        """
-       def export(self, *args, **kwargs):
+       async def export(self, *args, **kwargs):
                pass
 
 
 class TextExporter(Exporter):
-       def __call__(self, f, **kwargs):
+       async def __call__(self, f, **kwargs):
                detach = False
 
                # Convert any file handles to handle plain text
@@ -115,7 +115,7 @@ class TextExporter(Exporter):
                        detach = True
 
                # Export!
-               super().__call__(f, **kwargs)
+               await super().__call__(f, **kwargs)
 
                # Detach the underlying stream. That way, the wrapper won't close
                # the underlying file handle.
@@ -183,12 +183,12 @@ class DomainsExporter(TextExporter):
        """
                Exports the plain domains
        """
-       def export(self, f):
+       async def export(self, f):
                # Write the header
                self.write_header(f)
 
                # Write all domains
-               for domain in self.list.domains:
+               async for domain in self.list.get_domains():
                        f.write("%s\n" % domain)
 
 
@@ -196,12 +196,12 @@ class HostsExporter(TextExporter):
        """
                Exports a file like /etc/hosts
        """
-       def export(self, f):
+       async def export(self, f):
                # Write the header
                self.write_header(f)
 
                # Write all domains
-               for domain in self.list.domains:
+               async for domain in self.list.get_domains():
                        f.write("0.0.0.0 %s\n" % domain)
 
 
@@ -209,7 +209,7 @@ class AdBlockPlusExporter(TextExporter):
        """
                Exports for AdBlock Plus and compatible clients
        """
-       def export(self, f):
+       async def export(self, f):
                # Write the format
                f.write("[Adblock Plus]\n")
 
@@ -217,12 +217,12 @@ class AdBlockPlusExporter(TextExporter):
                self.write_header(f, "!")
 
                # Write all domains
-               for domain in self.list.domains:
+               async for domain in self.list.get_domains():
                        f.write("||%s^\n" % domain)
 
 
 class ZoneExporter(TextExporter):
-       def export(self, f, ttl=None, primary=None, zonemaster=None,
+       async def export(self, f, ttl=None, primary=None, zonemaster=None,
                        refresh=None, retry=None, expire=None, nameservers=None):
                # Write the header
                self.write_header(f, ";")
@@ -357,7 +357,7 @@ class ZoneExporter(TextExporter):
                f.write("_info IN TXT \"total-domains=%s\"\n" % len(self.list))
 
                # Write all domains
-               for domain in self.list.domains:
+               async for domain in self.list.get_domains():
                        for prefix in ("", "*."):
                                f.write("%s%s IN %s %s\n" % (prefix, domain, self.type, self.content))
 
@@ -390,7 +390,7 @@ class TarballExporter(Exporter):
        """
                This class contains some helper functions to write data to tarballs
        """
-       def export(self, f):
+       async def export(self, f):
                # Accept a tarball object and in that case, simply don't nest them
                if not isinstance(f, tarfile.TarFile):
                        f = tarfile.open(fileobj=f, mode="w|gz")
@@ -401,7 +401,7 @@ class TarballExporter(Exporter):
                        e = exporter(self.backend, self.list)
 
                        # Export to the tarball
-                       e.export_to_tarball(f, file)
+                       await e.export_to_tarball(f, file)
 
 
 class SquidGuardExporter(TarballExporter):
@@ -419,7 +419,7 @@ class SuricataRulesExporter(TextExporter):
        """
                Export domains as a set of rules for Suricata
        """
-       def export(self, f):
+       async def export(self, f):
                # Write the header
                self.write_header(f)
 
@@ -572,11 +572,11 @@ class SuricataDatasetExporter(TextExporter):
        """
                Exports the domains encoded as base64
        """
-       def export(self, f):
+       async def export(self, f):
                # This file cannot have a header because Suricata will try to base64-decode it, too
 
                # Write all domains
-               for domain in self.list.domains:
+               async for domain in self.list.get_domains():
                        # Convert the domain to bytes
                        domain = domain.encode()
 
@@ -601,16 +601,12 @@ class MultiExporter(abc.ABC):
                This is a base class that can export multiple lists at the same time
        """
 
-       def __init__(self, backend, lists=None):
+       def __init__(self, backend, lists):
                self.backend = backend
-
-               if lists is None:
-                       lists = backend.lists
-
                self.lists = lists
 
        @abc.abstractmethod
-       def __call__(self, *args, **kwargs):
+       async def __call__(self, *args, **kwargs):
                raise NotImplementedError
 
        @property
@@ -625,24 +621,24 @@ class CombinedSquidGuardExporter(MultiExporter):
        """
                This is a special exporter which combines all lists into a single tarball
        """
-       def __call__(self, f):
+       async def __call__(self, f):
                # Create a tar file
                with tarfile.open(fileobj=f, mode="w|gz") as tarball:
                        for list in self.lists:
                                exporter = SquidGuardExporter(self.backend, list)
-                               exporter(tarball)
+                               await exporter(tarball)
 
 
 class CombinedSuricataExporter(MultiExporter):
        """
                This is a special exporter which combines all Suricata rulesets into a tarball
        """
-       def __call__(self, f):
+       async def __call__(self, f):
                # Create a tar file
                with tarfile.open(fileobj=f, mode="w|gz") as tarball:
                        for list in self.lists:
                                exporter = SuricataExporter(self.backend, list)
-                               exporter(tarball)
+                               await exporter(tarball)
 
 
 class DirectoryExporter(MultiExporter):
@@ -667,13 +663,13 @@ class DirectoryExporter(MultiExporter):
                "suricata.tar.gz"      : CombinedSuricataExporter,
        }
 
-       def __init__(self, backend, root, lists=None):
+       def __init__(self, backend, root, lists):
                super().__init__(backend, lists)
 
                # Store the root
                self.root = pathlib.Path(root)
 
-       def __call__(self):
+       async def __call__(self):
                # Ensure the root directory exists
                try:
                        self.root.mkdir()
@@ -685,15 +681,15 @@ class DirectoryExporter(MultiExporter):
                        # For MultiExporters, we will have to export everything at once
                        if issubclass(exporter, MultiExporter):
                                e = exporter(self.backend, self.lists)
-                               self.export(e, name)
+                               await self.export(e, name)
 
                        # For regular exporters, we will have to export each list at a time
                        else:
                                for list in self.lists:
                                        e = exporter(self.backend, list)
-                                       self.export(e, name, list=list)
+                                       await self.export(e, name, list=list)
 
-       def export(self, exporter, name, **kwargs):
+       async def export(self, exporter, name, **kwargs):
                """
                        This function takes an exporter instance and runs it
                """
@@ -709,7 +705,7 @@ class DirectoryExporter(MultiExporter):
                # Create a new temporary file
                with tempfile.NamedTemporaryFile(dir=path.parent) as f:
                        # Export everthing to the file
-                       exporter(f)
+                       await exporter(f)
 
                        # Set the modification time (so that clients won't download again
                        # just because we have done a re-export)
index f9104966809e6f3d5fc3a17f9a6aef332b1701ac..ec6d8d99d4753040655e94c60b2f27eb8418959d 100644 (file)
@@ -50,7 +50,7 @@ class Lists(object):
        def __init__(self, backend):
                self.backend = backend
 
-       def __iter__(self):
+       def __aiter__(self):
                """
                        Returns an iterator over all lists
                """
@@ -70,7 +70,7 @@ class Lists(object):
 
                return self.backend.db.fetch(stmt)
 
-       def get_by_slug(self, slug):
+       async def get_by_slug(self, slug):
                stmt = (
                        sqlmodel
                        .select(
@@ -82,26 +82,26 @@ class Lists(object):
                        )
                )
 
-               return self.backend.db.fetch_one(stmt)
+               return await self.backend.db.fetch_one(stmt)
 
-       def _make_slug(self, name):
+       async def _make_slug(self, name):
                i = 0
 
                while True:
                        slug = util.slugify(name, i)
 
                        # Skip if the list already exists
-                       if self.get_by_slug(slug):
+                       if await self.get_by_slug(slug):
                                i += 1
                                continue
 
                        return slug
 
-       def create(self, name, created_by, license, description=None, priority=None):
+       async def create(self, name, created_by, license, description=None, priority=None):
                """
                        Creates a new list
                """
-               slug = self._make_slug(name)
+               slug = await self._make_slug(name)
 
                # Map priority
                try:
@@ -110,7 +110,7 @@ class Lists(object):
                        raise ValueError("Invalid priority: %s" % priority) from e
 
                # Create a new list
-               return self.backend.db.insert(
+               return await self.backend.db.insert(
                        List,
                        name        = name,
                        slug        = slug,
@@ -120,7 +120,7 @@ class Lists(object):
                        priority    = priority,
                )
 
-       def get_listing_domain(self, name):
+       async def get_listing_domain(self, name):
                """
                        Returns all lists that currently contain the given domain.
                """
@@ -149,7 +149,7 @@ class Lists(object):
                        )
                )
 
-               return self.backend.db.fetch_as_list(stmt)
+               return await self.backend.db.fetch_as_list(stmt)
 
 
 class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
@@ -224,12 +224,13 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                                Source.deleted_at == None,
                        )""",
                        "order_by"    : "Source.name, Source.url",
+                       "lazy"        : "selectin",
                 },
        )
 
        # Delete Source!
 
-       def delete_source(self, url, **kwargs):
+       async def delete_source(self, url, **kwargs):
                """
                        Removes a source from the list
                """
@@ -350,8 +351,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
                return listed_domains
 
-       @property
-       def domains(self):
+       async def get_domains(self):
                """
                        Returns all domains that are on this list
                """
@@ -365,7 +365,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                canary_inserted = False
 
                # Walk through all domains and insert the canary
-               for domain in domains:
+               async for domain in domains:
                        # If we have not inserted the canary, yet, we will do
                        # it whenever it alphabetically fits
                        if not canary_inserted and domain > self.canary:
@@ -379,12 +379,12 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                if not canary_inserted:
                        yield self.canary
 
-       def add_domain(self, name, added_by, report=None, block=True):
+       async def add_domain(self, name, added_by, report=None, block=True):
                """
                        Adds a new domain to the list
                """
                # Check if the domain is already listed
-               domain = self.get_domain(name)
+               domain = await self.get_domain(name)
                if domain:
                        # Silently ignore if the domain is already listed
                        if domain.block == block:
@@ -398,7 +398,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                        )
 
                # Add the domain to the database
-               domain = self.backend.db.insert(
+               domain = await self.backend.db.insert(
                        domains.Domain,
                        list       = self,
                        name       = name,
@@ -415,7 +415,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
                return domain
 
-       def get_domain(self, name):
+       async def get_domain(self, name):
                """
                        Fetches a domain (not including sources)
                """
@@ -439,9 +439,9 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                        )
                )
 
-               return self.backend.db.fetch_one(stmt)
+               return await self.backend.db.fetch_one(stmt)
 
-       def get_sources_by_domain(self, name):
+       async def get_sources_by_domain(self, name):
                """
                        Returns all sources that list the given domain
                """
@@ -472,7 +472,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                        )
                )
 
-               return self.backend.db.fetch_as_set(stmt)
+               return await self.backend.db.fetch_as_set(stmt)
 
        # Total Domains
        total_domains: int = 0
@@ -482,7 +482,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
        # Delete!
 
-       def delete(self, deleted_by):
+       async def delete(self, deleted_by):
                """
                        Deletes the list
                """
@@ -493,11 +493,11 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
        # Update!
 
-       def update(self, **kwargs):
+       async def update(self, **kwargs):
                """
                        Updates the list
                """
-               with self.backend.db.transaction():
+               with await self.backend.db.transaction():
                        updated = False
 
                        # Update all sources
@@ -510,12 +510,12 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                                self.updated_at = sqlmodel.func.current_timestamp()
 
                        # Optimize the list
-                       self.optimize(update_stats=False)
+                       await self.optimize(update_stats=False)
 
                        # Update the stats
-                       self.update_stats()
+                       await self.update_stats()
 
-       def update_stats(self):
+       async def update_stats(self):
                stmt = (
                        sqlmodel
                        .select(
@@ -526,10 +526,10 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                )
 
                # Store the number of total domains
-               self.total_domains = self.backend.db.fetch_one(stmt)
+               self.total_domains = await self.backend.db.fetch_one(stmt)
 
                # Store the number of subsumed domains
-               self.subsumed_domains = self.backend.db.fetch_one(
+               self.subsumed_domains = await self.backend.db.fetch_one(
                        sqlmodel
                        .select(
                                sqlmodel.func.count(),
@@ -550,7 +550,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                )
 
                # Store the stats history
-               self.backend.db.insert(
+               await self.backend.db.insert(
                        ListStats,
                        list             = self,
                        total_domains    = self.total_domains,
@@ -559,7 +559,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
        # Export!
 
-       def export(self, f, format, **kwargs):
+       async def export(self, f, format, **kwargs):
                """
                        Exports the list
                """
@@ -587,7 +587,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
        # Reports
        reports : typing.List["Report"] = sqlmodel.Relationship(back_populates="list")
 
-       def get_reports(self, open=None, name=None, limit=None):
+       async def get_reports(self, open=None, name=None, limit=None):
                """
                        Fetches the most recent reports
                """
@@ -624,15 +624,15 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                if limit:
                        stmt = stmt.limit(limit)
 
-               return self.backend.db.fetch_as_list(stmt)
+               return await self.backend.db.fetch_as_list(stmt)
 
        # Report!
 
-       def report(self, **kwargs):
+       async def report(self, **kwargs):
                """
                        Creates a new report for this list
                """
-               return self.backend.reports.create(list=self, **kwargs)
+               return await self.backend.reports.create(list=self, **kwargs)
 
        # Pending Reports
 
@@ -824,14 +824,14 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
                return self.backend.db.fetch(stmt)
 
-       def optimize(self, update_stats=True):
+       async def optimize(self, update_stats=True):
                """
                        Optimizes this list
                """
                log.info("Optimizing %s..." % self)
 
                # Fetch all domains on this list
-               names = self.backend.db.fetch_as_set(
+               names = await self.backend.db.fetch_as_set(
                        sqlmodel
                        .select(
                                domains.Domain.name
@@ -862,7 +862,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                log.info(_("Identified %s redunduant domain(s)") % len(redundant_names))
 
                # Reset the status for all domains
-               self.backend.db.execute(
+               await self.backend.db.execute(
                        sqlmodel
                        .update(
                                domains.Domain,
@@ -878,7 +878,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
                )
 
                # De-list the redundant domains
-               self.backend.db.execute(
+               await self.backend.db.execute(
                        sqlmodel
                        .update(
                                domains.Domain,
@@ -897,7 +897,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
                # Update all stats afterwards
                if update_stats:
-                       self.update_stats()
+                       await self.update_stats()
 
 
 class ListStats(sqlmodel.SQLModel, table=True):
index 854bd817ed95689870e83bd59c815d4682158ae8..534eccac2ff313b52248ba3b17aab8227fc82a9f 100644 (file)
@@ -35,7 +35,7 @@ class Reports(object):
        def __init__(self, backend):
                self.backend = backend
 
-       def __iter__(self):
+       def __aiter__(self):
                stmt = (
                        sqlmodel
                        .select(
@@ -51,7 +51,7 @@ class Reports(object):
 
                return self.backend.db.fetch(stmt)
 
-       def get_by_id(self, id):
+       async def get_by_id(self, id):
                stmt = (
                        sqlmodel
                        .select(
@@ -62,25 +62,28 @@ class Reports(object):
                        )
                )
 
-               return self.backend.db.fetch_one(stmt)
+               return await self.backend.db.fetch_one(stmt)
 
-       def create(self, **kwargs):
+       async def create(self, **kwargs):
                """
                        Creates a new report
                """
-               report = self.backend.db.insert(
+               report = await self.backend.db.insert(
                        Report, **kwargs,
                )
 
+               # Manifest the object in the database immediately to assign the ID
+               await self.backend.db.flush_and_refresh(report)
+
                # Increment the counter of the list
                report.list.pending_reports += 1
 
                # Send a notification to the reporter
-               report._send_opening_notification()
+               await report._send_opening_notification()
 
                return report
 
-       def notify(self):
+       async def notify(self):
                """
                        Notifies moderators about any pending reports
                """
@@ -102,7 +105,7 @@ class Reports(object):
                lists = {}
 
                # Group reports by list
-               for report in reports:
+               async for report in reports:
                        try:
                                lists[report.list].append(report)
                        except KeyError:
@@ -190,7 +193,13 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True):
        list_id : int = sqlmodel.Field(foreign_key="lists.id", exclude=True)
 
        # List
-       list : "List" = sqlmodel.Relationship(back_populates="reports")
+       list : "List" = sqlmodel.Relationship(
+               back_populates         = "reports",
+               sa_relationship_kwargs = {
+                       "lazy"      : "joined",
+                       "innerjoin" : True,
+               },
+       )
 
        @pydantic.computed_field
        @property
@@ -225,7 +234,7 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
        # Close!
 
-       def close(self, closed_by=None, accept=True, update_stats=True):
+       async def close(self, closed_by=None, accept=True, update_stats=True):
                """
                        Called when a moderator has made a decision
                """
@@ -246,14 +255,14 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True):
                self.list.pending_reports -= 1
 
                # Send a message to the reporter?
-               self._send_closing_notification()
+               await self._send_closing_notification()
 
                # We are done if the report has not been accepted
                if not self.accepted:
                        return
 
                # Add the domain to the list (the add function will do the rest)
-               self.list.add_domain(
+               await self.list.add_domain(
                        name     = self.name,
                        added_by = self.reported_by,
                        report   = self,
@@ -266,13 +275,13 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True):
                # Update list stats
                if update_stats:
                        # Update stats for all sources that list this domain
-                       for source in self.list.get_sources_by_domain(self.name):
+                       for source in await self.list.get_sources_by_domain(self.name):
                                source.update_stats()
 
                        # Update the list's stats
-                       self.list.update_stats()
+                       await self.list.update_stats()
 
-       def _send_opening_notification(self):
+       async def _send_opening_notification(self):
                """
                        Sends a notification to the reporter when this report is opened.
                """
@@ -309,7 +318,7 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True):
                        "Subject" : _("Your DBL Report Has Been Received"),
                })
 
-       def _send_closing_notification(self):
+       async def _send_closing_notification(self):
                """
                        Sends a notification to the reporter when this report gets closed.
                """
index dd67147af40ae218077fd6b8e7e9058374b3b0c3..6561782f74971f2622b039d0a100ecf428e85cd4 100644 (file)
@@ -54,7 +54,7 @@ class Sources(object):
        def __init__(self, backend):
                self.backend = backend
 
-       def __iter__(self):
+       def __aiter__(self):
                stmt = (
                        sqlmodel
                        .select(
@@ -70,7 +70,7 @@ class Sources(object):
 
                return self.backend.db.fetch(stmt)
 
-       def get_by_id(self, id):
+       async def get_by_id(self, id):
                stmt = (
                        sqlmodel
                        .select(
@@ -81,13 +81,13 @@ class Sources(object):
                        )
                )
 
-               return self.backend.db.fetch_one(stmt)
+               return await self.backend.db.fetch_one(stmt)
 
-       def create(self, list, name, url, created_by, license):
+       async def create(self, list, name, url, created_by, license):
                """
                        Creates a new source
                """
-               return self.backend.db.insert(
+               return await self.backend.db.insert(
                        Source,
                        list       = list,
                        name       = name,
@@ -159,7 +159,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
        # Delete!
 
-       def delete(self, deleted_by):
+       async def delete(self, deleted_by):
                """
                        Deletes the source
                """
@@ -180,12 +180,12 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                                removed_at=sqlmodel.func.current_timestamp(),
                        )
                )
-               self.backend.db.execute(stmt)
+               await self.backend.db.execute(stmt)
 
                # Log action
                log.info(_("Source '%s' has been deleted from '%s'") % (self, self.list))
 
-       def update(self, force=False):
+       async def update(self, force=False):
                """
                        Updates this source.
                """
@@ -202,7 +202,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                        log.debug("Forcing update of %s because it has no data" % self)
                        force = True
 
-               with self.db.transaction():
+               with await self.db.transaction():
                        with self.backend.client() as client:
                                # Compose some request headers
                                headers = self._make_headers(force=force)
@@ -328,7 +328,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                                                        return False
 
                                                # Add all domains to the database
-                                               self.add_domains(domains)
+                                               await self.add_domains(domains)
 
                                        # The list has now been updated
                                        self.updated_at = sqlmodel.func.current_timestamp()
@@ -346,10 +346,10 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                                        return False
 
                        # Mark all domains that have not been updated as removed
-                       self.__prune()
+                       await self.__prune()
 
                        # Update the stats
-                       self.update_stats()
+                       await self.update_stats()
 
                # Signal that we have actually fetched new data
                return True
@@ -521,7 +521,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                """
                return line
 
-       def add_domains(self, _domains):
+       async def add_domains(self, _domains):
                """
                        Adds or updates a domain.
                """
@@ -553,9 +553,10 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                                        }
                                )
                        )
-                       self.backend.db.execute(stmt)
 
-       def __prune(self):
+                       await self.backend.db.execute(stmt)
+
+       async def __prune(self):
                """
                        Prune any domains that have not been updated.
 
@@ -577,9 +578,9 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                                domains.Domain.removed_at == None,
                        )
                )
-               self.backend.db.execute(stmt)
+               await self.backend.db.execute(stmt)
 
-       def duplicates(self):
+       async def duplicates(self):
                """
                        Finds the number of duplicates against other sources
                """
@@ -617,7 +618,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                        )
 
                        # Run the query
-                       sources[source] = self.backend.db.fetch_one(stmt)
+                       sources[source] = await self.backend.db.fetch_one(stmt)
 
                return sources
 
@@ -629,7 +630,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
        false_positives: int = 0
 
-       def update_stats(self):
+       async def update_stats(self):
                """
                        Updates the stats of this source
                """
@@ -648,7 +649,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                )
 
                # Store the total number of domains
-               self.total_domains = self.backend.db.fetch_one(stmt)
+               self.total_domains = await self.backend.db.fetch_one(stmt)
 
                stmt = (
                        sqlmodel
@@ -668,7 +669,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                )
 
                # Store the total number of dead domains
-               self.dead_domains = self.backend.db.fetch_one(stmt)
+               self.dead_domains = await self.backend.db.fetch_one(stmt)
 
                whitelisted_domains = (
                        sqlmodel
@@ -721,10 +722,10 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                )
 
                # Store the number of false positives
-               self.false_positives = self.backend.db.fetch_one(stmt)
+               self.false_positives = await self.backend.db.fetch_one(stmt)
 
                # Store the stats history
-               self.backend.db.insert(
+               await self.backend.db.insert(
                        SourceStats,
                        source          = self,
                        total_domains   = self.total_domains,
index aa4560f6968fa16e27c4f01b700c6f1940b881be..6e2b0aa209b3d89a7d951240b1bcea0741edb618 100644 (file)
@@ -19,6 +19,7 @@
 #                                                                             #
 ###############################################################################
 
+import asyncio
 import argparse
 import babel.dates
 import babel.numbers
@@ -222,7 +223,7 @@ class CLI(object):
 
                return args
 
-       def run(self):
+       async def run(self):
                # Parse the command line
                args = self.parse_cli()
 
@@ -239,18 +240,18 @@ class CLI(object):
                )
 
                # Call the handler function
-               with backend.db:
-                       ret = args.func(backend, args)
+               async with backend.db:
+                       ret = await args.func(backend, args)
 
                # Exit with the returned error code
                if ret:
-                       sys.exit(ret)
+                       return ret
 
                # Otherwise just exit
-               sys.exit(0)
+               return 0
 
        @staticmethod
-       def terminate(message, code=2):
+       async def terminate(message, code=2):
                """
                        Convenience function to terminate the program gracefully with a message
                """
@@ -261,20 +262,20 @@ class CLI(object):
                # Terminate with the given code
                raise SystemExit(code)
 
-       def __get_list(self, backend, slug):
+       async def __get_list(self, backend, slug):
                """
                        Fetches a list or terminates the program if we could not find the list.
                """
                # Fetch the list
-               list = backend.lists.get_by_slug(slug)
+               list = await backend.lists.get_by_slug(slug)
 
                # Terminate because we could not find the list
                if not list:
-                       self.terminate("Could not find list '%s'" % slug)
+                       await self.terminate("Could not find list '%s'" % slug)
 
                return list
 
-       def __list(self, backend, args):
+       async def __list(self, backend, args):
                table = rich.table.Table(title=_("Lists"))
                table.add_column(_("Name"))
                table.add_column(_("Slug"))
@@ -283,7 +284,7 @@ class CLI(object):
                table.add_column(_("Listed Unique Domains"), justify="right")
 
                # Show all lists
-               for list in backend.lists:
+               async for list in backend.lists:
                        table.add_row(
                                list.name,
                                list.slug,
@@ -295,11 +296,11 @@ class CLI(object):
                # Print the table
                self.console.print(table)
 
-       def __create(self, backend, args):
+       async def __create(self, backend, args):
                """
                        Creates a new list
                """
-               backend.lists.create(
+               await backend.lists.create(
                        name        = args.name,
                        created_by  = args.created_by,
                        license     = args.license,
@@ -307,22 +308,22 @@ class CLI(object):
                        priority    = args.priority,
                )
 
-       def __delete(self, backend, args):
+       async def __delete(self, backend, args):
                """
                        Deletes a list
                """
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                # Delete!
-               list.delete(
+               await list.delete(
                        deleted_by = args.deleted_by,
                )
 
-       def __show(self, backend, args):
+       async def __show(self, backend, args):
                """
                        Shows information about a list
                """
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                table = rich.table.Table(title=list.name)
                table.add_column(_("Property"))
@@ -365,50 +366,51 @@ class CLI(object):
                # Print the sources
                self.console.print(table)
 
-       def __update(self, backend, args):
+       async def __update(self, backend, args):
                """
                        Updates a single list
                """
                # Fetch the list
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                # Update!
-               list.update(force=args.force)
+               await list.update(force=args.force)
 
-       def __update_all(self, backend, args):
+       async def __update_all(self, backend, args):
                """
                        Updates all lists
                """
-               for list in backend.lists:
-                       list.update(force=args.force)
+               async for list in backend.lists:
+                       await list.update(force=args.force)
 
-       def __export(self, backend, args):
+       async def __export(self, backend, args):
                """
                        Exports a list
                """
                # Fetch the list
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                # Export!
-               list.export(args.output, format=args.format)
+               await list.export(args.output, format=args.format)
 
-       def __export_all(self, backend, args):
+       async def __export_all(self, backend, args):
                """
                        Exports all lists
                """
                # Launch the DirectoryExporter
-               exporter = dbl.exporters.DirectoryExporter(backend, root=args.directory)
-               exporter()
+               exporter = dbl.exporters.DirectoryExporter(backend,
+                       root=args.directory, lists=[l async for l in backend.lists])
+               await exporter()
 
-       def __add_source(self, backend, args):
+       async def __add_source(self, backend, args):
                """
                        Adds a new source to a list
                """
                # Fetch the list
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                # Create the source
-               backend.sources.create(
+               await backend.sources.create(
                        list       = list,
                        name       = args.name,
                        url        = args.url,
@@ -416,25 +418,25 @@ class CLI(object):
                        license    = args.license,
                )
 
-       def __delete_source(self, backend, args):
+       async def __delete_source(self, backend, args):
                """
                        Removes a source from a list
                """
                # Fetch the list
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                # Remove the source
-               list.delete_source(
+               await list.delete_source(
                        url        = args.url,
                        deleted_by = args.deleted_by,
                )
 
-       def __search(self, backend, args):
+       async def __search(self, backend, args):
                """
                        Searches for a domain name
                """
                # Check if a domain is active
-               active = backend.check(args.domain)
+               active = await backend.check(args.domain)
 
                # If the domain is dead, we show a warning
                if active is False:
@@ -445,7 +447,7 @@ class CLI(object):
                        self.console.print(warning)
 
                # Search!
-               lists = backend.search(args.domain)
+               lists = await backend.search(args.domain)
 
                # Do nothing if nothing was found
                if not lists:
@@ -467,17 +469,17 @@ class CLI(object):
                # Print the table
                self.console.print(table)
 
-       def __analyze(self, backend, args):
+       async def __analyze(self, backend, args):
                """
                        Analyzes a list
                """
                # Fetch the list
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                # Show duplicates
-               self.__analyze_duplicates(list)
+               await self.__analyze_duplicates(list)
 
-       def __analyze_duplicates(self, list):
+       async def __analyze_duplicates(self, list):
                table = rich.table.Table(title=_("Duplication"))
 
                table.add_column(_("List"))
@@ -489,7 +491,7 @@ class CLI(object):
                # Check duplicates
                for source in list.sources:
                        # Determine all duplicates against other sources
-                       duplicates = source.duplicates()
+                       duplicates = await source.duplicates()
 
                        columns = []
 
@@ -519,22 +521,22 @@ class CLI(object):
                # Print the table
                self.console.print(table)
 
-       def __optimize(self, backend, args):
+       async def __optimize(self, backend, args):
                """
                        Optimizes a list
                """
                # Fetch the list
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                with dbl.util.Stopwatch(_("Optimizing %s") % list):
-                       list.optimize()
+                       await list.optimize()
 
-       def __history(self, backend, args):
+       async def __history(self, backend, args):
                """
                        Shows the history of a list
                """
                # Fetch the list
-               list = self.__get_list(backend, args.list)
+               list = await self.__get_list(backend, args.list)
 
                # Fetch the history
                history = list.get_history(limit=args.limit)
@@ -549,7 +551,7 @@ class CLI(object):
                table.add_column(_("Domains Removed"))
 
                # Add the history
-               for events in history:
+               async for events in history:
                        table.add_row(
                                babel.dates.format_datetime(events.ts),
                                "\n".join(events.domains_blocked),
@@ -560,67 +562,70 @@ class CLI(object):
                # Print the table
                self.console.print(table)
 
-       def __check_domains(self, backend, args):
+       async def __check_domains(self, backend, args):
                """
                        Runs the checker over all domains
                """
                checker = dbl.checker.Checker(backend)
-               checker.check(*args.domain)
+               await checker.check(*args.domain)
 
        # Authentication
 
-       def __create_api_key(self, backend, args):
+       async def __create_api_key(self, backend, args):
                """
                        Creates a new API key
                """
-               key = backend.auth.create(created_by=args.created_by)
+               key = await backend.auth.create(created_by=args.created_by)
 
                # Show the new key
                print(_("Your new API key has been created: %s") % key)
 
-       def __delete_api_key(self, backend, args):
+       async def __delete_api_key(self, backend, args):
                """
                        Creates a new API key
                """
-               key = backend.auth.get_key(args.key)
+               key = await backend.auth.get_key(args.key)
 
                # If we could not find a key, we cannot delete it
                if not key:
-                       self.terminate("Could not find key %s" % args.key)
+                       await self.terminate("Could not find key %s" % args.key)
 
                # Delete the key
-               key.delete(deleted_by=args.deleted_by)
+               await key.delete(deleted_by=args.deleted_by)
 
        # Reports
 
-       def __close_report(self, backend, args):
+       async def __close_report(self, backend, args):
                """
                        Closes a report
                """
-               report = backend.reports.get_by_id(args.id)
+               report = await backend.reports.get_by_id(args.id)
 
                # Fail if we cannot find the report
                if not report:
-                       self.terminate("Could not find report %s" % args.id)
+                       await self.terminate("Could not find report %s" % args.id)
 
                # Close the report
-               report.close(
+               await report.close(
                        closed_by = args.closed_by,
                        accept    = args.accept and not args.reject,
                )
 
        # Notify
 
-       def __notify(self, backend, args):
+       async def __notify(self, backend, args):
                """
                        Notifies moderators about any pending reports
                """
-               backend.reports.notify()
+               await backend.reports.notify()
 
 
 def main():
        c = CLI()
-       c.run()
+       ret = asyncio.run(c.run())
+
+       # Terminate passing the return code
+       sys.exit(ret)
 
 if __name__ == "__main__":
        main()