]> git.ipfire.org Git - dbl.git/commitdiff
dnsbl: Add an analyze command to show duplicates in lists
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 10 Dec 2025 15:46:02 +0000 (15:46 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 10 Dec 2025 15:46:02 +0000 (15:46 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/dnsbl/sources.py
src/scripts/dnsbl.in

index 88c99fcdf825968db61c31d9a7d360c44b6ac5d1..e467101f5baa508fb95f4fcd19807409e527634e 100644 (file)
@@ -93,6 +93,13 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
 
                return self.backend.db.fetch_one(stmt)
 
+       def __hash__(self):
+               # Only hashable once the object has an ID
+               if self.id is None:
+                       raise TypeError("Cannot hash Source objects before they are persisted and have an ID")
+
+               return hash(self.id)
+
        # ID
        id : int = sqlmodel.Field(primary_key=True)
 
@@ -348,6 +355,48 @@ class Source(sqlmodel.SQLModel, database.BackendMixin, table=True):
                )
                self.backend.db.execute(stmt)
 
+       def duplicates(self):
+               """
+                       Finds the number of duplicates against other sources
+               """
+               sources = {}
+
+               for source in self.list.sources:
+                       # Don't compare against ourselves
+                       if source == self:
+                               continue
+
+                       domains_self  = sqlalchemy.orm.aliased(SourceDomain)
+                       domains_other = sqlalchemy.orm.aliased(SourceDomain)
+
+                       stmt = (
+                               sqlmodel
+                               .select(
+                                       sqlmodel.func.count(),
+                               )
+                               .select_from(
+                                       domains_self,
+                               )
+                               .join(
+                                       domains_other,
+                                       domains_other.name == domains_self.name,
+                               )
+                               .where(
+                                       # Select the right sources
+                                       domains_self.source == self,
+                                       domains_other.source == source,
+
+                                       # Domains cannot have been removed
+                                       domains_self.removed_at == None,
+                                       domains_other.removed_at == None,
+                               )
+                       )
+
+                       # Run the query
+                       sources[source] = self.backend.db.fetch_one(stmt)
+
+               return sources
+
 
 class SourceDomain(sqlmodel.SQLModel, database.BackendMixin, table=True):
        __tablename__ = "source_domains"
index 6a90e0f19255e4fcf2c82e7551c2d9c7a8bf255d..d87cda70a16049367077082e1b139c5289834324 100644 (file)
@@ -144,6 +144,11 @@ class CLI(object):
                search.add_argument("domain", help=_("The domain name"))
                search.set_defaults(func=self.__search)
 
+               # analyze
+               analyze = subparsers.add_parser("analyze", help=_("Analyzes a list"))
+               analyze.add_argument("list", help=_("The name of the list"))
+               analyze.set_defaults(func=self.__analyze)
+
                # Parse all arguments
                args = parser.parse_args()
 
@@ -415,6 +420,50 @@ class CLI(object):
                # Print the table
                self.console.print(table)
 
+       def __analyze(self, backend, args):
+               """
+                       Analyzes a list
+               """
+               # Fetch the list
+               list = backend.lists.get_by_slug(args.list)
+
+               # Show duplicates
+               self.__analyze_duplicates(list)
+
+       def __analyze_duplicates(self, list):
+               table = rich.table.Table(title=_("Duplication"))
+
+               table.add_column(_("List"))
+
+               # Add all columns
+               for source in list.sources:
+                       table.add_column(source.name, justify="right")
+
+               # Check duplicates
+               for source in list.sources:
+                       # Determine all duplicates against other sources
+                       duplicates = source.duplicates()
+
+                       columns = []
+
+                       # Format the values for the table
+                       for other in list.sources:
+                               try:
+                                       value = duplicates[other]
+                               except KeyError:
+                                       columns.append("")
+                                       continue
+
+                               columns.append(
+                                       "%.2f%%" % (value / len(source) * 100),
+                               )
+
+                       # Add a row to the table
+                       table.add_row(source.name, *columns)
+
+               # Print the table
+               self.console.print(table)
+
 
 def main():
        c = CLI()