]> git.ipfire.org Git - pbs.git/commitdiff
uploads: Migrate to SQLModel
authorMichael Tremer <michael.tremer@ipfire.org>
Sun, 15 Jun 2025 15:21:21 +0000 (15:21 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sun, 15 Jun 2025 15:21:21 +0000 (15:21 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/uploads.py

index e8efd51c6a5a861fac5255b7e9261bb0320e93b2..19c1bacdd27b28c968119c3c069a5b83f5ef409b 100644 (file)
@@ -1,6 +1,7 @@
 #!/usr/bin/python3
 
 import asyncio
+import datetime
 import hashlib
 import hmac
 import logging
@@ -9,8 +10,8 @@ import shutil
 import sqlalchemy
 import tempfile
 
-from sqlalchemy import Column, ForeignKey
-from sqlalchemy import BigInteger, DateTime, Integer, LargeBinary, Text, UUID
+import sqlmodel
+from uuid import UUID
 
 from . import base
 from . import database
@@ -24,7 +25,8 @@ log = logging.getLogger("pbs.uploads")
 class Uploads(base.Object):
        def __aiter__(self):
                stmt = (
-                       sqlalchemy.select(Upload)
+                       sqlmodel
+                       .select(Upload)
 
                        # Order them by creation time
                        .order_by(Upload.created_at)
@@ -35,7 +37,8 @@ class Uploads(base.Object):
 
        async def get_by_uuid(self, uuid):
                stmt = (
-                       sqlalchemy.select(Upload)
+                       sqlmodel
+                       .select(Upload)
                        .where(
                                Upload.uuid == uuid,
                                Upload.expires_at > sqlalchemy.func.current_timestamp(),
@@ -109,7 +112,7 @@ class Uploads(base.Object):
        async def cleanup(self):
                # Find all expired uploads
                stmt = (
-                       sqlalchemy
+                       sqlmodel
                        .select(Upload)
                        .where(
                                Upload.expires_at <= sqlalchemy.func.current_timestamp(),
@@ -144,7 +147,7 @@ class Uploads(base.Object):
                return h.digest()
 
 
-class Upload(database.Base, database.BackendMixin):
+class Upload(sqlmodel.SQLModel, database.BackendMixin, table=True):
        __tablename__ = "uploads"
 
        def __str__(self):
@@ -152,16 +155,20 @@ class Upload(database.Base, database.BackendMixin):
 
        # ID
 
-       id = Column(Integer, primary_key=True)
+       id : int = sqlmodel.Field(primary_key=True)
 
        # UUID
 
-       uuid = Column(UUID, unique=True, nullable=False,
-               server_default=sqlalchemy.func.gen_random_uuid())
+       uuid: UUID = sqlmodel.Field(
+               unique = True,
+               sa_column_kwargs = {
+                       "server_default" : sqlalchemy.text("gen_random_uuid()"),
+               },
+       )
 
        # Filename
 
-       filename = Column(Text, nullable=False)
+       filename : str
 
        # Extension
 
@@ -173,31 +180,31 @@ class Upload(database.Base, database.BackendMixin):
 
        # Path
 
-       path = Column(Text, nullable=False)
+       path : str
 
        # Size
 
-       size = Column(BigInteger, nullable=False)
+       size : int # Ensure this is BIGINT
 
        # Digest
 
-       digest_blake2b512 = Column(LargeBinary, nullable=False)
+       digest_blake2b512 : bytes
 
        # Builder ID
 
-       builder_id = Column(Integer, ForeignKey("builders.id"))
+       builder_id : int = sqlmodel.Field(foreign_key="builders.id")
 
        # Builder
 
-       builder = sqlalchemy.orm.relationship("Builder", foreign_keys=[builder_id], lazy="joined")
+       builder : "Builder" = sqlmodel.Relationship()
 
        # User ID
 
-       user_id = Column(Integer, ForeignKey("users.id"))
+       user_id : int = sqlmodel.Field(foreign_key="users.id")
 
        # User
 
-       user = sqlalchemy.orm.relationship("User", foreign_keys=[user_id], lazy="joined")
+       user : "User" = sqlmodel.Relationship()
 
        # Has Perms?
 
@@ -225,13 +232,15 @@ class Upload(database.Base, database.BackendMixin):
 
        # Created At
 
-       created_at = Column(DateTime(timezone=False), nullable=False,
-               server_default=sqlalchemy.func.current_timestamp())
+       created_at : datetime.datetime = sqlmodel.Field(
+               sa_column_kwargs = {"server_default" : sqlalchemy.text("CURRENT_TIMESTAMP")}
+       )
 
        # Expires At
 
-       expires_at = Column(DateTime(timezone=False), nullable=False,
-               server_default=sqlalchemy.text("CURRENT_TIMESTAMP + INTERVAL '24 hours'"))
+       expires_at : datetime.datetime = sqlmodel.Field(
+               sa_column_kwargs = {"server_default" : sqlalchemy.text("CURRENT_TIMESTAMP + INTERVAL '24 hours'")}
+       )
 
        # Has Payload?