]> git.ipfire.org Git - people/stevee/pakfire.git/blobdiff - python/pakfire/repository/database.py
Improve the repository code.
[people/stevee/pakfire.git] / python / pakfire / repository / database.py
index 1b3872f02d8f324b788a501e61b58bf57ae68a4a..da145a03bf571e96e694d6b52553b70ff028c2bd 100644 (file)
 #                                                                             #
 ###############################################################################
 
-import logging
 import os
 import random
 import shutil
 import sqlite3
 import time
 
+import logging
+log = logging.getLogger("pakfire")
+
 import pakfire.packages as packages
 
 from pakfire.constants import *
+from pakfire.i18n import _
 
 class Cursor(sqlite3.Cursor):
        def execute(self, *args, **kwargs):
@@ -53,9 +56,12 @@ class Database(object):
        def create(self):
                pass
 
+       def migrate(self):
+               pass
+
        def open(self):
                if self._db is None:
-                       logging.debug("Open database %s" % self.filename)
+                       log.debug("Open database %s" % self.filename)
 
                        dirname = os.path.dirname(self.filename)
                        if not os.path.exists(dirname):
@@ -67,16 +73,22 @@ class Database(object):
                        self._db = sqlite3.connect(self.filename)
                        self._db.row_factory = sqlite3.Row
 
-                       # Create the database if it was not there, yet.
-                       if not database_exists:
+                       # In the case, the database was not existant, it is
+                       # filled with content. In case it has been there
+                       # we call the migrate method to update it if neccessary.
+                       if database_exists:
+                               self.migrate()
+                       else:
                                self.create()
 
        def close(self):
-               self.__del__()
+               if self._db:
+                       self._db.close()
+                       self._db = None
 
        def commit(self):
-               self.open()
-               self._db.commit()
+               if self._db:
+                       self._db.commit()
 
        def cursor(self):
                self.open()
@@ -94,8 +106,15 @@ class DatabaseLocal(Database):
                # Generate filename for package database
                filename = os.path.join(pakfire.path, PACKAGES_DB)
 
+               # Cache format number.
+               self.__format = None
+
                Database.__init__(self, pakfire, filename)
 
+               # Check if we actually can open the database.
+               if not self.format in DATABASE_FORMATS_SUPPORTED:
+                       raise DatabaseFormatError, _("The format of the database is not supported by this version of pakfire.")
+
        def __len__(self):
                count = 0
 
@@ -107,6 +126,23 @@ class DatabaseLocal(Database):
 
                return count
 
+       @property
+       def format(self):
+               if self.__format is None:
+                       c = self.cursor()
+
+                       c.execute("SELECT val FROM settings WHERE key = 'version' LIMIT 1")
+                       for row in c:
+                               try:
+                                       self.__format = int(row["val"])
+                                       break
+                               except ValueError:
+                                       pass
+
+                       c.close()
+
+               return self.__format
+
        def create(self):
                c = self.cursor()
                c.executescript("""
@@ -114,19 +150,25 @@ class DatabaseLocal(Database):
                                key                     TEXT,
                                val                     TEXT
                        );
-                       INSERT INTO settings(key, val) VALUES('version', '0');
+                       INSERT INTO settings(key, val) VALUES('version', '%s');
 
                        CREATE TABLE files(
-                               id                      INTEGER PRIMARY KEY,
+                               id              INTEGER PRIMARY KEY,
                                name            TEXT,
-                               pkg                     INTEGER,
+                               pkg             INTEGER,
                                size            INTEGER,
                                type            INTEGER,
-                               hash1           TEXT
+                               config          INTEGER,
+                               mode            INTEGER,
+                               user            TEXT,
+                               `group`         TEXT,
+                               hash1           TEXT,
+                               mtime           INTEGER,
+                               capabilities    TEXT
                        );
 
                        CREATE TABLE packages(
-                               id                      INTEGER PRIMARY KEY,
+                               id              INTEGER PRIMARY KEY,
                                name            TEXT,
                                epoch           INTEGER,
                                version         TEXT,
@@ -144,6 +186,7 @@ class DatabaseLocal(Database):
                                summary         TEXT,
                                description     TEXT,
                                uuid            TEXT,
+                               vendor          TEXT,
                                build_id        TEXT,
                                build_host      TEXT,
                                build_date      TEXT,
@@ -166,13 +209,49 @@ class DatabaseLocal(Database):
                                dependency      TEXT,
                                scriptlet       TEXT
                        );
-               """)
+               """ % DATABASE_FORMAT)
                # XXX add some indexes here
                self.commit()
                c.close()
 
+       def migrate(self):
+               # If we have already the latest version, there is nothing to do.
+               if self.format == DATABASE_FORMAT:
+                       return
+
+               # Check if database version is supported.
+               if self.format > DATABASE_FORMAT:
+                       raise DatabaseError, _("Cannot use database with version greater than %s.") % DATABASE_FORMAT
+
+               log.info(_("Migrating database from format %(old)s to %(new)s.") % \
+                       { "old" : self.format, "new" : DATABASE_FORMAT })
+
+               # Get a database cursor.
+               c = self.cursor()
+
+               # 1) The vendor column was added.
+               if self.format < 1:
+                       c.execute("ALTER TABLE packages ADD COLUMN vendor TEXT AFTER uuid")
+
+               if self.format < 2:
+                       c.execute("ALTER TABLE files ADD COLUMN `config` INTEGER")
+                       c.execute("ALTER TABLE files ADD COLUMN `mode` INTEGER")
+                       c.execute("ALTER TABLE files ADD COLUMN `user` TEXT")
+                       c.execute("ALTER TABLE files ADD COLUMN `group` TEXT")
+                       c.execute("ALTER TABLE files ADD COLUMN `mtime` INTEGER")
+
+               if self.format < 3:
+                       c.execute("ALTER TABLE files ADD COLUMN `capabilities` TEXT")
+
+               # In the end, we can easily update the version of the database.
+               c.execute("UPDATE settings SET val = ? WHERE key = 'version'", (DATABASE_FORMAT,))
+               self.__format = DATABASE_FORMAT
+
+               self.commit()
+               c.close()
+
        def add_package(self, pkg, reason=None):
-               logging.debug("Adding package to database: %s" % pkg.friendly_name)
+               log.debug("Adding package to database: %s" % pkg.friendly_name)
 
                c = self.cursor()
 
@@ -196,6 +275,7 @@ class DatabaseLocal(Database):
                                        summary,
                                        description,
                                        uuid,
+                                       vendor,
                                        build_id,
                                        build_host,
                                        build_date,
@@ -203,7 +283,7 @@ class DatabaseLocal(Database):
                                        installed,
                                        repository,
                                        reason
-                               ) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
+                               ) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
                                (
                                        pkg.name,
                                        pkg.epoch,
@@ -222,6 +302,7 @@ class DatabaseLocal(Database):
                                        pkg.summary,
                                        pkg.description,
                                        pkg.uuid,
+                                       pkg.vendor or "",
                                        pkg.build_id,
                                        pkg.build_host,
                                        pkg.build_date,
@@ -234,8 +315,9 @@ class DatabaseLocal(Database):
 
                        pkg_id = c.lastrowid
 
-                       c.executemany("INSERT INTO files(name, pkg, size, hash1) VALUES(?, ?, ?, ?)",
-                               ((f.name, pkg_id, f.size, f.hash1) for f in pkg.filelist))
+                       c.executemany("INSERT INTO files(`name`, `pkg`, `size`, `config`, `type`, `hash1`, `mode`, `user`, `group`, `mtime`, `capabilities`)"
+                                       " VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+                               ((f.name, pkg_id, f.size, f.is_config(), f.type, f.hash1, f.mode, f.user, f.group, f.mtime, f.capabilities or "") for f in pkg.filelist))
 
                except:
                        raise
@@ -246,13 +328,15 @@ class DatabaseLocal(Database):
                c.close()
 
        def rem_package(self, pkg):
-               logging.debug("Removing package from database: %s" % pkg.friendly_name)
+               log.debug("Removing package from database: %s" % pkg.friendly_name)
 
                assert pkg.uuid
 
                # Get the ID of the package in the database.
                c = self.cursor()
-               c.execute("SELECT id FROM packages WHERE uuid = ? LIMIT 1", (pkg.uuid,))
+               #c.execute("SELECT id FROM packages WHERE uuid = ? LIMIT 1", (pkg.uuid,))
+               c.execute("SELECT id FROM packages WHERE name = ? AND epoch = ? AND version = ?"
+                       " AND release = ? LIMIT 1", (pkg.name, pkg.epoch, pkg.version, pkg.release,))
 
                id = None
                for row in c:
@@ -289,6 +373,18 @@ class DatabaseLocal(Database):
 
                c.close()
 
+       def get_filelist(self):
+               c = self.cursor()
+               c.execute("SELECT DISTINCT name FROM files")
+
+               ret = []
+               for row in c:
+                       ret.append(row["name"])
+
+               c.close()
+
+               return ret
+
        def get_package_from_solv(self, solv_pkg):
                c = self.cursor()
                c.execute("SELECT * FROM packages WHERE uuid = ? LIMIT 1", (solv_pkg.uuid,))