]> git.ipfire.org Git - pakfire.git/blobdiff - pakfire/repository/database.py
One huge commit, that breaks pakfire.
[pakfire.git] / pakfire / repository / database.py
index e360e20798c535c1d127cbc0ce703b8051e4503e..f8dc8ca8d0eda0e85b918162cbdb69921baead86 100644 (file)
@@ -1,4 +1,23 @@
 #!/usr/bin/python
+###############################################################################
+#                                                                             #
+# Pakfire - The IPFire package management system                              #
+# Copyright (C) 2011 Pakfire development team                                 #
+#                                                                             #
+# This program is free software: you can redistribute it and/or modify        #
+# it under the terms of the GNU General Public License as published by        #
+# the Free Software Foundation, either version 3 of the License, or           #
+# (at your option) any later version.                                         #
+#                                                                             #
+# This program is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
+# GNU General Public License for more details.                                #
+#                                                                             #
+# You should have received a copy of the GNU General Public License           #
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
+#                                                                             #
+###############################################################################
 
 import logging
 import os
@@ -7,6 +26,8 @@ import shutil
 import sqlite3
 import time
 
+import pakfire.packages as packages
+
 from pakfire.constants import *
 
 class Cursor(sqlite3.Cursor):
@@ -20,29 +41,20 @@ class Cursor(sqlite3.Cursor):
 class Database(object):
        def __init__(self, pakfire, filename):
                self.pakfire = pakfire
-               self._db = None
-
-               self._tmp = False
-
-               if filename == ":memory:":
-                       self._tmp = True
-
-                       filename = "/tmp/.%s-db" % random.randint(0, 1024**2)
-
                self.filename = filename
 
-               self.open()
+               self._db = None
 
        def __del__(self):
                if self._db:
-                       #self._db.commit()
                        self._db.close()
+                       self._db = None
 
        def create(self):
                pass
 
        def open(self):
-               if not self._db:
+               if self._db is None:
                        logging.debug("Open database %s" % self.filename)
 
                        dirname = os.path.dirname(self.filename)
@@ -60,35 +72,50 @@ class Database(object):
                                self.create()
 
        def close(self):
-               self._db.close()
-               self._db = None
-
-               if self._tmp:
-                       os.unlink(self.filename)
+               self.__del__()
 
        def commit(self):
+               self.open()
                self._db.commit()
 
        def cursor(self):
+               self.open()
                return self._db.cursor(Cursor)
 
        def executescript(self, *args, **kwargs):
+               self.open()
                return self._db.executescript(*args, **kwargs)
 
-       def save(self, path):
-               """
-                       Save a copy of this database to a new one located at path.
-               """
-               self.commit()
 
-               shutil.copy2(self.filename, path)
+class DatabaseLocal(Database):
+       def __init__(self, pakfire, repo):
+               self.repo = repo
 
+               # Generate filename for package database
+               filename = os.path.join(pakfire.path, PACKAGES_DB)
+
+               Database.__init__(self, pakfire, filename)
+
+       def __len__(self):
+               count = 0
 
-class PackageDatabase(Database):
-       def create(self):
                c = self.cursor()
+               c.execute("SELECT COUNT(*) AS count FROM packages")
+               for row in c:
+                       count = row["count"]
+               c.close()
+
+               return count
 
+       def create(self):
+               c = self.cursor()
                c.executescript("""
+                       CREATE TABLE settings(
+                               key                     TEXT,
+                               val                     TEXT
+                       );
+                       INSERT INTO settings(key, val) VALUES('version', '0');
+
                        CREATE TABLE files(
                                name            TEXT,
                                pkg                     INTEGER,
@@ -119,174 +146,144 @@ class PackageDatabase(Database):
                                build_id        TEXT,
                                build_host      TEXT,
                                build_date      TEXT,
-                               build_time      INTEGER
+                               build_time      INTEGER,
+                               installed       INT,
+                               reason          TEXT,
+                               repository      TEXT
+                       );
+
+                       CREATE TABLE scriptlets(
+                               id                      INTEGER PRIMARY KEY,
+                               pkg                     INTEGER,
+                               action          TEXT,
+                               scriptlet       TEXT
+                       );
+
+                       CREATE TABLE triggers(
+                               id                      INTEGER PRIMARY KEY,
+                               pkg                     INTEGER,
+                               dependency      TEXT,
+                               scriptlet       TEXT
                        );
                """)
                # XXX add some indexes here
-
                self.commit()
                c.close()
 
-       def list_packages(self):
-               c = self.cursor()
-               c.execute("SELECT DISTINCT name FROM packages ORDER BY name")
-
-               for pkg in c:
-                       yield pkg["name"]
-
-               c.close()
-
-       def package_exists(self, pkg):
-               return not self.get_id_by_pkg(pkg) is None
+       def add_package(self, pkg, reason=None):
+               logging.debug("Adding package to database: %s" % pkg.friendly_name)
 
-       def get_id_by_pkg(self, pkg):
                c = self.cursor()
 
-               c.execute("SELECT id FROM packages WHERE name = ? AND version = ? AND \
-                       release = ? AND epoch = ? LIMIT 1", (pkg.name, pkg.version, pkg.release, pkg.epoch))
+               try:
+                       c.execute("""
+                               INSERT INTO packages(
+                                       name,
+                                       epoch,
+                                       version,
+                                       release,
+                                       arch,
+                                       groups,
+                                       filename,
+                                       size,
+                                       hash1,
+                                       provides,
+                                       requires,
+                                       conflicts,
+                                       obsoletes,
+                                       license,
+                                       summary,
+                                       description,
+                                       uuid,
+                                       build_id,
+                                       build_host,
+                                       build_date,
+                                       build_time,
+                                       installed,
+                                       repository,
+                                       reason
+                               ) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
+                               (
+                                       pkg.name,
+                                       pkg.epoch,
+                                       pkg.version,
+                                       pkg.release,
+                                       pkg.arch,
+                                       " ".join(pkg.groups),
+                                       pkg.filename,
+                                       pkg.size,
+                                       pkg.hash1,
+                                       " ".join(pkg.provides),
+                                       " ".join(pkg.requires),
+                                       " ".join(pkg.conflicts),
+                                       " ".join(pkg.obsoletes),
+                                       pkg.license,
+                                       pkg.summary,
+                                       pkg.description,
+                                       pkg.uuid,
+                                       pkg.build_id,
+                                       pkg.build_host,
+                                       pkg.build_date,
+                                       pkg.build_time,
+                                       time.time(),
+                                       pkg.repo.name,
+                                       reason or "",
+                               )
+                       )
 
-               ret = None
-               for i in c:
-                       ret = i["id"]
-                       break
+                       pkg_id = c.lastrowid
 
-               c.close()
+                       c.executemany("INSERT INTO files(name, pkg) VALUES(?, ?)",
+                               ((file, pkg_id) for file in pkg.filelist))
 
-               return ret
+               except:
+                       raise
 
-       def add_package(self, pkg):
-               raise NotImplementedError
+               else:
+                       self.commit()
 
+               c.close()
 
-class RemotePackageDatabase(PackageDatabase):
-       def add_package(self, pkg, reason=None):
-               if self.package_exists(pkg):
-                       logging.debug("Skipping package which already exists in database: %s" % pkg.friendly_name)
-                       return
+       def rem_package(self, pkg):
+               logging.debug("Removing package from database: %s" % pkg.friendly_name)
 
-               logging.debug("Adding package to database: %s" % pkg.friendly_name)
-
-               filename = ""
-               if pkg.repo.local:
-                       # Get the path relatively to the repository.
-                       filename = pkg.filename[len(pkg.repo.path):]
-                       # Strip leading / if any.
-                       if filename.startswith("/"):
-                               filename = filename[1:]
+               assert pkg.uuid
 
+               # Get the ID of the package in the database.
                c = self.cursor()
-               c.execute("""
-                       INSERT INTO packages(
-                               name,
-                               epoch,
-                               version,
-                               release,
-                               arch,
-                               groups,
-                               filename,
-                               size,
-                               hash1,
-                               provides,
-                               requires,
-                               conflicts,
-                               obsoletes,
-                               license,
-                               summary,
-                               description,
-                               uuid,
-                               build_id,
-                               build_host,
-                               build_date,
-                               build_time
-                       ) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
-                       (
-                               pkg.name,
-                               pkg.epoch,
-                               pkg.version,
-                               pkg.release,
-                               pkg.arch,
-                               " ".join(pkg.groups),
-                               filename,
-                               pkg.size,
-                               pkg.hash1,
-                               " ".join(pkg.provides),
-                               " ".join(pkg.requires),
-                               " ".join(pkg.conflicts),
-                               " ".join(pkg.obsoletes),
-                               pkg.license,
-                               pkg.summary,
-                               pkg.description,
-                               pkg.uuid,
-                               pkg.build_id,
-                               pkg.build_host,
-                               pkg.build_date,
-                               pkg.build_time,
-                       )
-               )
-               self.commit()
-               c.close()
+               c.execute("SELECT id FROM packages WHERE uuid = ? LIMIT 1", (pkg.uuid,))
 
-               pkg_id = self.get_id_by_pkg(pkg)
+               id = None
+               for row in c:
+                       id = row["id"]
+                       break
+               assert id
 
-               c = self.cursor()
-               for file in pkg.filelist:
-                       c.execute("INSERT INTO files(name, pkg) VALUES(?, ?)", (file, pkg_id))
+               # First, delete all files from the database and then delete the pkg itself.
+               c.execute("DELETE FROM files WHERE pkg = ?", (id,))
+               c.execute("DELETE FROM packages WHERE id = ?", (id,))
 
-               self.commit()
                c.close()
-
-               return pkg_id
-
-
-class LocalPackageDatabase(RemotePackageDatabase):
-       def __init__(self, pakfire):
-               # Generate filename for package database
-               filename = os.path.join(pakfire.path, PACKAGES_DB)
-
-               RemotePackageDatabase.__init__(self, pakfire, filename)
-
-       def create(self):
-               RemotePackageDatabase.create(self)
-
-               # Alter the database layout to store additional local information.
-               logging.debug("Altering database table for local information.")
-               c = self.cursor()
-               c.executescript("""
-                       ALTER TABLE packages ADD COLUMN installed INT;
-                       ALTER TABLE packages ADD COLUMN reason TEXT;
-                       ALTER TABLE packages ADD COLUMN repository TEXT;
-                       ALTER TABLE packages ADD COLUMN scriptlet TEXT;
-                       ALTER TABLE packages ADD COLUMN triggers TEXT;
-               """)
                self.commit()
-               c.close()
 
-       def add_package(self, pkg, reason=None):
-               # Insert all the information to the database we have in the remote database
-               pkg_id = RemotePackageDatabase.add_package(self, pkg)
-
-               # then: add some more information
+       @property
+       def packages(self):
                c = self.cursor()
 
-               # Save timestamp when the package was installed.
-               c.execute("UPDATE packages SET installed = ? WHERE id = ?", (time.time(), pkg_id))
-
-               # Add repository information.
-               c.execute("UPDATE packages SET repository = ? WHERE id = ?", (pkg.repo.name, pkg_id))
+               c.execute("SELECT * FROM packages ORDER BY name")
 
-               # Save reason of installation (if any).
-               if reason:
-                       c.execute("UPDATE packages SET reason = ? WHERE id = ?", (reason, pkg_id))
+               for row in c:
+                       yield packages.DatabasePackage(self.pakfire, self.repo, self, row)
 
-               # Update the filename information.
-               c.execute("UPDATE packages SET filename = ? WHERE id = ?", (pkg.filename, pkg_id))
+               c.close()
 
-               # Add the scriptlet to database (needed to update or uninstall packages).
-               c.execute("UPDATE packages SET scriptlet = ? WHERE id = ?", (pkg.scriptlet, pkg_id))
+       def get_package_from_solv(self, solv_pkg):
+               c = self.cursor()
+               c.execute("SELECT * FROM packages WHERE uuid = ? LIMIT 1", (solv_pkg.uuid,))
 
-               # Add triggers to the database.
-               triggers = " ".join(pkg.triggers)
-               c.execute("UPDATE packages SET triggers = ? WHERE id = ?", (triggers, pkg_id))
+               try:
+                       for row in c:
+                               return packages.DatabasePackage(self.pakfire, self.repo, self, row)
 
-               self.commit()
-               c.close()
+               finally:
+                       c.close()