]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactored reflecttable
authorRandall Smith <randall@tnr.cc>
Sun, 15 Mar 2009 05:30:24 +0000 (05:30 +0000)
committerRandall Smith <randall@tnr.cc>
Sun, 15 Mar 2009 05:30:24 +0000 (05:30 +0000)
lib/sqlalchemy/dialects/firebird/base.py

index f00aa963ee6b5260bc007d2a0b460f4432c17c62..630f9da2483f24d591f451d63c84d84dc480695a 100644 (file)
@@ -101,8 +101,9 @@ parameter when creating the queries::
 
 import datetime, decimal, re
 
-from sqlalchemy import exc, schema, types as sqltypes, sql, util
-from sqlalchemy.engine import base, default
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import exc, types as sqltypes, sql, util
+from sqlalchemy.engine import base, default, reflection
 
 
 _initialized_kb = False
@@ -403,7 +404,49 @@ class FBDialect(default.DefaultDialect):
         else:
             return False
 
-    def reflecttable(self, connection, table, include_columns):
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        # Query to extract the PK/FK constrained fields of the given table
+        keyqry = """
+        SELECT se.rdb$field_name AS fname
+        FROM rdb$relation_constraints rc
+             JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
+        WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+        """
+        tablename = self._denormalize_name(table.name)
+        # get primary key fields
+        c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
+        pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()]
+        return pkfields
+
+    @reflection.cache
+    def get_column_sequence(self, connection, table_name, column_name,
+                                                        schema=None, **kw):
+        tablename = self._denormalize_name(table_name)
+        colname = self._denormalize_name(column_name)
+        # Heuristic-query to determine the generator associated to a PK field
+        genqry = """
+        SELECT trigdep.rdb$depended_on_name AS fgenerator
+        FROM rdb$dependencies tabdep
+             JOIN rdb$dependencies trigdep ON (tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
+                                               AND trigdep.rdb$depended_on_type=14
+                                               AND trigdep.rdb$dependent_type=2)
+             JOIN rdb$triggers trig ON (trig.rdb$trigger_name=tabdep.rdb$dependent_name)
+        WHERE tabdep.rdb$depended_on_name=?
+          AND tabdep.rdb$depended_on_type=0
+          AND trig.rdb$trigger_type=1
+          AND tabdep.rdb$field_name=?
+          AND (SELECT count(*)
+               FROM rdb$dependencies trigdep2
+               WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
+        """
+        genc = connection.execute(genqry, [tablename, colname])
+        genr = genc.fetchone()
+        if genr is not None:
+            return dict(name=self._normalize_name(genr['fgenerator']))
+
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
         # Query to extract the details of all the fields of the given table
         tblqry = """
         SELECT DISTINCT r.rdb$field_name AS fname,
@@ -420,13 +463,41 @@ class FBDialect(default.DefaultDialect):
         WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
         ORDER BY r.rdb$field_position
         """
-        # Query to extract the PK/FK constrained fields of the given table
-        keyqry = """
-        SELECT se.rdb$field_name AS fname
-        FROM rdb$relation_constraints rc
-             JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
-        WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
-        """
+        tablename = self._denormalize_name(table_name)
+        # get all of the fields for this table
+        c = connection.execute(tblqry, [tablename])
+        cols = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            name = self._normalize_name(row['fname'])
+            # get the data type
+            coltype = ischema_names.get(row['ftype'].rstrip())
+            if coltype is None:
+                util.warn("Did not recognize type '%s' of column '%s'" %
+                          (str(row['ftype']), name))
+                coltype = sqltypes.NULLTYPE
+            else:
+                coltype = coltype(row)
+
+            # does it have a default value?
+            defvalue = None
+            if row['fdefault'] is not None:
+                # the value comes down as "DEFAULT 'value'"
+                assert row['fdefault'].upper().startswith('DEFAULT '), row
+                defvalue = row['fdefault'][8:]
+            col_d = {
+                'name' : name,
+                'type' : coltype,
+                'nullable' :  not bool(row['null_flag']),
+                'default' : defvalue
+            }
+            cols.append(col_d)
+        return cols
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
         # Query to extract the details of each UK/FK of the given table
         fkqry = """
         SELECT rc.rdb$constraint_name AS cname,
@@ -441,105 +512,91 @@ class FBDialect(default.DefaultDialect):
         WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
         ORDER BY se.rdb$index_name, se.rdb$field_position
         """
-        # Heuristic-query to determine the generator associated to a PK field
-        genqry = """
-        SELECT trigdep.rdb$depended_on_name AS fgenerator
-        FROM rdb$dependencies tabdep
-             JOIN rdb$dependencies trigdep ON (tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
-                                               AND trigdep.rdb$depended_on_type=14
-                                               AND trigdep.rdb$dependent_type=2)
-             JOIN rdb$triggers trig ON (trig.rdb$trigger_name=tabdep.rdb$dependent_name)
-        WHERE tabdep.rdb$depended_on_name=?
-          AND tabdep.rdb$depended_on_type=0
-          AND trig.rdb$trigger_type=1
-          AND tabdep.rdb$field_name=?
-          AND (SELECT count(*)
-               FROM rdb$dependencies trigdep2
-               WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
-        """
+        tablename = self._denormalize_name(table_name)
+        # get the foreign keys
+        c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
+        fks = {}
+        fkeys = []
+        while True:
+            row = c.fetchone()
+            if not row:
+                break
+            cname = self._normalize_name(row['cname'])
+            if cname in fks:
+                fk = fks[cname]
+            else:
+                fk = {
+                    'name' : cname,
+                    'constrained_columns' : [],
+                    'referred_schema' : None,
+                    'referred_table' : None,
+                    'referred_columns' : []
+                }
+                fks[cname] = fk
+                fkeys.append(fk)
+            fk['referred_table'] = self._normalize_name(row['targetrname'])
+            fk['constrained_columns'].append(self._normalize_name(row['fname']))
+            fk['referred_columns'].append(
+                            self._normalize_name(row['targetfname']))
+            return fkeys
 
-        tablename = self._denormalize_name(table.name)
+    def reflecttable(self, connection, table, include_columns):
 
         # get primary key fields
-        c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
-        pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()]
-
-        # get all of the fields for this table
-        c = connection.execute(tblqry, [tablename])
+        pkfields = self.get_primary_keys(connection, table.name)
 
         found_table = False
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
+        for col_d in self.get_columns(connection, table.name):
             found_table = True
 
-            name = self._normalize_name(row['fname'])
+            name = col_d.get('name')
+            defvalue = col_d.get('default')
+            nullable = col_d.get('nullable')
+            coltype = col_d.get('type')
+
             if include_columns and name not in include_columns:
                 continue
             args = [name]
 
             kw = {}
-            # get the data type
-            coltype = ischema_names.get(row['ftype'].rstrip())
-            if coltype is None:
-                util.warn("Did not recognize type '%s' of column '%s'" %
-                          (str(row['ftype']), name))
-                coltype = sqltypes.NULLTYPE
-            else:
-                coltype = coltype(row)
             args.append(coltype)
 
             # is it a primary key?
             kw['primary_key'] = name in pkfields
 
             # is it nullable?
-            kw['nullable'] = not bool(row['null_flag'])
+            kw['nullable'] = nullable
 
             # does it have a default value?
-            if row['fdefault'] is not None:
-                # the value comes down as "DEFAULT 'value'"
-                assert row['fdefault'].upper().startswith('DEFAULT '), row
-                defvalue = row['fdefault'][8:]
-                args.append(schema.DefaultClause(sql.text(defvalue)))
+            if defvalue:
+                args.append(sa_schema.DefaultClause(sql.text(defvalue)))
 
-            col = schema.Column(*args, **kw)
+            col = sa_schema.Column(*args, **kw)
             if kw['primary_key']:
                 # if the PK is a single field, try to see if its linked to
                 # a sequence thru a trigger
                 if len(pkfields)==1:
-                    genc = connection.execute(genqry, [tablename, row['fname']])
-                    genr = genc.fetchone()
-                    if genr is not None:
-                        col.sequence = schema.Sequence(self._normalize_name(genr['fgenerator']))
-
+                    sequence_name = self.get_column_sequence(connection,
+                                            table.name, name)
+                    if sequence_name is not None:
+                        col.sequence = sa_schema.Sequence(sequence_name)
             table.append_column(col)
 
         if not found_table:
             raise exc.NoSuchTableError(table.name)
 
         # get the foreign keys
-        c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
-        fks = {}
-        while True:
-            row = c.fetchone()
-            if not row:
-                break
-
-            cname = self._normalize_name(row['cname'])
-            try:
-                fk = fks[cname]
-            except KeyError:
-                fks[cname] = fk = ([], [])
-            rname = self._normalize_name(row['targetrname'])
-            schema.Table(rname, table.metadata, autoload=True, autoload_with=connection)
-            fname = self._normalize_name(row['fname'])
-            refspec = rname + '.' + self._normalize_name(row['targetfname'])
-            fk[0].append(fname)
-            fk[1].append(refspec)
-
-        for name, value in fks.iteritems():
-            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True))
+        for fkey_d in self.get_foreign_keys(connection, table.name):
+            cname = fkey_d['name']
+            constrained_columns = fkey_d['constrained_columns']
+            rname = fkey_d['referred_table']
+            referred_columns = fkey_d['referred_columns']
+
+            sa_schema.Table(rname, table.metadata, autoload=True, autoload_with=connection)
+            refspec = ['.'.join(c) for c in \
+                                zip(constrained_columns, referred_columns)]
+            table.append_constraint(sa_schema.ForeignKeyConstraint(
+                constrained_columns, refspec, name=cname, link_to_name=True))
 
     def do_execute(self, cursor, statement, parameters, **kwargs):
         # kinterbase does not accept a None, but wants an empty list
@@ -765,4 +822,4 @@ dialect.schemagenerator = FBSchemaGenerator
 dialect.schemadropper = FBSchemaDropper
 dialect.defaultrunner = FBDefaultRunner
 dialect.preparer = FBIdentifierPreparer
-dialect.execution_ctx_cls = FBExecutionContext
\ No newline at end of file
+dialect.execution_ctx_cls = FBExecutionContext