]> git.ipfire.org Git - people/stevee/pypdns.git/blobdiff - backend.py
Code rework of the CLI.
[people/stevee/pypdns.git] / backend.py
index f3efa646ab4e0ea2a034311449a8980e9101d09b..e915fd224212578eadf6e0b4db1d9f265094273c 100644 (file)
 #                                                                             #
 ###############################################################################
 #                                                                             #
+# Basic information about the database layout can be found here:              #
+# http://doc.powerdns.com/gsqlite.html                                        #
+#                                                                             #
 # More details about the database tables and fields can be found here:        #
 # http://wiki.powerdns.com/trac/wiki/fields                                   #
 #                                                                             #
 ###############################################################################
 
 import database
+import sqlite3
+
+from errors import *
 
 DB = "/var/lib/pdns/pdns.db"
 
 # Create the primary DNS class.
-"""Use Database class from imported database module to connect to the PDNS sqlite database."""
 class DNS(object):
+       """
+       Primary DNS class.
+
+       Uses the database class from imported database module.
+       Connects to the PDNS sqlite database.
+       """
        def __init__(self, db):
-               self.db = database.Database(db)
+               # Try to connect to database or raise an exception.
+               try:
+                       self.db = database.Database(db)
+
+               except sqlite3.OperationalError, e:
+                       raise DatabaseException, "Could not open database: %s" % e
+
+
 
        # Get all configured domains.
        def get_domains(self):
+               """
+               Fetch all configured domains.
+               """
+               # Create an empty list.
                domains = []
 
-               """Fetch all configured domains, line by line and add them to the previous created empty list."""
+               # Add fetched domains to the previous created empty list.
                for row in self.db.query("SELECT id FROM domains"):
                        domain = Domain(self, row.id)
                        domains.append(domain)
@@ -46,78 +68,109 @@ class DNS(object):
 
        # Get a domain by it's name.
        def get_domain(self, name):
+               """
+               Get a domain by a given name.
+               """
                row = self.db.get("SELECT id FROM domains WHERE name = ?", name)
 
-               # Only do anything, if there is an existing domain.
-               if row:
-                       domain = Domain(self, row.id)
+               # Check if an id has been returned from database or return None.
+               if not row:
+                       return None
+
+               return Domain(self, row.id)
 
-                       return domain
 
 # Create Domain class.
-"""Use query method from database module to get all requested information about our domain."""
-"""The domain is specified by it's unique id."""
 class Domain(object):
+       """
+       Domain class.
+
+       Uses query method from database module to get requested information
+       from domain.
+
+       The domain is specified by it's unique database id.
+       """
        def __init__(self, dns, domain_id):
                self.dns = dns
                self.id = domain_id
 
+               self.__data = None
+
        @property
        def db(self):
                return self.dns.db
 
+       # Cache.
+       @property
+       def data(self):
+               if self.__data is None:
+                       self.__data = self.db.get("SELECT * FROM domains \
+                               WHERE id = ?", self.id)
+                       assert self.__data
+
+               return self.__data
+
        # Determine the name of the zone by a given id.
        @property
        def name(self):
-               row = self.db.get("SELECT name FROM domains WHERE id = ?", self.id)
-               return row.name
+               return self.data.name
 
-       # Get information of the master nameserver from which the domain should be slaved.
+       # Get information of the master nameserver from which the domain should
+       # be slaved.
        @property
        def master(self):
-               row = self.db.get("SELECT master FROM domains WHERE id = ?", self.id)
-               return row.master
+               return self.data.master
 
        # Fetch data of the last check from the domain.
        @property
        def last_check(self):
-               row = self.db.get("SELECT last_check FROM domains WHERE id = ?", self.id)
-               return row.last_check
+               return self.data.last_check
 
        # Get the type of the domain.
        @property
        def type(self):
-               row = self.db.get("SELECT type FROM domains WHERE id = ?", self.id)
-               return row.type
+               return self.data.type
 
        # Get the last notified serial of a used master domain.
        @property
        def notified_serial(self):
-               row = self.db.get("SELECT notified_serial FROM domains WHERE id = ?", self.id)
-               return row.notified_serial
+               return self.data.notified_serial
 
        # Gain if a certain host is a supermaster for a certain domain name.
        @property
        def account(self):
-               row = self.db.get("SELECT account FROM domains WHERE id = ?", self.id)
-               return row.account
+               return self.data.account
+
+       # Get count of records of a zone. Return true if there is at least one
+       # or false.
+       def has_records(self):
+               count = self.db.get("SELECT COUNT(*) AS num FROM records \
+                       WHERE domain_id = ?", self.id)
+
+               if count.num > 0:
+                       return True
+
+               return False
 
        # Get all records from zone.
        @property 
        def records(self):
-               records = []
+               """
+               Get all records from the zone.
+               """
+               # Fetch records from zone and categorize them into their
+               # different record types.
+               for row in self.db.query("SELECT id, type FROM records \
+                       WHERE domain_id = ?", self.id):
 
-               """Fetch all records from domain, line by line and add them to the previous created empty list."""
-               for row in self.db.query("SELECT id, type FROM records WHERE domain_id = ?", self.id):
                        if row.type == "SOA":
                                record = SOARecord(self, row.id)
                        elif row.type == "A":
                                record = ARecord(self, row.id)
                        else:
-                               record = Record(self, row.id) 
-                       records.append(record)
+                               record = Record(self, row.id)
 
-               return records
+                       yield record
 
        # Get records by a specified type.
        def get_records_by_type(self, type):
@@ -137,82 +190,99 @@ class Domain(object):
 
 
 # Create class for domain records.
-"""It is used to get more details about the configured records."""
-"""The domain is specified by it's unique id."""
 class Record(object):
+       """
+       Record class
+
+       It is used to get details about configured records.
+       The domain and record is's are specified by their unique database id's.
+       """
        def __init__(self, domain, record_id):
                self.domain = domain
                self.id = record_id
 
+               # Cache.
+               self.__data = None
+
        @property
        def db(self):
                return self.domain.db
 
+       @property
+       def data(self):
+               if self.__data is None:
+                       self.__data = self.db.get("SELECT * FROM records \
+                               WHERE id = ?", self.id)
+                       assert self.__data
+
+               return self.__data
+
        # Determine the type of the record.
        @property
        def type(self):
-               row = self.db.get("SELECT type FROM records WHERE id = ?", self.id)
-               return row.type
+               return self.data.type
 
        # Get the configured DNS name of the record.
        @property
        def dnsname(self):
-               row = self.db.get("SELECT name FROM records WHERE id = ?", self.id)
-               return row.name
+               return self.data.name
 
 
        # Fetch content like the address to which the record points.
        @property
        def content(self):
-               row = self.db.get("SELECT content FROM records WHERE id = ?", self.id)
-               return row.content
-
+               return self.data.content
 
        # Get the "Time to live" for the record.
        @property
        def ttl(self):
-               row = self.db.get("SELECT ttl FROM records WHERE id = ?", self.id)
-               return row.ttl
+               return self.data.ttl
 
        # Gain the configured record priority.
        @property
        def priority(self):
-               row = self.db.get("SELECT prio FROM records WHERE id = ?" , self.id)
-               return row.prio
+               return self.data.prio
 
        # Get the change_date.
        @property
        def change_date(self):
-               row = self.db.get("SELECT change_date FROM records WHERE id = ?" , self.id)
-               return row.change_date
+               return self.data.change_date
 
        # Fetch the ordername.
        @property
        def ordername(self):
-               row = self.db.get("SELECT ordername FROM records WHERE id = ?" , self.id)
-               return row.ordername
+               return self.data.ordername
 
        # Gain all information about records authentication.
        @property
        def authentication(self):
-               row = self.db.get("SELECT auth FROM records WHERE id = ?" , self.id)
-               return row.auth
+               return self.data.auth
 
 
 # Create an own class to deal with "SOA" records.
-"""Use splitt() to generate a list of the original content string from the database, and return the requested entries.""" 
 class SOARecord(Record):
+       """
+       SOA Record class.
+       This is an own class to deal with "SOA" records.
+
+       Uses splitt() to generate a list of the content string from the
+       database.
+
+       Returns the requested entries.
+       """
        def __init__(self, domain, record_id):
                Record.__init__(self, domain, record_id)
 
                self.soa_attrs = self.content.split()
 
-               # Check if the content from database is valid (It contains all 7 required information).
+               # Check if the content from database is valid.
+               # (It contains all 7 required information)
                if not len(self.soa_attrs) == 7:
-                       #XXX Add something like an error message or log output.
-                       pass                    
+                       raise InvalidRecordDataException, "Your SOA record \
+                               doesn't contain all required seven elements."
 
-       # Primary NS - the domain name of the name server that was the original source of the data.     
+       # Primary NS - the domain name of the name server that was the
+       # original source of the data.  
        @property
        def mname(self):
                return self.soa_attrs[0]
@@ -222,22 +292,26 @@ class SOARecord(Record):
        def email(self):
                return self.soa_attrs[1]
 
-       # The serial which increases allways after a change on the domain has been made.
+       # The serial which increases allways after a change on the domain has
+       # been made.
        @property
        def serial(self):
                return self.soa_attrs[2]
 
-       # The number of seconds between the time that a secondary name server gets a copy of the domain.
+       # The number of seconds between the time that a secondary name server
+       # gets a copy of the domain.
        @property
        def refresh(self):
                return self.soa_attrs[3]
 
-       # The number of seconds during the next refresh attempt if the previous fails.
+       # The number of seconds during the next refresh attempt if the
+       # previous fails.
        @property
        def retry(self):
                return self.soa_attrs[4]
 
-       # The number of seconds that lets the secondary name server(s) know how long they can hold the information.
+       # The number of seconds that lets the secondary name server(s) know
+       # how long they can hold the information.
        @property
        def expire(self):
                return self.soa_attrs[5]
@@ -251,4 +325,3 @@ class SOARecord(Record):
 
 class ARecord(Record):
        pass
-