]> git.ipfire.org Git - people/stevee/pypdns.git/commitdiff
Add cache for domain and record objects.
authorStefan Schantl <stefan.schantl@ipfire.org>
Wed, 3 Oct 2012 19:37:04 +0000 (21:37 +0200)
committerStefan Schantl <stefan.schantl@ipfire.org>
Wed, 3 Oct 2012 19:37:04 +0000 (21:37 +0200)
backend.py

index 3911bef87557018deabba485ba9c7b73b2fcfc9b..1679281435a776568675259f92254bd2b4cf137a 100644 (file)
 ###############################################################################
 
 import database
-import errors
 import sqlite3
 
+from errors import *
+
 DB = "/var/lib/pdns/pdns.db"
 
 # Create the primary DNS class.
@@ -46,7 +47,7 @@ class DNS(object):
                        self.db = database.Database(db)
 
                except sqlite3.OperationalError, e:
-                       raise errors.DatabaseException, "Could not open database: %s" % e
+                       raise DatabaseException, "Could not open database: %s" % e
 
 
 
@@ -72,11 +73,12 @@ class DNS(object):
                """
                row = self.db.get("SELECT id FROM domains WHERE name = ?", name)
 
-               # Only do anything, if there is an existing domain.
-               if row:
-                       domain = Domain(self, row.id)
+               # Check if an id has been returned from database, or return None.
+               if not row:
+                       return None
+
+               return Domain(self, row.id)
 
-                       return domain
 
 # Create Domain class.
 class Domain(object):
@@ -90,45 +92,50 @@ class Domain(object):
                self.dns = dns
                self.id = domain_id
 
+               self.__data = None
+
        @property
        def db(self):
                return self.dns.db
 
+       # Cache.
+       @property
+       def data(self):
+               if self.__data is None:
+                       self.__data = self.db.get("SELECT * FROM domains WHERE id = ?", self.id)
+                       assert self.__data
+
+               return self.__data
+
        # Determine the name of the zone by a given id.
        @property
        def name(self):
-               row = self.db.get("SELECT name FROM domains WHERE id = ?", self.id)
-               return row.name
+               return self.data.name
 
        # Get information of the master nameserver from which the domain should be slaved.
        @property
        def master(self):
-               row = self.db.get("SELECT master FROM domains WHERE id = ?", self.id)
-               return row.master
+               return self.data.master
 
        # Fetch data of the last check from the domain.
        @property
        def last_check(self):
-               row = self.db.get("SELECT last_check FROM domains WHERE id = ?", self.id)
-               return row.last_check
+               return self.data.last_check
 
        # Get the type of the domain.
        @property
        def type(self):
-               row = self.db.get("SELECT type FROM domains WHERE id = ?", self.id)
-               return row.type
+               return self.data.type
 
        # Get the last notified serial of a used master domain.
        @property
        def notified_serial(self):
-               row = self.db.get("SELECT notified_serial FROM domains WHERE id = ?", self.id)
-               return row.notified_serial
+               return self.data.notified_serial
 
        # Gain if a certain host is a supermaster for a certain domain name.
        @property
        def account(self):
-               row = self.db.get("SELECT account FROM domains WHERE id = ?", self.id)
-               return row.account
+               return self.data.account
 
        # Get all records from zone.
        @property 
@@ -136,9 +143,6 @@ class Domain(object):
                """
                Get all records from the zone.
                """
-               # Create an empty list.
-               records = []
-
                # Fetch records from zone and categorize them into their different record types.
                for row in self.db.query("SELECT id, type FROM records WHERE domain_id = ?", self.id):
                        if row.type == "SOA":
@@ -146,10 +150,9 @@ class Domain(object):
                        elif row.type == "A":
                                record = ARecord(self, row.id)
                        else:
-                               record = Record(self, row.id) 
-                       records.append(record)
+                               record = Record(self, row.id)
 
-               return records
+                       yield record
 
        # Get records by a specified type.
        def get_records_by_type(self, type):
@@ -180,59 +183,61 @@ class Record(object):
                self.domain = domain
                self.id = record_id
 
+               # Cache.
+               self.__data = None
+
        @property
        def db(self):
                return self.domain.db
 
+       @property
+       def data(self):
+               if self.__data is None:
+                       self.__data = self.db.get("SELECT * FROM records WHERE id = ?", self.id)
+                       assert self.__data
+
+               return self.__data
+
        # Determine the type of the record.
        @property
        def type(self):
-               row = self.db.get("SELECT type FROM records WHERE id = ?", self.id)
-               return row.type
+               return self.data.type
 
        # Get the configured DNS name of the record.
        @property
        def dnsname(self):
-               row = self.db.get("SELECT name FROM records WHERE id = ?", self.id)
-               return row.name
+               return self.data.name
 
 
        # Fetch content like the address to which the record points.
        @property
        def content(self):
-               row = self.db.get("SELECT content FROM records WHERE id = ?", self.id)
-               return row.content
-
+               return self.data.content
 
        # Get the "Time to live" for the record.
        @property
        def ttl(self):
-               row = self.db.get("SELECT ttl FROM records WHERE id = ?", self.id)
-               return row.ttl
+               return self.data.ttl
 
        # Gain the configured record priority.
        @property
        def priority(self):
-               row = self.db.get("SELECT prio FROM records WHERE id = ?" , self.id)
-               return row.prio
+               return self.data.prio
 
        # Get the change_date.
        @property
        def change_date(self):
-               row = self.db.get("SELECT change_date FROM records WHERE id = ?" , self.id)
-               return row.change_date
+               return self.data.change_date
 
        # Fetch the ordername.
        @property
        def ordername(self):
-               row = self.db.get("SELECT ordername FROM records WHERE id = ?" , self.id)
-               return row.ordername
+               return self.data.ordername
 
        # Gain all information about records authentication.
        @property
        def authentication(self):
-               row = self.db.get("SELECT auth FROM records WHERE id = ?" , self.id)
-               return row.auth
+               return self.data.auth
 
 
 # Create an own class to deal with "SOA" records.