# #
###############################################################################
+import asyncio
import configparser
import functools
import httpx
# Put everything back together again
return "%s[.]%s" % (name, tld)
+ # Functions to run something in background
+
+ # A list of any background tasks
+ __tasks = set()
+
+ def run_task(self, callback, *args, **kwargs):
+ """
+ Runs the given coroutine in the background
+ """
+ # Create a new task
+ task = asyncio.create_task(callback(*args, **kwargs))
+
+ # Keep a reference to the task and remove it when the task has finished
+ self.__tasks.add(task)
+ task.add_done_callback(self.__tasks.discard)
+
+ return task
+
@functools.cached_property
def auth(self):
return auth.Auth(self)
return conn
- def search(self, name):
+ async def search(self, name):
"""
Searches for a domain
"""
res = {}
# Group all matches by list
- for domain in self.db.fetch(stmt):
+ async for domain in self.db.fetch(stmt):
try:
res[domain.list].append(domain.name)
except KeyError:
domain = matches.pop()
# Fetch the history of the longest match
- for history in list.get_domain_history(domain, limit=1):
+ async for history in list.get_domain_history(domain, limit=1):
res[list] = history
break
return res
- def check(self, domain):
+ async def check(self, domain):
"""
Returns the status of a domain
"""
)
)
- return not self.db.fetch_one(stmt)
+ return not await self.db.fetch_one(stmt)
# Search
@app.get("/search")
-def search(q: str):
+async def search(q: str):
"""
Performs a simple search
"""
raise fastapi.HTTPException(400, "Not a valid FQDN: %s" % q)
# Perform the search
- results = backend.search(q)
+ results = await backend.search(q)
# Format the result
for list, event in results.items():
tags=["Lists"],
)
-def get_list_from_path(list: str = fastapi.Path(...)) -> lists.List:
+async def get_list_from_path(list: str = fastapi.Path(...)) -> lists.List:
"""
Fetches a list by its slug
"""
- return backend.lists.get_by_slug(list)
+ return await backend.lists.get_by_slug(list)
@router.get("")
-def get_lists() -> typing.List[lists.List]:
- return [l for l in backend.lists]
+async def get_lists() -> typing.List[lists.List]:
+ return [l async for l in backend.lists]
@router.get("/{list}")
-def get_list(list = fastapi.Depends(get_list_from_path)) -> lists.List:
+async def get_list(list = fastapi.Depends(get_list_from_path)) -> lists.List:
if not list:
raise fastapi.HTTPException(404, "Could not find list")
return list
@router.get("/{list}/history")
-def get_list_history(
+async def get_list_history(
list = fastapi.Depends(get_list_from_path),
before: datetime.datetime = None,
limit: int = 10,
"domains_blocked" : e.domains_blocked,
"domains_allowed" : e.domains_allowed,
"domains_removed" : e.domains_removed,
- } for e in list.get_history(before=before, limit=limit)
+ } async for e in list.get_history(before=before, limit=limit)
]
@router.get("/{list}/sources")
-def get_list_sources(list = fastapi.Depends(get_list_from_path)) -> typing.List[sources.Source]:
+async def get_list_sources(list = fastapi.Depends(get_list_from_path)) -> typing.List[sources.Source]:
return list.sources
@router.get("/{list}/reports")
-def get_list_reports(
+async def get_list_reports(
list = fastapi.Depends(get_list_from_path),
open: bool | None = None,
name: str | None = None,
limit: int | None = None
) -> typing.List[reports.Report]:
- return list.get_reports(open=open, name=name, limit=limit)
+ return await list.get_reports(open=open, name=name, limit=limit)
@router.get("/{list}/domains/{name}")
-def get_list_domains(
+async def get_list_domains(
name: str, list = fastapi.Depends(get_list_from_path),
) -> typing.List[domains.DomainEvent]:
# Fetch the domain history
- return list.get_domain_history(name)
+ return [e async for e in list.get_domain_history(name)]
class CreateReport(pydantic.BaseModel):
@router.post("/{list}/reports")
-def list_report(
+async def list_report(
report: CreateReport,
auth = fastapi.Depends(require_api_key),
list = fastapi.Depends(get_list_from_path),
) -> reports.Report:
- # Create a new report
- with backend.db:
- return list.report(
- name = report.name,
- reported_by = report.reported_by,
- comment = report.comment,
- block = report.block,
- )
+ return await list.report(
+ name = report.name,
+ reported_by = report.reported_by,
+ comment = report.comment,
+ block = report.block,
+ )
# Include our endpoints
app.include_router(router)
backend = request.app.state.backend
# Acquire a new database session and process the request
- with backend.db.session() as session:
+ async with await backend.db.session() as session:
return await call_next(request)
tags=["Reports"],
)
-def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report:
+async def get_report_from_path(id: uuid.UUID = fastapi.Path(...)) -> reports.Report:
"""
Fetches a report by its ID
"""
# Fetch the report
- report = backend.reports.get_by_id(id)
+ report = await backend.reports.get_by_id(id)
# Raise 404 if we could not find the report
if not report:
return report
@router.get("")
-def get_reports() -> typing.List[reports.Report]:
- return [l for l in backend.reports]
+async def get_reports() -> typing.List[reports.Report]:
+ return [l async for l in backend.reports]
@router.get("/{id}")
-def get_report(report = fastapi.Depends(get_report_from_path)) -> reports.Report:
+async def get_report(report = fastapi.Depends(get_report_from_path)) -> reports.Report:
return report
class CloseReport(pydantic.BaseModel):
accept: bool = True
@router.post("/{id}/close")
-def close_report(
+async def close_report(
data: CloseReport,
report: reports.Report = fastapi.Depends(get_report_from_path),
) -> fastapi.Response:
- # Close the report
- with backend.db:
- report.close(
- closed_by = data.closed_by,
- accept = data.accept,
- )
+ await report.close(
+ closed_by = data.closed_by,
+ accept = data.accept,
+ )
# Send 204
return fastapi.Response(status_code=fastapi.status.HTTP_204_NO_CONTENT)
def __init__(self, backend):
self.backend = backend
- def __call__(self, key):
+ async def __call__(self, key):
"""
The main authentication function which takes an API key
and checks it against the database.
"""
- key = self.get_key(key)
+ key = await self.get_key(key)
# If we have found a key, we have successfully authenticated
if key:
return False
- def get_key(self, key):
+ async def get_key(self, key):
"""
Fetches a specific key
"""
keys = self.get_keys(prefix)
# Check all keys
- for key in keys:
+ async for key in keys:
# Return True if the key matches the secret
if key.check(secret):
return key
)
)
- return self.backend.db.fetch_as_list(stmt)
+ return self.backend.db.fetch(stmt)
- def create(self, created_by):
+ async def create(self, created_by):
"""
Creates a new API key
"""
secret = secrets.token_urlsafe(32)
# Insert the new token into the database
- key = self.backend.db.insert(
+ key = await self.backend.db.insert(
APIKey,
prefix = prefix,
secret = secret,
# Deleted By
deleted_by : str | None
- def check(self, secret):
+ async def check(self, secret):
"""
Checks if the provided secret matches
"""
return secrets.compare_digest(self.secret, secret)
- def delete(self, deleted_by):
+ async def delete(self, deleted_by):
"""
Deletes the key
"""
# #
###############################################################################
+import asyncio
import collections
import functools
import logging
+import sqlalchemy.ext.asyncio
import sqlalchemy.orm
import sqlmodel
self.engine = self.connect(uri)
# Create a session maker
- self.sessionmaker = sqlalchemy.orm.sessionmaker(
+ self.sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker(
self.engine,
expire_on_commit = False,
+ class_ = sqlalchemy.ext.asyncio.AsyncSession,
info = {
"backend" : self.backend,
},
)
# Session
- self.__session = None
+ self.__sessions = {}
def connect(self, uri):
"""
Connects to the database
"""
# Create the database engine
- engine = sqlmodel.create_engine(
+ engine = sqlalchemy.ext.asyncio.create_async_engine(
uri,
# Log more if we are running in debug mode
# Use our own logger
logging_name=log.name,
+
+ # Increase the pool size
+ pool_size=128,
)
return engine
- def session(self):
+ async def session(self):
"""
Returns the current database session
"""
- if self.__session is None:
- self.__session = self.sessionmaker()
+ # Fetch the current task
+ task = asyncio.current_task()
+
+ # Check that we have a task
+ assert task, "Could not determine task"
+
+ # Try returning the same session to the same task
+ try:
+ return self.__sessions[task]
+ except KeyError:
+ pass
+
+ # Fetch a new session from the engine
+ session = self.__sessions[task] = self.sessionmaker()
+
+ log.debug("Assigning database session %s to %s" % (session, task))
- return self.__session
+ # When the task finishes, release the connection
+ task.add_done_callback(self.release_session)
- def __enter__(self):
- return self.session()
+ return session
- def __exit__(self, type, exception, traceback):
- session = self.session()
+ def release_session(self, task):
+ """
+ Called when a task that requested a session has finished.
+
+ This method will schedule that the session is being closed.
+ """
+ self.backend.run_task(self.__release_session, task)
+ async def __release_session(self, task):
+ """
+ Called when a task that had a session assigned has finished
+ """
+ exception = None
+
+ # Retrieve the session
+ try:
+ session = self.__sessions[task]
+ except KeyError:
+ return
+
+ # Fetch any exception if the task is done
+ if task.done():
+ exception = task.exception()
+
+ # If there is no exception, we can commit
if exception is None:
- session.commit()
+ await session.commit()
+
+ log.debug("Releasing database session %s of %s" % (session, task))
+
+ # Delete it
+ del self.__sessions[task]
+
+ # Close the session
+ await session.close()
+
+ # Re-raise the exception
+ if exception:
+ raise exception
+
+ async def __aenter__(self):
+ return await self.session()
+
+ async def __aexit__(self, type, exception, traceback):
+ session = await self.session()
- def transaction(self):
+ if exception is None:
+ await session.commit()
+
+ async def transaction(self):
"""
Opens a new transaction
"""
# Fetch our session
- session = self.session()
+ session = await self.session()
# If we are already in a transaction, begin a nested one
if session.in_transaction():
# Otherwise begin the first transaction of the session
return session.begin()
- def execute(self, stmt):
+ async def execute(self, stmt):
"""
Executes a statement and returns a result object
"""
# Fetch our session
- session = self.session()
+ session = await self.session()
# Execute the statement
- return session.execute(stmt)
+ return await session.execute(stmt)
- def insert(self, cls, **kwargs):
+ async def insert(self, cls, **kwargs):
"""
Inserts a new object into the database
"""
# Fetch our session
- session = self.session()
+ session = await self.session()
# Create a new object
object = cls(**kwargs)
# Return the object
return object
- def select(self, stmt, batch_size=128):
+ async def select(self, stmt, batch_size=128):
"""
Selects custom queries
"""
- result = self.execute(stmt)
+ result = await self.execute(stmt)
# Process the result in batches
if batch_size:
yield row(**object)
- def fetch(self, stmt, batch_size=128):
+ async def fetch(self, stmt, batch_size=128):
"""
Fetches objects of the given type
"""
- result = self.execute(stmt)
+ result = await self.execute(stmt)
# Process the result in batches
if batch_size:
for object in result.scalars():
yield object
- def fetch_one(self, stmt):
- result = self.execute(stmt)
+ async def fetch_one(self, stmt):
+ result = await self.execute(stmt)
# Apply unique filtering
result = result.unique()
# Return exactly one object or none, but fail otherwise
return result.scalar_one_or_none()
- def fetch_as_list(self, *args, **kwargs):
+ async def fetch_as_list(self, *args, **kwargs):
"""
Fetches objects and returns them as a list instead of an iterator
"""
objects = self.fetch(*args, **kwargs)
# Return as list
- return [o for o in objects]
+ return [o async for o in objects]
- def fetch_as_set(self, *args, **kwargs):
+ async def fetch_as_set(self, *args, **kwargs):
"""
Fetches objects and returns them as a set instead of an iterator
"""
objects = self.fetch(*args, **kwargs)
# Return as set
- return set([o for o in objects])
+ return set([o async for o in objects])
- def commit(self):
+ async def commit(self):
"""
Manually triggers a database commit
"""
# Fetch our session
- session = self.session()
+ session = await self.session()
# Commit!
- session.commit()
+ await session.commit()
+
+ async def flush(self, *objects):
+ """
+ Manually triggers a flush
+ """
+ # Fetch our session
+ session = await self.session()
+
+ # Flush!
+ await session.flush(objects)
+
+ async def refresh(self, o):
+ """
+ Refreshes the given object
+ """
+ # Fetch our session
+ session = await self.session()
+
+ # Refresh!
+ await session.refresh(o)
+
+ async def flush_and_refresh(self, *objects):
+ """
+ Flushes and refreshes in one go.
+ """
+ # Fetch our session
+ session = await self.session()
+
+ # Flush!
+ await session.flush(objects)
+
+ # Refresh!
+ for o in objects:
+ await session.refresh(o)
class BackendMixin:
def __init__(self, backend):
self.backend = backend
- def get_by_name(self, name):
+ async def get_by_name(self, name):
"""
Fetch all domain objects that match the name
"""
)
)
- return self.backend.db.fetch_as_set(stmt)
+ return await self.backend.db.fetch_as_set(stmt)
class Domain(sqlmodel.SQLModel, database.BackendMixin, table=True):
list_id: int = sqlmodel.Field(foreign_key="lists.id")
# List
- list: "List" = sqlmodel.Relationship()
+ list: "List" = sqlmodel.Relationship(
+ sa_relationship_kwargs={ "lazy" : "selectin" },
+ )
# Name
name: str
source_id: int = sqlmodel.Field(foreign_key="sources.id")
# Source
- source: "Source" = sqlmodel.Relationship(back_populates="domains")
+ source: "Source" = sqlmodel.Relationship(
+ back_populates="domains", sa_relationship_kwargs={ "lazy" : "selectin" },
+ )
# Added At
added_at: datetime.datetime = sqlmodel.Field(
id: int = sqlmodel.Field(primary_key=True, foreign_key="domains.id", exclude=True)
# Domain
- domain: Domain = sqlmodel.Relationship()
+ domain: Domain = sqlmodel.Relationship(
+ sa_relationship_kwargs={ "lazy" : "joined", "innerjoin" : True },
+ )
# List ID
list_id: int = sqlmodel.Field(foreign_key="lists.id", exclude=True)
# List
- list: "List" = sqlmodel.Relationship()
+ list: "List" = sqlmodel.Relationship(
+ sa_relationship_kwargs={ "lazy" : "joined", "innerjoin" : True },
+ )
# List Slug
@pydantic.computed_field
source_id: int | None = sqlmodel.Field(foreign_key="sources.id", exclude=True)
# Source
- source: "Source" = sqlmodel.Relationship()
+ source: "Source" = sqlmodel.Relationship(
+ sa_relationship_kwargs={ "lazy" : "selectin" },
+ )
# Source Name
@pydantic.computed_field
report_id: uuid.UUID | None = sqlmodel.Field(foreign_key="reports.id")
# Report
- report: "Report" = sqlmodel.Relationship()
+ report: "Report" = sqlmodel.Relationship(
+ sa_relationship_kwargs={ "lazy" : "selectin" },
+ )
# Blocks
blocks: int
self.backend = backend
self.list = list
- def __call__(self, f, **kwargs):
+ async def __call__(self, f, **kwargs):
"""
The main entry point to export something with this exporter...
"""
# Export!
with util.Stopwatch(_("Exporting %(name)s using %(exporter)s") % \
{ "name" : self.list.name, "exporter" : self.__class__.__name__ }):
- self.export(f, **kwargs)
+ await self.export(f, **kwargs)
@abc.abstractmethod
- def export(self, f, **kwargs):
+ async def export(self, f, **kwargs):
"""
Runs the export
"""
raise NotImplementedError
- def export_to_tarball(self, tarball, name, **kwargs):
+ async def export_to_tarball(self, tarball, name, **kwargs):
"""
Exports the list to the tarball using the given exporter
"""
f = io.BytesIO()
# Export the data
- self(f, **kwargs)
+ await self(f, **kwargs)
# Expand the filename
name = name % {
"""
Exports nothing, i.e. writes an empty file
"""
- def export(self, *args, **kwargs):
+ async def export(self, *args, **kwargs):
pass
class TextExporter(Exporter):
- def __call__(self, f, **kwargs):
+ async def __call__(self, f, **kwargs):
detach = False
# Convert any file handles to handle plain text
detach = True
# Export!
- super().__call__(f, **kwargs)
+ await super().__call__(f, **kwargs)
# Detach the underlying stream. That way, the wrapper won't close
# the underlying file handle.
"""
Exports the plain domains
"""
- def export(self, f):
+ async def export(self, f):
# Write the header
self.write_header(f)
# Write all domains
- for domain in self.list.domains:
+ async for domain in self.list.get_domains():
f.write("%s\n" % domain)
"""
Exports a file like /etc/hosts
"""
- def export(self, f):
+ async def export(self, f):
# Write the header
self.write_header(f)
# Write all domains
- for domain in self.list.domains:
+ async for domain in self.list.get_domains():
f.write("0.0.0.0 %s\n" % domain)
"""
Exports for AdBlock Plus and compatible clients
"""
- def export(self, f):
+ async def export(self, f):
# Write the format
f.write("[Adblock Plus]\n")
self.write_header(f, "!")
# Write all domains
- for domain in self.list.domains:
+ async for domain in self.list.get_domains():
f.write("||%s^\n" % domain)
class ZoneExporter(TextExporter):
- def export(self, f, ttl=None, primary=None, zonemaster=None,
+ async def export(self, f, ttl=None, primary=None, zonemaster=None,
refresh=None, retry=None, expire=None, nameservers=None):
# Write the header
self.write_header(f, ";")
f.write("_info IN TXT \"total-domains=%s\"\n" % len(self.list))
# Write all domains
- for domain in self.list.domains:
+ async for domain in self.list.get_domains():
for prefix in ("", "*."):
f.write("%s%s IN %s %s\n" % (prefix, domain, self.type, self.content))
"""
This class contains some helper functions to write data to tarballs
"""
- def export(self, f):
+ async def export(self, f):
# Accept a tarball object and in that case, simply don't nest them
if not isinstance(f, tarfile.TarFile):
f = tarfile.open(fileobj=f, mode="w|gz")
e = exporter(self.backend, self.list)
# Export to the tarball
- e.export_to_tarball(f, file)
+ await e.export_to_tarball(f, file)
class SquidGuardExporter(TarballExporter):
"""
Export domains as a set of rules for Suricata
"""
- def export(self, f):
+ async def export(self, f):
# Write the header
self.write_header(f)
"""
Exports the domains encoded as base64
"""
- def export(self, f):
+ async def export(self, f):
# This file cannot have a header because Suricata will try to base64-decode it, too
# Write all domains
- for domain in self.list.domains:
+ async for domain in self.list.get_domains():
# Convert the domain to bytes
domain = domain.encode()
This is a base class that can export multiple lists at the same time
"""
- def __init__(self, backend, lists=None):
+ def __init__(self, backend, lists):
self.backend = backend
-
- if lists is None:
- lists = backend.lists
-
self.lists = lists
@abc.abstractmethod
- def __call__(self, *args, **kwargs):
+ async def __call__(self, *args, **kwargs):
raise NotImplementedError
@property
"""
This is a special exporter which combines all lists into a single tarball
"""
- def __call__(self, f):
+ async def __call__(self, f):
# Create a tar file
with tarfile.open(fileobj=f, mode="w|gz") as tarball:
for list in self.lists:
exporter = SquidGuardExporter(self.backend, list)
- exporter(tarball)
+ await exporter(tarball)
class CombinedSuricataExporter(MultiExporter):
"""
This is a special exporter which combines all Suricata rulesets into a tarball
"""
- def __call__(self, f):
+ async def __call__(self, f):
# Create a tar file
with tarfile.open(fileobj=f, mode="w|gz") as tarball:
for list in self.lists:
exporter = SuricataExporter(self.backend, list)
- exporter(tarball)
+ await exporter(tarball)
class DirectoryExporter(MultiExporter):
"suricata.tar.gz" : CombinedSuricataExporter,
}
- def __init__(self, backend, root, lists=None):
+ def __init__(self, backend, root, lists):
super().__init__(backend, lists)
# Store the root
self.root = pathlib.Path(root)
- def __call__(self):
+ async def __call__(self):
# Ensure the root directory exists
try:
self.root.mkdir()
# For MultiExporters, we will have to export everything at once
if issubclass(exporter, MultiExporter):
e = exporter(self.backend, self.lists)
- self.export(e, name)
+ await self.export(e, name)
# For regular exporters, we will have to export each list at a time
else:
for list in self.lists:
e = exporter(self.backend, list)
- self.export(e, name, list=list)
+ await self.export(e, name, list=list)
- def export(self, exporter, name, **kwargs):
+ async def export(self, exporter, name, **kwargs):
"""
This function takes an exporter instance and runs it
"""
# Create a new temporary file
with tempfile.NamedTemporaryFile(dir=path.parent) as f:
# Export everthing to the file
- exporter(f)
+ await exporter(f)
# Set the modification time (so that clients won't download again
# just because we have done a re-export)
def __init__(self, backend):
self.backend = backend
- def __iter__(self):
+ def __aiter__(self):
"""
Returns an iterator over all lists
"""
return self.backend.db.fetch(stmt)
- def get_by_slug(self, slug):
+ async def get_by_slug(self, slug):
stmt = (
sqlmodel
.select(
)
)
- return self.backend.db.fetch_one(stmt)
+ return await self.backend.db.fetch_one(stmt)
- def _make_slug(self, name):
+ async def _make_slug(self, name):
i = 0
while True:
slug = util.slugify(name, i)
# Skip if the list already exists
- if self.get_by_slug(slug):
+ if await self.get_by_slug(slug):
i += 1
continue
return slug
- def create(self, name, created_by, license, description=None, priority=None):
+ async def create(self, name, created_by, license, description=None, priority=None):
"""
Creates a new list
"""
- slug = self._make_slug(name)
+ slug = await self._make_slug(name)
# Map priority
try:
raise ValueError("Invalid priority: %s" % priority) from e
# Create a new list
- return self.backend.db.insert(
+ return await self.backend.db.insert(
List,
name = name,
slug = slug,
priority = priority,
)
- def get_listing_domain(self, name):
+ async def get_listing_domain(self, name):
"""
Returns all lists that currently contain the given domain.
"""
)
)
- return self.backend.db.fetch_as_list(stmt)
+ return await self.backend.db.fetch_as_list(stmt)
class List(sqlmodel.SQLModel, database.BackendMixin, table=True):
Source.deleted_at == None,
)""",
"order_by" : "Source.name, Source.url",
+ "lazy" : "selectin",
},
)
# Delete Source!
- def delete_source(self, url, **kwargs):
+ async def delete_source(self, url, **kwargs):
"""
Removes a source from the list
"""
return listed_domains
- @property
- def domains(self):
+ async def get_domains(self):
"""
Returns all domains that are on this list
"""
canary_inserted = False
# Walk through all domains and insert the canary
- for domain in domains:
+ async for domain in domains:
# If we have not inserted the canary, yet, we will do
# it whenever it alphabetically fits
if not canary_inserted and domain > self.canary:
if not canary_inserted:
yield self.canary
- def add_domain(self, name, added_by, report=None, block=True):
+ async def add_domain(self, name, added_by, report=None, block=True):
"""
Adds a new domain to the list
"""
# Check if the domain is already listed
- domain = self.get_domain(name)
+ domain = await self.get_domain(name)
if domain:
# Silently ignore if the domain is already listed
if domain.block == block:
)
# Add the domain to the database
- domain = self.backend.db.insert(
+ domain = await self.backend.db.insert(
domains.Domain,
list = self,
name = name,
return domain
- def get_domain(self, name):
+ async def get_domain(self, name):
"""
Fetches a domain (not including sources)
"""
)
)
- return self.backend.db.fetch_one(stmt)
+ return await self.backend.db.fetch_one(stmt)
- def get_sources_by_domain(self, name):
+ async def get_sources_by_domain(self, name):
"""
Returns all sources that list the given domain
"""
)
)
- return self.backend.db.fetch_as_set(stmt)
+ return await self.backend.db.fetch_as_set(stmt)
# Total Domains
total_domains: int = 0
# Delete!
- def delete(self, deleted_by):
+ async def delete(self, deleted_by):
"""
Deletes the list
"""
# Update!
- def update(self, **kwargs):
+ async def update(self, **kwargs):
"""
Updates the list
"""
- with self.backend.db.transaction():
+ with await self.backend.db.transaction():
updated = False
# Update all sources
self.updated_at = sqlmodel.func.current_timestamp()
# Optimize the list
- self.optimize(update_stats=False)
+ await self.optimize(update_stats=False)
# Update the stats
- self.update_stats()
+ await self.update_stats()
- def update_stats(self):
+ async def update_stats(self):
stmt = (
sqlmodel
.select(
)
# Store the number of total domains
- self.total_domains = self.backend.db.fetch_one(stmt)
+ self.total_domains = await self.backend.db.fetch_one(stmt)
# Store the number of subsumed domains
- self.subsumed_domains = self.backend.db.fetch_one(
+ self.subsumed_domains = await self.backend.db.fetch_one(
sqlmodel
.select(
sqlmodel.func.count(),
)
# Store the stats history
- self.backend.db.insert(
+ await self.backend.db.insert(
ListStats,
list = self,
total_domains = self.total_domains,
# Export!
- def export(self, f, format, **kwargs):
+ async def export(self, f, format, **kwargs):
"""
Exports the list
"""
# Reports
reports : typing.List["Report"] = sqlmodel.Relationship(back_populates="list")
- def get_reports(self, open=None, name=None, limit=None):
+ async def get_reports(self, open=None, name=None, limit=None):
"""
Fetches the most recent reports
"""
if limit:
stmt = stmt.limit(limit)
- return self.backend.db.fetch_as_list(stmt)
+ return await self.backend.db.fetch_as_list(stmt)
# Report!
- def report(self, **kwargs):
+ async def report(self, **kwargs):
"""
Creates a new report for this list
"""
- return self.backend.reports.create(list=self, **kwargs)
+ return await self.backend.reports.create(list=self, **kwargs)
# Pending Reports
return self.backend.db.fetch(stmt)
- def optimize(self, update_stats=True):
+ async def optimize(self, update_stats=True):
"""
Optimizes this list
"""
log.info("Optimizing %s..." % self)
# Fetch all domains on this list
- names = self.backend.db.fetch_as_set(
+ names = await self.backend.db.fetch_as_set(
sqlmodel
.select(
domains.Domain.name
log.info(_("Identified %s redunduant domain(s)") % len(redundant_names))
# Reset the status for all domains
- self.backend.db.execute(
+ await self.backend.db.execute(
sqlmodel
.update(
domains.Domain,
)
# De-list the redundant domains
- self.backend.db.execute(
+ await self.backend.db.execute(
sqlmodel
.update(
domains.Domain,
# Update all stats afterwards
if update_stats:
- self.update_stats()
+ await self.update_stats()
class ListStats(sqlmodel.SQLModel, table=True):
def __init__(self, backend):
self.backend = backend
- def __iter__(self):
+ def __aiter__(self):
stmt = (
sqlmodel
.select(
return self.backend.db.fetch(stmt)
- def get_by_id(self, id):
+ async def get_by_id(self, id):
stmt = (
sqlmodel
.select(
)
)
- return self.backend.db.fetch_one(stmt)
+ return await self.backend.db.fetch_one(stmt)
- def create(self, **kwargs):
+ async def create(self, **kwargs):
"""
Creates a new report
"""
- report = self.backend.db.insert(
+ report = await self.backend.db.insert(
Report, **kwargs,
)
+ # Manifest the object in the database immediately to assign the ID
+ await self.backend.db.flush_and_refresh(report)
+
# Increment the counter of the list
report.list.pending_reports += 1
# Send a notification to the reporter
- report._send_opening_notification()
+ await report._send_opening_notification()
return report
- def notify(self):
+ async def notify(self):
"""
Notifies moderators about any pending reports
"""
lists = {}
# Group reports by list
- for report in reports:
+ async for report in reports:
try:
lists[report.list].append(report)
except KeyError:
list_id : int = sqlmodel.Field(foreign_key="lists.id", exclude=True)
# List
- list : "List" = sqlmodel.Relationship(back_populates="reports")
+ list : "List" = sqlmodel.Relationship(
+ back_populates = "reports",
+ sa_relationship_kwargs = {
+ "lazy" : "joined",
+ "innerjoin" : True,
+ },
+ )
@pydantic.computed_field
@property
# Close!
- def close(self, closed_by=None, accept=True, update_stats=True):
+ async def close(self, closed_by=None, accept=True, update_stats=True):
"""
Called when a moderator has made a decision
"""
self.list.pending_reports -= 1
# Send a message to the reporter?
- self._send_closing_notification()
+ await self._send_closing_notification()
# We are done if the report has not been accepted
if not self.accepted:
return
# Add the domain to the list (the add function will do the rest)
- self.list.add_domain(
+ await self.list.add_domain(
name = self.name,
added_by = self.reported_by,
report = self,
# Update list stats
if update_stats:
# Update stats for all sources that list this domain
- for source in self.list.get_sources_by_domain(self.name):
+ for source in await self.list.get_sources_by_domain(self.name):
source.update_stats()
# Update the list's stats
- self.list.update_stats()
+ await self.list.update_stats()
- def _send_opening_notification(self):
+ async def _send_opening_notification(self):
"""
Sends a notification to the reporter when this report is opened.
"""
"Subject" : _("Your DBL Report Has Been Received"),
})
- def _send_closing_notification(self):
+ async def _send_closing_notification(self):
"""
Sends a notification to the reporter when this report gets closed.
"""
def __init__(self, backend):
self.backend = backend
- def __iter__(self):
+ def __aiter__(self):
stmt = (
sqlmodel
.select(
return self.backend.db.fetch(stmt)
- def get_by_id(self, id):
+ async def get_by_id(self, id):
stmt = (
sqlmodel
.select(
)
)
- return self.backend.db.fetch_one(stmt)
+ return await self.backend.db.fetch_one(stmt)
- def create(self, list, name, url, created_by, license):
+ async def create(self, list, name, url, created_by, license):
"""
Creates a new source
"""
- return self.backend.db.insert(
+ return await self.backend.db.insert(
Source,
list = list,
name = name,
# Delete!
- def delete(self, deleted_by):
+ async def delete(self, deleted_by):
"""
Deletes the source
"""
removed_at=sqlmodel.func.current_timestamp(),
)
)
- self.backend.db.execute(stmt)
+ await self.backend.db.execute(stmt)
# Log action
log.info(_("Source '%s' has been deleted from '%s'") % (self, self.list))
- def update(self, force=False):
+ async def update(self, force=False):
"""
Updates this source.
"""
log.debug("Forcing update of %s because it has no data" % self)
force = True
- with self.db.transaction():
+ with await self.db.transaction():
with self.backend.client() as client:
# Compose some request headers
headers = self._make_headers(force=force)
return False
# Add all domains to the database
- self.add_domains(domains)
+ await self.add_domains(domains)
# The list has now been updated
self.updated_at = sqlmodel.func.current_timestamp()
return False
# Mark all domains that have not been updated as removed
- self.__prune()
+ await self.__prune()
# Update the stats
- self.update_stats()
+ await self.update_stats()
# Signal that we have actually fetched new data
return True
"""
return line
- def add_domains(self, _domains):
+ async def add_domains(self, _domains):
"""
Adds or updates a domain.
"""
}
)
)
- self.backend.db.execute(stmt)
- def __prune(self):
+ await self.backend.db.execute(stmt)
+
+ async def __prune(self):
"""
Prune any domains that have not been updated.
domains.Domain.removed_at == None,
)
)
- self.backend.db.execute(stmt)
+ await self.backend.db.execute(stmt)
- def duplicates(self):
+ async def duplicates(self):
"""
Finds the number of duplicates against other sources
"""
)
# Run the query
- sources[source] = self.backend.db.fetch_one(stmt)
+ sources[source] = await self.backend.db.fetch_one(stmt)
return sources
false_positives: int = 0
- def update_stats(self):
+ async def update_stats(self):
"""
Updates the stats of this source
"""
)
# Store the total number of domains
- self.total_domains = self.backend.db.fetch_one(stmt)
+ self.total_domains = await self.backend.db.fetch_one(stmt)
stmt = (
sqlmodel
)
# Store the total number of dead domains
- self.dead_domains = self.backend.db.fetch_one(stmt)
+ self.dead_domains = await self.backend.db.fetch_one(stmt)
whitelisted_domains = (
sqlmodel
)
# Store the number of false positives
- self.false_positives = self.backend.db.fetch_one(stmt)
+ self.false_positives = await self.backend.db.fetch_one(stmt)
# Store the stats history
- self.backend.db.insert(
+ await self.backend.db.insert(
SourceStats,
source = self,
total_domains = self.total_domains,
# #
###############################################################################
+import asyncio
import argparse
import babel.dates
import babel.numbers
return args
- def run(self):
+ async def run(self):
# Parse the command line
args = self.parse_cli()
)
# Call the handler function
- with backend.db:
- ret = args.func(backend, args)
+ async with backend.db:
+ ret = await args.func(backend, args)
# Exit with the returned error code
if ret:
- sys.exit(ret)
+ return ret
# Otherwise just exit
- sys.exit(0)
+ return 0
@staticmethod
- def terminate(message, code=2):
+ async def terminate(message, code=2):
"""
Convenience function to terminate the program gracefully with a message
"""
# Terminate with the given code
raise SystemExit(code)
- def __get_list(self, backend, slug):
+ async def __get_list(self, backend, slug):
"""
Fetches a list or terminates the program if we could not find the list.
"""
# Fetch the list
- list = backend.lists.get_by_slug(slug)
+ list = await backend.lists.get_by_slug(slug)
# Terminate because we could not find the list
if not list:
- self.terminate("Could not find list '%s'" % slug)
+ await self.terminate("Could not find list '%s'" % slug)
return list
- def __list(self, backend, args):
+ async def __list(self, backend, args):
table = rich.table.Table(title=_("Lists"))
table.add_column(_("Name"))
table.add_column(_("Slug"))
table.add_column(_("Listed Unique Domains"), justify="right")
# Show all lists
- for list in backend.lists:
+ async for list in backend.lists:
table.add_row(
list.name,
list.slug,
# Print the table
self.console.print(table)
- def __create(self, backend, args):
+ async def __create(self, backend, args):
"""
Creates a new list
"""
- backend.lists.create(
+ await backend.lists.create(
name = args.name,
created_by = args.created_by,
license = args.license,
priority = args.priority,
)
- def __delete(self, backend, args):
+ async def __delete(self, backend, args):
"""
Deletes a list
"""
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
# Delete!
- list.delete(
+ await list.delete(
deleted_by = args.deleted_by,
)
- def __show(self, backend, args):
+ async def __show(self, backend, args):
"""
Shows information about a list
"""
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
table = rich.table.Table(title=list.name)
table.add_column(_("Property"))
# Print the sources
self.console.print(table)
- def __update(self, backend, args):
+ async def __update(self, backend, args):
"""
Updates a single list
"""
# Fetch the list
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
# Update!
- list.update(force=args.force)
+ await list.update(force=args.force)
- def __update_all(self, backend, args):
+ async def __update_all(self, backend, args):
"""
Updates all lists
"""
- for list in backend.lists:
- list.update(force=args.force)
+ async for list in backend.lists:
+ await list.update(force=args.force)
- def __export(self, backend, args):
+ async def __export(self, backend, args):
"""
Exports a list
"""
# Fetch the list
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
# Export!
- list.export(args.output, format=args.format)
+ await list.export(args.output, format=args.format)
- def __export_all(self, backend, args):
+ async def __export_all(self, backend, args):
"""
Exports all lists
"""
# Launch the DirectoryExporter
- exporter = dbl.exporters.DirectoryExporter(backend, root=args.directory)
- exporter()
+ exporter = dbl.exporters.DirectoryExporter(backend,
+ root=args.directory, lists=[l async for l in backend.lists])
+ await exporter()
- def __add_source(self, backend, args):
+ async def __add_source(self, backend, args):
"""
Adds a new source to a list
"""
# Fetch the list
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
# Create the source
- backend.sources.create(
+ await backend.sources.create(
list = list,
name = args.name,
url = args.url,
license = args.license,
)
- def __delete_source(self, backend, args):
+ async def __delete_source(self, backend, args):
"""
Removes a source from a list
"""
# Fetch the list
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
# Remove the source
- list.delete_source(
+ await list.delete_source(
url = args.url,
deleted_by = args.deleted_by,
)
- def __search(self, backend, args):
+ async def __search(self, backend, args):
"""
Searches for a domain name
"""
# Check if a domain is active
- active = backend.check(args.domain)
+ active = await backend.check(args.domain)
# If the domain is dead, we show a warning
if active is False:
self.console.print(warning)
# Search!
- lists = backend.search(args.domain)
+ lists = await backend.search(args.domain)
# Do nothing if nothing was found
if not lists:
# Print the table
self.console.print(table)
- def __analyze(self, backend, args):
+ async def __analyze(self, backend, args):
"""
Analyzes a list
"""
# Fetch the list
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
# Show duplicates
- self.__analyze_duplicates(list)
+ await self.__analyze_duplicates(list)
- def __analyze_duplicates(self, list):
+ async def __analyze_duplicates(self, list):
table = rich.table.Table(title=_("Duplication"))
table.add_column(_("List"))
# Check duplicates
for source in list.sources:
# Determine all duplicates against other sources
- duplicates = source.duplicates()
+ duplicates = await source.duplicates()
columns = []
# Print the table
self.console.print(table)
- def __optimize(self, backend, args):
+ async def __optimize(self, backend, args):
"""
Optimizes a list
"""
# Fetch the list
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
with dbl.util.Stopwatch(_("Optimizing %s") % list):
- list.optimize()
+ await list.optimize()
- def __history(self, backend, args):
+ async def __history(self, backend, args):
"""
Shows the history of a list
"""
# Fetch the list
- list = self.__get_list(backend, args.list)
+ list = await self.__get_list(backend, args.list)
# Fetch the history
history = list.get_history(limit=args.limit)
table.add_column(_("Domains Removed"))
# Add the history
- for events in history:
+ async for events in history:
table.add_row(
babel.dates.format_datetime(events.ts),
"\n".join(events.domains_blocked),
# Print the table
self.console.print(table)
- def __check_domains(self, backend, args):
+ async def __check_domains(self, backend, args):
"""
Runs the checker over all domains
"""
checker = dbl.checker.Checker(backend)
- checker.check(*args.domain)
+ await checker.check(*args.domain)
# Authentication
- def __create_api_key(self, backend, args):
+ async def __create_api_key(self, backend, args):
"""
Creates a new API key
"""
- key = backend.auth.create(created_by=args.created_by)
+ key = await backend.auth.create(created_by=args.created_by)
# Show the new key
print(_("Your new API key has been created: %s") % key)
- def __delete_api_key(self, backend, args):
+ async def __delete_api_key(self, backend, args):
"""
Creates a new API key
"""
- key = backend.auth.get_key(args.key)
+ key = await backend.auth.get_key(args.key)
# If we could not find a key, we cannot delete it
if not key:
- self.terminate("Could not find key %s" % args.key)
+ await self.terminate("Could not find key %s" % args.key)
# Delete the key
- key.delete(deleted_by=args.deleted_by)
+ await key.delete(deleted_by=args.deleted_by)
# Reports
- def __close_report(self, backend, args):
+ async def __close_report(self, backend, args):
"""
Closes a report
"""
- report = backend.reports.get_by_id(args.id)
+ report = await backend.reports.get_by_id(args.id)
# Fail if we cannot find the report
if not report:
- self.terminate("Could not find report %s" % args.id)
+ await self.terminate("Could not find report %s" % args.id)
# Close the report
- report.close(
+ await report.close(
closed_by = args.closed_by,
accept = args.accept and not args.reject,
)
# Notify
- def __notify(self, backend, args):
+ async def __notify(self, backend, args):
"""
Notifies moderators about any pending reports
"""
- backend.reports.notify()
+ await backend.reports.notify()
def main():
c = CLI()
- c.run()
+ ret = asyncio.run(c.run())
+
+ # Terminate passing the return code
+ sys.exit(ret)
if __name__ == "__main__":
main()