]> git.ipfire.org Git - people/stevee/pakfire.git/blobdiff - python/pakfire/repository/database.py
Improve the repository code.
[people/stevee/pakfire.git] / python / pakfire / repository / database.py
index 24d81541d5654e604bbd395dee3a42c5b5ec1e1f..da145a03bf571e96e694d6b52553b70ff028c2bd 100644 (file)
 #                                                                             #
 ###############################################################################
 
-import logging
 import os
 import random
 import shutil
 import sqlite3
 import time
 
+import logging
+log = logging.getLogger("pakfire")
+
 import pakfire.packages as packages
 
 from pakfire.constants import *
@@ -59,7 +61,7 @@ class Database(object):
 
        def open(self):
                if self._db is None:
-                       logging.debug("Open database %s" % self.filename)
+                       log.debug("Open database %s" % self.filename)
 
                        dirname = os.path.dirname(self.filename)
                        if not os.path.exists(dirname):
@@ -80,11 +82,13 @@ class Database(object):
                                self.create()
 
        def close(self):
-               self.__del__()
+               if self._db:
+                       self._db.close()
+                       self._db = None
 
        def commit(self):
-               self.open()
-               self._db.commit()
+               if self._db:
+                       self._db.commit()
 
        def cursor(self):
                self.open()
@@ -159,7 +163,8 @@ class DatabaseLocal(Database):
                                user            TEXT,
                                `group`         TEXT,
                                hash1           TEXT,
-                               mtime           INTEGER
+                               mtime           INTEGER,
+                               capabilities    TEXT
                        );
 
                        CREATE TABLE packages(
@@ -214,7 +219,12 @@ class DatabaseLocal(Database):
                if self.format == DATABASE_FORMAT:
                        return
 
-               logging.info(_("Migrating database from format %s to %s.") % (self.format, DATABASE_FORMAT))
+               # Check if database version is supported.
+               if self.format > DATABASE_FORMAT:
+                       raise DatabaseError, _("Cannot use database with version greater than %s.") % DATABASE_FORMAT
+
+               log.info(_("Migrating database from format %(old)s to %(new)s.") % \
+                       { "old" : self.format, "new" : DATABASE_FORMAT })
 
                # Get a database cursor.
                c = self.cursor()
@@ -230,6 +240,9 @@ class DatabaseLocal(Database):
                        c.execute("ALTER TABLE files ADD COLUMN `group` TEXT")
                        c.execute("ALTER TABLE files ADD COLUMN `mtime` INTEGER")
 
+               if self.format < 3:
+                       c.execute("ALTER TABLE files ADD COLUMN `capabilities` TEXT")
+
                # In the end, we can easily update the version of the database.
                c.execute("UPDATE settings SET val = ? WHERE key = 'version'", (DATABASE_FORMAT,))
                self.__format = DATABASE_FORMAT
@@ -238,7 +251,7 @@ class DatabaseLocal(Database):
                c.close()
 
        def add_package(self, pkg, reason=None):
-               logging.debug("Adding package to database: %s" % pkg.friendly_name)
+               log.debug("Adding package to database: %s" % pkg.friendly_name)
 
                c = self.cursor()
 
@@ -302,9 +315,9 @@ class DatabaseLocal(Database):
 
                        pkg_id = c.lastrowid
 
-                       c.executemany("INSERT INTO files(`name`, `pkg`, `size`, `config`, `type`, `hash1`, `mode`, `user`, `group`, `mtime`)"
-                                       " VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
-                               ((f.name, pkg_id, f.size, f.is_config(), f.type, f.hash1, f.mode, f.user, f.group, f.mtime) for f in pkg.filelist))
+                       c.executemany("INSERT INTO files(`name`, `pkg`, `size`, `config`, `type`, `hash1`, `mode`, `user`, `group`, `mtime`, `capabilities`)"
+                                       " VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+                               ((f.name, pkg_id, f.size, f.is_config(), f.type, f.hash1, f.mode, f.user, f.group, f.mtime, f.capabilities or "") for f in pkg.filelist))
 
                except:
                        raise
@@ -315,7 +328,7 @@ class DatabaseLocal(Database):
                c.close()
 
        def rem_package(self, pkg):
-               logging.debug("Removing package from database: %s" % pkg.friendly_name)
+               log.debug("Removing package from database: %s" % pkg.friendly_name)
 
                assert pkg.uuid
 
@@ -360,6 +373,18 @@ class DatabaseLocal(Database):
 
                c.close()
 
+       def get_filelist(self):
+               c = self.cursor()
+               c.execute("SELECT DISTINCT name FROM files")
+
+               ret = []
+               for row in c:
+                       ret.append(row["name"])
+
+               c.close()
+
+               return ret
+
        def get_package_from_solv(self, solv_pkg):
                c = self.cursor()
                c.execute("SELECT * FROM packages WHERE uuid = ? LIMIT 1", (solv_pkg.uuid,))