From: Michael Tremer Date: Fri, 20 Feb 2026 15:40:03 +0000 (+0000) Subject: Make the entire application async X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=d2b7de9db18058f04e6154dcaa2278d919c94cd9;p=dbl.git Make the entire application async This was needed because we seem to have some issues with the API starting the next request or trying to access the database after the session has been closed. Instead we will now have a single database session per task which can be managed easier and will only be closed after the entire task has completed. As another benefit, we are now able to run many requests simultaneously. So far this has not been a big bottleneck, but some operations (like closing a report) can take a moment and would therefore have been blocking other requests. Signed-off-by: Michael Tremer --- diff --git a/src/dbl/__init__.py b/src/dbl/__init__.py index eaf374a..43ff8ea 100644 --- a/src/dbl/__init__.py +++ b/src/dbl/__init__.py @@ -18,6 +18,7 @@ # # ############################################################################### +import asyncio import configparser import functools import httpx @@ -117,6 +118,24 @@ class Backend(object): # Put everything back together again return "%s[.]%s" % (name, tld) + # Functions to run something in background + + # A list of any background tasks + __tasks = set() + + def run_task(self, callback, *args, **kwargs): + """ + Runs the given coroutine in the background + """ + # Create a new task + task = asyncio.create_task(callback(*args, **kwargs)) + + # Keep a reference to the task and remove it when the task has finished + self.__tasks.add(task) + task.add_done_callback(self.__tasks.discard) + + return task + @functools.cached_property def auth(self): return auth.Auth(self) @@ -177,7 +196,7 @@ class Backend(object): return conn - def search(self, name): + async def search(self, name): """ Searches for a domain """ @@ -211,7 +230,7 @@ class Backend(object): res = {} # Group all matches by list - for domain in self.db.fetch(stmt): + async for domain in self.db.fetch(stmt): try: res[domain.list].append(domain.name) except KeyError: @@ -226,13 +245,13 @@ class Backend(object): domain = matches.pop() # Fetch the history of the longest match - for history in list.get_domain_history(domain, limit=1): + async for history in list.get_domain_history(domain, limit=1): res[list] = history break return res - def check(self, domain): + async def check(self, domain): """ Returns the status of a domain """ @@ -246,4 +265,4 @@ class Backend(object): ) ) - return not self.db.fetch_one(stmt) + return not await self.db.fetch_one(stmt) diff --git a/src/dbl/api/__init__.py b/src/dbl/api/__init__.py index 5324803..9720a98 100644 --- a/src/dbl/api/__init__.py +++ b/src/dbl/api/__init__.py @@ -60,7 +60,7 @@ from . import reports # Search @app.get("/search") -def search(q: str): +async def search(q: str): """ Performs a simple search """ @@ -71,7 +71,7 @@ def search(q: str): raise fastapi.HTTPException(400, "Not a valid FQDN: %s" % q) # Perform the search - results = backend.search(q) + results = await backend.search(q) # Format the result for list, event in results.items(): diff --git a/src/dbl/api/lists.py b/src/dbl/api/lists.py index 48550bd..cc2722e 100644 --- a/src/dbl/api/lists.py +++ b/src/dbl/api/lists.py @@ -39,25 +39,25 @@ router = fastapi.APIRouter( tags=["Lists"], ) -def get_list_from_path(list: str = fastapi.Path(...)) -> lists.List: +async def get_list_from_path(list: str = fastapi.Path(...)) -> lists.List: """ Fetches a list by its slug """ - return backend.lists.get_by_slug(list) + return await backend.lists.get_by_slug(list) @router.get("") -def get_lists() -> typing.List[lists.List]: - return [l for l in backend.lists] +async def get_lists() -> typing.List[lists.List]: + return [l async for l in backend.lists] @router.get("/{list}") -def get_list(list = fastapi.Depends(get_list_from_path)) -> lists.List: +async def get_list(list = fastapi.Depends(get_list_from_path)) -> lists.List: if not list: raise fastapi.HTTPException(404, "Could not find list") return list @router.get("/{list}/history") -def get_list_history( +async def get_list_history( list = fastapi.Depends(get_list_from_path), before: datetime.datetime = None, limit: int = 10, @@ -69,28 +69,28 @@ def get_list_history( "domains_blocked" : e.domains_blocked, "domains_allowed" : e.domains_allowed, "domains_removed" : e.domains_removed, - } for e in list.get_history(before=before, limit=limit) + } async for e in list.get_history(before=before, limit=limit) ] @router.get("/{list}/sources") -def get_list_sources(list = fastapi.Depends(get_list_from_path)) -> typing.List[sources.Source]: +async def get_list_sources(list = fastapi.Depends(get_list_from_path)) -> typing.List[sources.Source]: return list.sources @router.get("/{list}/reports") -def get_list_reports( +async def get_list_reports( list = fastapi.Depends(get_list_from_path), open: bool | None = None, name: str | None = None, limit: int | None = None ) -> typing.List[reports.Report]: - return list.get_reports(open=open, name=name, limit=limit) + return await list.get_reports(open=open, name=name, limit=limit) @router.get("/{list}/domains/{name}") -def get_list_domains( +async def get_list_domains( name: str, list = fastapi.Depends(get_list_from_path), ) -> typing.List[domains.DomainEvent]: # Fetch the domain history - return list.get_domain_history(name) + return [e async for e in list.get_domain_history(name)] class CreateReport(pydantic.BaseModel): @@ -108,19 +108,17 @@ class CreateReport(pydantic.BaseModel): @router.post("/{list}/reports") -def list_report( +async def list_report( report: CreateReport, auth = fastapi.Depends(require_api_key), list = fastapi.Depends(get_list_from_path), ) -> reports.Report: - # Create a new report - with backend.db: - return list.report( - name = report.name, - reported_by = report.reported_by, - comment = report.comment, - block = report.block, - ) + return await list.report( + name = report.name, + reported_by = report.reported_by, + comment = report.comment, + block = report.block, + ) # Include our endpoints app.include_router(router) diff --git a/src/dbl/api/middlewares.py b/src/dbl/api/middlewares.py index 6a59f66..36e249f 100644 --- a/src/dbl/api/middlewares.py +++ b/src/dbl/api/middlewares.py @@ -25,5 +25,5 @@ class DatabaseSessionMiddleware(fastapi.applications.BaseHTTPMiddleware): backend = request.app.state.backend # Acquire a new database session and process the request - with backend.db.session() as session: + async with await backend.db.session() as session: return await call_next(request) diff --git a/src/dbl/api/reports.py b/src/dbl/api/reports.py index eb241c9..cb13a11 100644 --- a/src/dbl/api/reports.py +++ b/src/dbl/api/reports.py @@ -35,12 +35,12 @@ router = fastapi.APIRouter( tags=["Reports"], ) -def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report: +async def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report: """ Fetches a report by its ID """ # Fetch the report - report = backend.reports.get_by_id(id) + report = await backend.reports.get_by_id(id) # Raise 404 if we could not find the report if not report: @@ -49,11 +49,11 @@ def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report: return report @router.get("") -def get_reports() -> typing.List[reports.Report]: - return [l for l in backend.reports] +async def get_reports() -> typing.List[reports.Report]: + return [l async for l in backend.reports] @router.get("/{id}") -def get_report(report = fastapi.Depends(get_report_from_path)) -> reports.Report: +async def get_report(report = fastapi.Depends(get_report_from_path)) -> reports.Report: return report class CloseReport(pydantic.BaseModel): @@ -64,16 +64,14 @@ class CloseReport(pydantic.BaseModel): accept: bool = True @router.post("/{id}/close") -def close_report( +async def close_report( data: CloseReport, report: reports.Report = fastapi.Depends(get_report_from_path), ) -> fastapi.Response: - # Close the report - with backend.db: - report.close( - closed_by = data.closed_by, - accept = data.accept, - ) + await report.close( + closed_by = data.closed_by, + accept = data.accept, + ) # Send 204 return fastapi.Response(status_code=fastapi.status.HTTP_204_NO_CONTENT) diff --git a/src/dbl/auth.py b/src/dbl/auth.py index 2c0440c..4159a76 100644 --- a/src/dbl/auth.py +++ b/src/dbl/auth.py @@ -42,12 +42,12 @@ class Auth(object): def __init__(self, backend): self.backend = backend - def __call__(self, key): + async def __call__(self, key): """ The main authentication function which takes an API key and checks it against the database. """ - key = self.get_key(key) + key = await self.get_key(key) # If we have found a key, we have successfully authenticated if key: @@ -56,7 +56,7 @@ class Auth(object): return False - def get_key(self, key): + async def get_key(self, key): """ Fetches a specific key """ @@ -72,7 +72,7 @@ class Auth(object): keys = self.get_keys(prefix) # Check all keys - for key in keys: + async for key in keys: # Return True if the key matches the secret if key.check(secret): return key @@ -92,9 +92,9 @@ class Auth(object): ) ) - return self.backend.db.fetch_as_list(stmt) + return self.backend.db.fetch(stmt) - def create(self, created_by): + async def create(self, created_by): """ Creates a new API key """ @@ -105,7 +105,7 @@ class Auth(object): secret = secrets.token_urlsafe(32) # Insert the new token into the database - key = self.backend.db.insert( + key = await self.backend.db.insert( APIKey, prefix = prefix, secret = secret, @@ -147,13 +147,13 @@ class APIKey(sqlmodel.SQLModel, database.BackendMixin, table=True): # Deleted By deleted_by : str | None - def check(self, secret): + async def check(self, secret): """ Checks if the provided secret matches """ return secrets.compare_digest(self.secret, secret) - def delete(self, deleted_by): + async def delete(self, deleted_by): """ Deletes the key """ diff --git a/src/dbl/database.py b/src/dbl/database.py index edeeb83..e86045f 100644 --- a/src/dbl/database.py +++ b/src/dbl/database.py @@ -18,9 +18,11 @@ # # ############################################################################### +import asyncio import collections import functools import logging +import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlmodel @@ -35,23 +37,24 @@ class Database(object): self.engine = self.connect(uri) # Create a session maker - self.sessionmaker = sqlalchemy.orm.sessionmaker( + self.sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker( self.engine, expire_on_commit = False, + class_ = sqlalchemy.ext.asyncio.AsyncSession, info = { "backend" : self.backend, }, ) # Session - self.__session = None + self.__sessions = {} def connect(self, uri): """ Connects to the database """ # Create the database engine - engine = sqlmodel.create_engine( + engine = sqlalchemy.ext.asyncio.create_async_engine( uri, # Log more if we are running in debug mode @@ -59,34 +62,94 @@ class Database(object): # Use our own logger logging_name=log.name, + + # Increase the pool size + pool_size=128, ) return engine - def session(self): + async def session(self): """ Returns the current database session """ - if self.__session is None: - self.__session = self.sessionmaker() + # Fetch the current task + task = asyncio.current_task() + + # Check that we have a task + assert task, "Could not determine task" + + # Try returning the same session to the same task + try: + return self.__sessions[task] + except KeyError: + pass + + # Fetch a new session from the engine + session = self.__sessions[task] = self.sessionmaker() + + log.debug("Assigning database session %s to %s" % (session, task)) - return self.__session + # When the task finishes, release the connection + task.add_done_callback(self.release_session) - def __enter__(self): - return self.session() + return session - def __exit__(self, type, exception, traceback): - session = self.session() + def release_session(self, task): + """ + Called when a task that requested a session has finished. + + This method will schedule that the session is being closed. + """ + self.backend.run_task(self.__release_session, task) + async def __release_session(self, task): + """ + Called when a task that had a session assigned has finished + """ + exception = None + + # Retrieve the session + try: + session = self.__sessions[task] + except KeyError: + return + + # Fetch any exception if the task is done + if task.done(): + exception = task.exception() + + # If there is no exception, we can commit if exception is None: - session.commit() + await session.commit() + + log.debug("Releasing database session %s of %s" % (session, task)) + + # Delete it + del self.__sessions[task] + + # Close the session + await session.close() + + # Re-raise the exception + if exception: + raise exception + + async def __aenter__(self): + return await self.session() + + async def __aexit__(self, type, exception, traceback): + session = await self.session() - def transaction(self): + if exception is None: + await session.commit() + + async def transaction(self): """ Opens a new transaction """ # Fetch our session - session = self.session() + session = await self.session() # If we are already in a transaction, begin a nested one if session.in_transaction(): @@ -95,22 +158,22 @@ class Database(object): # Otherwise begin the first transaction of the session return session.begin() - def execute(self, stmt): + async def execute(self, stmt): """ Executes a statement and returns a result object """ # Fetch our session - session = self.session() + session = await self.session() # Execute the statement - return session.execute(stmt) + return await session.execute(stmt) - def insert(self, cls, **kwargs): + async def insert(self, cls, **kwargs): """ Inserts a new object into the database """ # Fetch our session - session = self.session() + session = await self.session() # Create a new object object = cls(**kwargs) @@ -121,11 +184,11 @@ class Database(object): # Return the object return object - def select(self, stmt, batch_size=128): + async def select(self, stmt, batch_size=128): """ Selects custom queries """ - result = self.execute(stmt) + result = await self.execute(stmt) # Process the result in batches if batch_size: @@ -137,11 +200,11 @@ class Database(object): yield row(**object) - def fetch(self, stmt, batch_size=128): + async def fetch(self, stmt, batch_size=128): """ Fetches objects of the given type """ - result = self.execute(stmt) + result = await self.execute(stmt) # Process the result in batches if batch_size: @@ -154,8 +217,8 @@ class Database(object): for object in result.scalars(): yield object - def fetch_one(self, stmt): - result = self.execute(stmt) + async def fetch_one(self, stmt): + result = await self.execute(stmt) # Apply unique filtering result = result.unique() @@ -163,33 +226,67 @@ class Database(object): # Return exactly one object or none, but fail otherwise return result.scalar_one_or_none() - def fetch_as_list(self, *args, **kwargs): + async def fetch_as_list(self, *args, **kwargs): """ Fetches objects and returns them as a list instead of an iterator """ objects = self.fetch(*args, **kwargs) # Return as list - return [o for o in objects] + return [o async for o in objects] - def fetch_as_set(self, *args, **kwargs): + async def fetch_as_set(self, *args, **kwargs): """ Fetches objects and returns them as a set instead of an iterator """ objects = self.fetch(*args, **kwargs) # Return as set - return set([o for o in objects]) + return set([o async for o in objects]) - def commit(self): + async def commit(self): """ Manually triggers a database commit """ # Fetch our session - session = self.session() + session = await self.session() # Commit! - session.commit() + await session.commit() + + async def flush(self, *objects): + """ + Manually triggers a flush + """ + # Fetch our session + session = await self.session() + + # Flush! + await session.flush(objects) + + async def refresh(self, o): + """ + Refreshes the given object + """ + # Fetch our session + session = await self.session() + + # Refresh! + await session.refresh(o) + + async def flush_and_refresh(self, *objects): + """ + Flushes and refreshes in one go. + """ + # Fetch our session + session = await self.session() + + # Flush! + await session.flush(objects) + + # Refresh! + for o in objects: + await session.refresh(o) class BackendMixin: diff --git a/src/dbl/domains.py b/src/dbl/domains.py index be70add..23f069f 100644 --- a/src/dbl/domains.py +++ b/src/dbl/domains.py @@ -34,7 +34,7 @@ class Domains(object): def __init__(self, backend): self.backend = backend - def get_by_name(self, name): + async def get_by_name(self, name): """ Fetch all domain objects that match the name """ @@ -59,7 +59,7 @@ class Domains(object): ) ) - return self.backend.db.fetch_as_set(stmt) + return await self.backend.db.fetch_as_set(stmt) class Domain(sqlmodel.SQLModel, database.BackendMixin, table=True): @@ -78,7 +78,9 @@ class Domain(sqlmodel.SQLModel, database.BackendMixin, table=True): list_id: int = sqlmodel.Field(foreign_key="lists.id") # List - list: "List" = sqlmodel.Relationship() + list: "List" = sqlmodel.Relationship( + sa_relationship_kwargs={ "lazy" : "selectin" }, + ) # Name name: str @@ -87,7 +89,9 @@ class Domain(sqlmodel.SQLModel, database.BackendMixin, table=True): source_id: int = sqlmodel.Field(foreign_key="sources.id") # Source - source: "Source" = sqlmodel.Relationship(back_populates="domains") + source: "Source" = sqlmodel.Relationship( + back_populates="domains", sa_relationship_kwargs={ "lazy" : "selectin" }, + ) # Added At added_at: datetime.datetime = sqlmodel.Field( @@ -165,13 +169,17 @@ class DomainEvent(sqlmodel.SQLModel, table=True): id: int = sqlmodel.Field(primary_key=True, foreign_key="domains.id", exclude=True) # Domain - domain: Domain = sqlmodel.Relationship() + domain: Domain = sqlmodel.Relationship( + sa_relationship_kwargs={ "lazy" : "joined", "innerjoin" : True }, + ) # List ID list_id: int = sqlmodel.Field(foreign_key="lists.id", exclude=True) # List - list: "List" = sqlmodel.Relationship() + list: "List" = sqlmodel.Relationship( + sa_relationship_kwargs={ "lazy" : "joined", "innerjoin" : True }, + ) # List Slug @pydantic.computed_field @@ -187,7 +195,9 @@ class DomainEvent(sqlmodel.SQLModel, table=True): source_id: int | None = sqlmodel.Field(foreign_key="sources.id", exclude=True) # Source - source: "Source" = sqlmodel.Relationship() + source: "Source" = sqlmodel.Relationship( + sa_relationship_kwargs={ "lazy" : "selectin" }, + ) # Source Name @pydantic.computed_field @@ -212,7 +222,9 @@ class DomainEvent(sqlmodel.SQLModel, table=True): report_id: uuid.UUID | None = sqlmodel.Field(foreign_key="reports.id") # Report - report: "Report" = sqlmodel.Relationship() + report: "Report" = sqlmodel.Relationship( + sa_relationship_kwargs={ "lazy" : "selectin" }, + ) # Blocks blocks: int diff --git a/src/dbl/exporters.py b/src/dbl/exporters.py index ac8149f..9e580b8 100644 --- a/src/dbl/exporters.py +++ b/src/dbl/exporters.py @@ -38,23 +38,23 @@ class Exporter(abc.ABC): self.backend = backend self.list = list - def __call__(self, f, **kwargs): + async def __call__(self, f, **kwargs): """ The main entry point to export something with this exporter... """ # Export! with util.Stopwatch(_("Exporting %(name)s using %(exporter)s") % \ { "name" : self.list.name, "exporter" : self.__class__.__name__ }): - self.export(f, **kwargs) + await self.export(f, **kwargs) @abc.abstractmethod - def export(self, f, **kwargs): + async def export(self, f, **kwargs): """ Runs the export """ raise NotImplementedError - def export_to_tarball(self, tarball, name, **kwargs): + async def export_to_tarball(self, tarball, name, **kwargs): """ Exports the list to the tarball using the given exporter """ @@ -64,7 +64,7 @@ class Exporter(abc.ABC): f = io.BytesIO() # Export the data - self(f, **kwargs) + await self(f, **kwargs) # Expand the filename name = name % { @@ -101,12 +101,12 @@ class NullExporter(Exporter): """ Exports nothing, i.e. writes an empty file """ - def export(self, *args, **kwargs): + async def export(self, *args, **kwargs): pass class TextExporter(Exporter): - def __call__(self, f, **kwargs): + async def __call__(self, f, **kwargs): detach = False # Convert any file handles to handle plain text @@ -115,7 +115,7 @@ class TextExporter(Exporter): detach = True # Export! - super().__call__(f, **kwargs) + await super().__call__(f, **kwargs) # Detach the underlying stream. That way, the wrapper won't close # the underlying file handle. @@ -183,12 +183,12 @@ class DomainsExporter(TextExporter): """ Exports the plain domains """ - def export(self, f): + async def export(self, f): # Write the header self.write_header(f) # Write all domains - for domain in self.list.domains: + async for domain in self.list.get_domains(): f.write("%s\n" % domain) @@ -196,12 +196,12 @@ class HostsExporter(TextExporter): """ Exports a file like /etc/hosts """ - def export(self, f): + async def export(self, f): # Write the header self.write_header(f) # Write all domains - for domain in self.list.domains: + async for domain in self.list.get_domains(): f.write("0.0.0.0 %s\n" % domain) @@ -209,7 +209,7 @@ class AdBlockPlusExporter(TextExporter): """ Exports for AdBlock Plus and compatible clients """ - def export(self, f): + async def export(self, f): # Write the format f.write("[Adblock Plus]\n") @@ -217,12 +217,12 @@ class AdBlockPlusExporter(TextExporter): self.write_header(f, "!") # Write all domains - for domain in self.list.domains: + async for domain in self.list.get_domains(): f.write("||%s^\n" % domain) class ZoneExporter(TextExporter): - def export(self, f, ttl=None, primary=None, zonemaster=None, + async def export(self, f, ttl=None, primary=None, zonemaster=None, refresh=None, retry=None, expire=None, nameservers=None): # Write the header self.write_header(f, ";") @@ -357,7 +357,7 @@ class ZoneExporter(TextExporter): f.write("_info IN TXT \"total-domains=%s\"\n" % len(self.list)) # Write all domains - for domain in self.list.domains: + async for domain in self.list.get_domains(): for prefix in ("", "*."): f.write("%s%s IN %s %s\n" % (prefix, domain, self.type, self.content)) @@ -390,7 +390,7 @@ class TarballExporter(Exporter): """ This class contains some helper functions to write data to tarballs """ - def export(self, f): + async def export(self, f): # Accept a tarball object and in that case, simply don't nest them if not isinstance(f, tarfile.TarFile): f = tarfile.open(fileobj=f, mode="w|gz") @@ -401,7 +401,7 @@ class TarballExporter(Exporter): e = exporter(self.backend, self.list) # Export to the tarball - e.export_to_tarball(f, file) + await e.export_to_tarball(f, file) class SquidGuardExporter(TarballExporter): @@ -419,7 +419,7 @@ class SuricataRulesExporter(TextExporter): """ Export domains as a set of rules for Suricata """ - def export(self, f): + async def export(self, f): # Write the header self.write_header(f) @@ -572,11 +572,11 @@ class SuricataDatasetExporter(TextExporter): """ Exports the domains encoded as base64 """ - def export(self, f): + async def export(self, f): # This file cannot have a header because Suricata will try to base64-decode it, too # Write all domains - for domain in self.list.domains: + async for domain in self.list.get_domains(): # Convert the domain to bytes domain = domain.encode() @@ -601,16 +601,12 @@ class MultiExporter(abc.ABC): This is a base class that can export multiple lists at the same time """ - def __init__(self, backend, lists=None): + def __init__(self, backend, lists): self.backend = backend - - if lists is None: - lists = backend.lists - self.lists = lists @abc.abstractmethod - def __call__(self, *args, **kwargs): + async def __call__(self, *args, **kwargs): raise NotImplementedError @property @@ -625,24 +621,24 @@ class CombinedSquidGuardExporter(MultiExporter): """ This is a special exporter which combines all lists into a single tarball """ - def __call__(self, f): + async def __call__(self, f): # Create a tar file with tarfile.open(fileobj=f, mode="w|gz") as tarball: for list in self.lists: exporter = SquidGuardExporter(self.backend, list) - exporter(tarball) + await exporter(tarball) class CombinedSuricataExporter(MultiExporter): """ This is a special exporter which combines all Suricata rulesets into a tarball """ - def __call__(self, f): + async def __call__(self, f): # Create a tar file with tarfile.open(fileobj=f, mode="w|gz") as tarball: for list in self.lists: exporter = SuricataExporter(self.backend, list) - exporter(tarball) + await exporter(tarball) class DirectoryExporter(MultiExporter): @@ -667,13 +663,13 @@ class DirectoryExporter(MultiExporter): "suricata.tar.gz" : CombinedSuricataExporter, } - def __init__(self, backend, root, lists=None): + def __init__(self, backend, root, lists): super().__init__(backend, lists) # Store the root self.root = pathlib.Path(root) - def __call__(self): + async def __call__(self): # Ensure the root directory exists try: self.root.mkdir() @@ -685,15 +681,15 @@ class DirectoryExporter(MultiExporter): # For MultiExporters, we will have to export everything at once if issubclass(exporter, MultiExporter): e = exporter(self.backend, self.lists) - self.export(e, name) + await self.export(e, name) # For regular exporters, we will have to export each list at a time else: for list in self.lists: e = exporter(self.backend, list) - self.export(e, name, list=list) + await self.export(e, name, list=list) - def export(self, exporter, name, **kwargs): + async def export(self, exporter, name, **kwargs): """ This function takes an exporter instance and runs it """ @@ -709,7 +705,7 @@ class DirectoryExporter(MultiExporter): # Create a new temporary file with tempfile.NamedTemporaryFile(dir=path.parent) as f: # Export everthing to the file - exporter(f) + await exporter(f) # Set the modification time (so that clients won't download again # just because we have done a re-export) diff --git a/src/dbl/lists.py b/src/dbl/lists.py index f910496..ec6d8d9 100644 --- a/src/dbl/lists.py +++ b/src/dbl/lists.py @@ -50,7 +50,7 @@ class Lists(object): def __init__(self, backend): self.backend = backend - def __iter__(self): + def __aiter__(self): """ Returns an iterator over all lists """ @@ -70,7 +70,7 @@ class Lists(object): return self.backend.db.fetch(stmt) - def get_by_slug(self, slug): + async def get_by_slug(self, slug): stmt = ( sqlmodel .select( @@ -82,26 +82,26 @@ class Lists(object): ) ) - return self.backend.db.fetch_one(stmt) + return await self.backend.db.fetch_one(stmt) - def _make_slug(self, name): + async def _make_slug(self, name): i = 0 while True: slug = util.slugify(name, i) # Skip if the list already exists - if self.get_by_slug(slug): + if await self.get_by_slug(slug): i += 1 continue return slug - def create(self, name, created_by, license, description=None, priority=None): + async def create(self, name, created_by, license, description=None, priority=None): """ Creates a new list """ - slug = self._make_slug(name) + slug = await self._make_slug(name) # Map priority try: @@ -110,7 +110,7 @@ class Lists(object): raise ValueError("Invalid priority: %s" % priority) from e # Create a new list - return self.backend.db.insert( + return await self.backend.db.insert( List, name = name, slug = slug, @@ -120,7 +120,7 @@ class Lists(object): priority = priority, ) - def get_listing_domain(self, name): + async def get_listing_domain(self, name): """ Returns all lists that currently contain the given domain. """ @@ -149,7 +149,7 @@ class Lists(object): ) ) - return self.backend.db.fetch_as_list(stmt) + return await self.backend.db.fetch_as_list(stmt) class List(sqlmodel.SQLModel, database.BackendMixin, table=True): @@ -224,12 +224,13 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): Source.deleted_at == None, )""", "order_by" : "Source.name, Source.url", + "lazy" : "selectin", }, ) # Delete Source! - def delete_source(self, url, **kwargs): + async def delete_source(self, url, **kwargs): """ Removes a source from the list """ @@ -350,8 +351,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): return listed_domains - @property - def domains(self): + async def get_domains(self): """ Returns all domains that are on this list """ @@ -365,7 +365,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): canary_inserted = False # Walk through all domains and insert the canary - for domain in domains: + async for domain in domains: # If we have not inserted the canary, yet, we will do # it whenever it alphabetically fits if not canary_inserted and domain > self.canary: @@ -379,12 +379,12 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): if not canary_inserted: yield self.canary - def add_domain(self, name, added_by, report=None, block=True): + async def add_domain(self, name, added_by, report=None, block=True): """ Adds a new domain to the list """ # Check if the domain is already listed - domain = self.get_domain(name) + domain = await self.get_domain(name) if domain: # Silently ignore if the domain is already listed if domain.block == block: @@ -398,7 +398,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # Add the domain to the database - domain = self.backend.db.insert( + domain = await self.backend.db.insert( domains.Domain, list = self, name = name, @@ -415,7 +415,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): return domain - def get_domain(self, name): + async def get_domain(self, name): """ Fetches a domain (not including sources) """ @@ -439,9 +439,9 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): ) ) - return self.backend.db.fetch_one(stmt) + return await self.backend.db.fetch_one(stmt) - def get_sources_by_domain(self, name): + async def get_sources_by_domain(self, name): """ Returns all sources that list the given domain """ @@ -472,7 +472,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): ) ) - return self.backend.db.fetch_as_set(stmt) + return await self.backend.db.fetch_as_set(stmt) # Total Domains total_domains: int = 0 @@ -482,7 +482,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): # Delete! - def delete(self, deleted_by): + async def delete(self, deleted_by): """ Deletes the list """ @@ -493,11 +493,11 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): # Update! - def update(self, **kwargs): + async def update(self, **kwargs): """ Updates the list """ - with self.backend.db.transaction(): + with await self.backend.db.transaction(): updated = False # Update all sources @@ -510,12 +510,12 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): self.updated_at = sqlmodel.func.current_timestamp() # Optimize the list - self.optimize(update_stats=False) + await self.optimize(update_stats=False) # Update the stats - self.update_stats() + await self.update_stats() - def update_stats(self): + async def update_stats(self): stmt = ( sqlmodel .select( @@ -526,10 +526,10 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # Store the number of total domains - self.total_domains = self.backend.db.fetch_one(stmt) + self.total_domains = await self.backend.db.fetch_one(stmt) # Store the number of subsumed domains - self.subsumed_domains = self.backend.db.fetch_one( + self.subsumed_domains = await self.backend.db.fetch_one( sqlmodel .select( sqlmodel.func.count(), @@ -550,7 +550,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # Store the stats history - self.backend.db.insert( + await self.backend.db.insert( ListStats, list = self, total_domains = self.total_domains, @@ -559,7 +559,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): # Export! - def export(self, f, format, **kwargs): + async def export(self, f, format, **kwargs): """ Exports the list """ @@ -587,7 +587,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): # Reports reports : typing.List["Report"] = sqlmodel.Relationship(back_populates="list") - def get_reports(self, open=None, name=None, limit=None): + async def get_reports(self, open=None, name=None, limit=None): """ Fetches the most recent reports """ @@ -624,15 +624,15 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): if limit: stmt = stmt.limit(limit) - return self.backend.db.fetch_as_list(stmt) + return await self.backend.db.fetch_as_list(stmt) # Report! - def report(self, **kwargs): + async def report(self, **kwargs): """ Creates a new report for this list """ - return self.backend.reports.create(list=self, **kwargs) + return await self.backend.reports.create(list=self, **kwargs) # Pending Reports @@ -824,14 +824,14 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): return self.backend.db.fetch(stmt) - def optimize(self, update_stats=True): + async def optimize(self, update_stats=True): """ Optimizes this list """ log.info("Optimizing %s..." % self) # Fetch all domains on this list - names = self.backend.db.fetch_as_set( + names = await self.backend.db.fetch_as_set( sqlmodel .select( domains.Domain.name @@ -862,7 +862,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): log.info(_("Identified %s redunduant domain(s)") % len(redundant_names)) # Reset the status for all domains - self.backend.db.execute( + await self.backend.db.execute( sqlmodel .update( domains.Domain, @@ -878,7 +878,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # De-list the redundant domains - self.backend.db.execute( + await self.backend.db.execute( sqlmodel .update( domains.Domain, @@ -897,7 +897,7 @@ class List(sqlmodel.SQLModel, database.BackendMixin, table=True): # Update all stats afterwards if update_stats: - self.update_stats() + await self.update_stats() class ListStats(sqlmodel.SQLModel, table=True): diff --git a/src/dbl/reports.py b/src/dbl/reports.py index 854bd81..534ecca 100644 --- a/src/dbl/reports.py +++ b/src/dbl/reports.py @@ -35,7 +35,7 @@ class Reports(object): def __init__(self, backend): self.backend = backend - def __iter__(self): + def __aiter__(self): stmt = ( sqlmodel .select( @@ -51,7 +51,7 @@ class Reports(object): return self.backend.db.fetch(stmt) - def get_by_id(self, id): + async def get_by_id(self, id): stmt = ( sqlmodel .select( @@ -62,25 +62,28 @@ class Reports(object): ) ) - return self.backend.db.fetch_one(stmt) + return await self.backend.db.fetch_one(stmt) - def create(self, **kwargs): + async def create(self, **kwargs): """ Creates a new report """ - report = self.backend.db.insert( + report = await self.backend.db.insert( Report, **kwargs, ) + # Manifest the object in the database immediately to assign the ID + await self.backend.db.flush_and_refresh(report) + # Increment the counter of the list report.list.pending_reports += 1 # Send a notification to the reporter - report._send_opening_notification() + await report._send_opening_notification() return report - def notify(self): + async def notify(self): """ Notifies moderators about any pending reports """ @@ -102,7 +105,7 @@ class Reports(object): lists = {} # Group reports by list - for report in reports: + async for report in reports: try: lists[report.list].append(report) except KeyError: @@ -190,7 +193,13 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True): list_id : int = sqlmodel.Field(foreign_key="lists.id", exclude=True) # List - list : "List" = sqlmodel.Relationship(back_populates="reports") + list : "List" = sqlmodel.Relationship( + back_populates = "reports", + sa_relationship_kwargs = { + "lazy" : "joined", + "innerjoin" : True, + }, + ) @pydantic.computed_field @property @@ -225,7 +234,7 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True): # Close! - def close(self, closed_by=None, accept=True, update_stats=True): + async def close(self, closed_by=None, accept=True, update_stats=True): """ Called when a moderator has made a decision """ @@ -246,14 +255,14 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True): self.list.pending_reports -= 1 # Send a message to the reporter? - self._send_closing_notification() + await self._send_closing_notification() # We are done if the report has not been accepted if not self.accepted: return # Add the domain to the list (the add function will do the rest) - self.list.add_domain( + await self.list.add_domain( name = self.name, added_by = self.reported_by, report = self, @@ -266,13 +275,13 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True): # Update list stats if update_stats: # Update stats for all sources that list this domain - for source in self.list.get_sources_by_domain(self.name): + for source in await self.list.get_sources_by_domain(self.name): source.update_stats() # Update the list's stats - self.list.update_stats() + await self.list.update_stats() - def _send_opening_notification(self): + async def _send_opening_notification(self): """ Sends a notification to the reporter when this report is opened. """ @@ -309,7 +318,7 @@ class Report(sqlmodel.SQLModel, database.BackendMixin, table=True): "Subject" : _("Your DBL Report Has Been Received"), }) - def _send_closing_notification(self): + async def _send_closing_notification(self): """ Sends a notification to the reporter when this report gets closed. """ diff --git a/src/dbl/sources.py b/src/dbl/sources.py index dd67147..6561782 100644 --- a/src/dbl/sources.py +++ b/src/dbl/sources.py @@ -54,7 +54,7 @@ class Sources(object): def __init__(self, backend): self.backend = backend - def __iter__(self): + def __aiter__(self): stmt = ( sqlmodel .select( @@ -70,7 +70,7 @@ class Sources(object): return self.backend.db.fetch(stmt) - def get_by_id(self, id): + async def get_by_id(self, id): stmt = ( sqlmodel .select( @@ -81,13 +81,13 @@ class Sources(object): ) ) - return self.backend.db.fetch_one(stmt) + return await self.backend.db.fetch_one(stmt) - def create(self, list, name, url, created_by, license): + async def create(self, list, name, url, created_by, license): """ Creates a new source """ - return self.backend.db.insert( + return await self.backend.db.insert( Source, list = list, name = name, @@ -159,7 +159,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): # Delete! - def delete(self, deleted_by): + async def delete(self, deleted_by): """ Deletes the source """ @@ -180,12 +180,12 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): removed_at=sqlmodel.func.current_timestamp(), ) ) - self.backend.db.execute(stmt) + await self.backend.db.execute(stmt) # Log action log.info(_("Source '%s' has been deleted from '%s'") % (self, self.list)) - def update(self, force=False): + async def update(self, force=False): """ Updates this source. """ @@ -202,7 +202,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): log.debug("Forcing update of %s because it has no data" % self) force = True - with self.db.transaction(): + with await self.db.transaction(): with self.backend.client() as client: # Compose some request headers headers = self._make_headers(force=force) @@ -328,7 +328,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): return False # Add all domains to the database - self.add_domains(domains) + await self.add_domains(domains) # The list has now been updated self.updated_at = sqlmodel.func.current_timestamp() @@ -346,10 +346,10 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): return False # Mark all domains that have not been updated as removed - self.__prune() + await self.__prune() # Update the stats - self.update_stats() + await self.update_stats() # Signal that we have actually fetched new data return True @@ -521,7 +521,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): """ return line - def add_domains(self, _domains): + async def add_domains(self, _domains): """ Adds or updates a domain. """ @@ -553,9 +553,10 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): } ) ) - self.backend.db.execute(stmt) - def __prune(self): + await self.backend.db.execute(stmt) + + async def __prune(self): """ Prune any domains that have not been updated. @@ -577,9 +578,9 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): domains.Domain.removed_at == None, ) ) - self.backend.db.execute(stmt) + await self.backend.db.execute(stmt) - def duplicates(self): + async def duplicates(self): """ Finds the number of duplicates against other sources """ @@ -617,7 +618,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # Run the query - sources[source] = self.backend.db.fetch_one(stmt) + sources[source] = await self.backend.db.fetch_one(stmt) return sources @@ -629,7 +630,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): false_positives: int = 0 - def update_stats(self): + async def update_stats(self): """ Updates the stats of this source """ @@ -648,7 +649,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # Store the total number of domains - self.total_domains = self.backend.db.fetch_one(stmt) + self.total_domains = await self.backend.db.fetch_one(stmt) stmt = ( sqlmodel @@ -668,7 +669,7 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # Store the total number of dead domains - self.dead_domains = self.backend.db.fetch_one(stmt) + self.dead_domains = await self.backend.db.fetch_one(stmt) whitelisted_domains = ( sqlmodel @@ -721,10 +722,10 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): ) # Store the number of false positives - self.false_positives = self.backend.db.fetch_one(stmt) + self.false_positives = await self.backend.db.fetch_one(stmt) # Store the stats history - self.backend.db.insert( + await self.backend.db.insert( SourceStats, source = self, total_domains = self.total_domains, diff --git a/src/scripts/dbl.in b/src/scripts/dbl.in index aa4560f..6e2b0aa 100644 --- a/src/scripts/dbl.in +++ b/src/scripts/dbl.in @@ -19,6 +19,7 @@ # # ############################################################################### +import asyncio import argparse import babel.dates import babel.numbers @@ -222,7 +223,7 @@ class CLI(object): return args - def run(self): + async def run(self): # Parse the command line args = self.parse_cli() @@ -239,18 +240,18 @@ class CLI(object): ) # Call the handler function - with backend.db: - ret = args.func(backend, args) + async with backend.db: + ret = await args.func(backend, args) # Exit with the returned error code if ret: - sys.exit(ret) + return ret # Otherwise just exit - sys.exit(0) + return 0 @staticmethod - def terminate(message, code=2): + async def terminate(message, code=2): """ Convenience function to terminate the program gracefully with a message """ @@ -261,20 +262,20 @@ class CLI(object): # Terminate with the given code raise SystemExit(code) - def __get_list(self, backend, slug): + async def __get_list(self, backend, slug): """ Fetches a list or terminates the program if we could not find the list. """ # Fetch the list - list = backend.lists.get_by_slug(slug) + list = await backend.lists.get_by_slug(slug) # Terminate because we could not find the list if not list: - self.terminate("Could not find list '%s'" % slug) + await self.terminate("Could not find list '%s'" % slug) return list - def __list(self, backend, args): + async def __list(self, backend, args): table = rich.table.Table(title=_("Lists")) table.add_column(_("Name")) table.add_column(_("Slug")) @@ -283,7 +284,7 @@ class CLI(object): table.add_column(_("Listed Unique Domains"), justify="right") # Show all lists - for list in backend.lists: + async for list in backend.lists: table.add_row( list.name, list.slug, @@ -295,11 +296,11 @@ class CLI(object): # Print the table self.console.print(table) - def __create(self, backend, args): + async def __create(self, backend, args): """ Creates a new list """ - backend.lists.create( + await backend.lists.create( name = args.name, created_by = args.created_by, license = args.license, @@ -307,22 +308,22 @@ class CLI(object): priority = args.priority, ) - def __delete(self, backend, args): + async def __delete(self, backend, args): """ Deletes a list """ - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) # Delete! - list.delete( + await list.delete( deleted_by = args.deleted_by, ) - def __show(self, backend, args): + async def __show(self, backend, args): """ Shows information about a list """ - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) table = rich.table.Table(title=list.name) table.add_column(_("Property")) @@ -365,50 +366,51 @@ class CLI(object): # Print the sources self.console.print(table) - def __update(self, backend, args): + async def __update(self, backend, args): """ Updates a single list """ # Fetch the list - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) # Update! - list.update(force=args.force) + await list.update(force=args.force) - def __update_all(self, backend, args): + async def __update_all(self, backend, args): """ Updates all lists """ - for list in backend.lists: - list.update(force=args.force) + async for list in backend.lists: + await list.update(force=args.force) - def __export(self, backend, args): + async def __export(self, backend, args): """ Exports a list """ # Fetch the list - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) # Export! - list.export(args.output, format=args.format) + await list.export(args.output, format=args.format) - def __export_all(self, backend, args): + async def __export_all(self, backend, args): """ Exports all lists """ # Launch the DirectoryExporter - exporter = dbl.exporters.DirectoryExporter(backend, root=args.directory) - exporter() + exporter = dbl.exporters.DirectoryExporter(backend, + root=args.directory, lists=[l async for l in backend.lists]) + await exporter() - def __add_source(self, backend, args): + async def __add_source(self, backend, args): """ Adds a new source to a list """ # Fetch the list - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) # Create the source - backend.sources.create( + await backend.sources.create( list = list, name = args.name, url = args.url, @@ -416,25 +418,25 @@ class CLI(object): license = args.license, ) - def __delete_source(self, backend, args): + async def __delete_source(self, backend, args): """ Removes a source from a list """ # Fetch the list - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) # Remove the source - list.delete_source( + await list.delete_source( url = args.url, deleted_by = args.deleted_by, ) - def __search(self, backend, args): + async def __search(self, backend, args): """ Searches for a domain name """ # Check if a domain is active - active = backend.check(args.domain) + active = await backend.check(args.domain) # If the domain is dead, we show a warning if active is False: @@ -445,7 +447,7 @@ class CLI(object): self.console.print(warning) # Search! - lists = backend.search(args.domain) + lists = await backend.search(args.domain) # Do nothing if nothing was found if not lists: @@ -467,17 +469,17 @@ class CLI(object): # Print the table self.console.print(table) - def __analyze(self, backend, args): + async def __analyze(self, backend, args): """ Analyzes a list """ # Fetch the list - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) # Show duplicates - self.__analyze_duplicates(list) + await self.__analyze_duplicates(list) - def __analyze_duplicates(self, list): + async def __analyze_duplicates(self, list): table = rich.table.Table(title=_("Duplication")) table.add_column(_("List")) @@ -489,7 +491,7 @@ class CLI(object): # Check duplicates for source in list.sources: # Determine all duplicates against other sources - duplicates = source.duplicates() + duplicates = await source.duplicates() columns = [] @@ -519,22 +521,22 @@ class CLI(object): # Print the table self.console.print(table) - def __optimize(self, backend, args): + async def __optimize(self, backend, args): """ Optimizes a list """ # Fetch the list - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) with dbl.util.Stopwatch(_("Optimizing %s") % list): - list.optimize() + await list.optimize() - def __history(self, backend, args): + async def __history(self, backend, args): """ Shows the history of a list """ # Fetch the list - list = self.__get_list(backend, args.list) + list = await self.__get_list(backend, args.list) # Fetch the history history = list.get_history(limit=args.limit) @@ -549,7 +551,7 @@ class CLI(object): table.add_column(_("Domains Removed")) # Add the history - for events in history: + async for events in history: table.add_row( babel.dates.format_datetime(events.ts), "\n".join(events.domains_blocked), @@ -560,67 +562,70 @@ class CLI(object): # Print the table self.console.print(table) - def __check_domains(self, backend, args): + async def __check_domains(self, backend, args): """ Runs the checker over all domains """ checker = dbl.checker.Checker(backend) - checker.check(*args.domain) + await checker.check(*args.domain) # Authentication - def __create_api_key(self, backend, args): + async def __create_api_key(self, backend, args): """ Creates a new API key """ - key = backend.auth.create(created_by=args.created_by) + key = await backend.auth.create(created_by=args.created_by) # Show the new key print(_("Your new API key has been created: %s") % key) - def __delete_api_key(self, backend, args): + async def __delete_api_key(self, backend, args): """ Creates a new API key """ - key = backend.auth.get_key(args.key) + key = await backend.auth.get_key(args.key) # If we could not find a key, we cannot delete it if not key: - self.terminate("Could not find key %s" % args.key) + await self.terminate("Could not find key %s" % args.key) # Delete the key - key.delete(deleted_by=args.deleted_by) + await key.delete(deleted_by=args.deleted_by) # Reports - def __close_report(self, backend, args): + async def __close_report(self, backend, args): """ Closes a report """ - report = backend.reports.get_by_id(args.id) + report = await backend.reports.get_by_id(args.id) # Fail if we cannot find the report if not report: - self.terminate("Could not find report %s" % args.id) + await self.terminate("Could not find report %s" % args.id) # Close the report - report.close( + await report.close( closed_by = args.closed_by, accept = args.accept and not args.reject, ) # Notify - def __notify(self, backend, args): + async def __notify(self, backend, args): """ Notifies moderators about any pending reports """ - backend.reports.notify() + await backend.reports.notify() def main(): c = CLI() - c.run() + ret = asyncio.run(c.run()) + + # Terminate passing the return code + sys.exit(ret) if __name__ == "__main__": main()