]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
r/m information_schema from pg
authorJonathan Ellis <jbellis@gmail.com>
Sun, 29 Jul 2007 04:21:09 +0000 (04:21 +0000)
committerJonathan Ellis <jbellis@gmail.com>
Sun, 29 Jul 2007 04:21:09 +0000 (04:21 +0000)
lib/sqlalchemy/databases/postgres.py

index 697ca8d15a5796a08aee0d38a6501f4601c39354..548494ff2cee6ac0681e2e43a454cc5022a5e9ab 100644 (file)
@@ -9,7 +9,6 @@ import re, random, warnings, operator
 from sqlalchemy import sql, schema, ansisql, exceptions
 from sqlalchemy.engine import base, default
 import sqlalchemy.types as sqltypes
-from sqlalchemy.databases import information_schema as ischema
 from decimal import Decimal
 
 try:
@@ -210,11 +209,10 @@ class PGExecutionContext(default.DefaultExecutionContext):
         super(PGExecutionContext, self).post_exec()
         
 class PGDialect(ansisql.ANSIDialect):
-    def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs):
+    def __init__(self, use_oids=False, server_side_cursors=False, **kwargs):
         ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
         self.use_oids = use_oids
         self.server_side_cursors = server_side_cursors
-        self.use_information_schema = use_information_schema
         self.paramstyle = 'pyformat'
         
     def dbapi(cls):
@@ -336,173 +334,176 @@ class PGDialect(ansisql.ANSIDialect):
             return False
         
     def table_names(self, connection, schema):
-        return ischema.table_names(connection, schema)
+        s = """
+        SELECT relname 
+        FROM pg_class c
+        WHERE relkind = 'r'
+          AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
+        """ % locals()
+        return [row[0] for row in connection.execute(s)]
 
     def reflecttable(self, connection, table, include_columns):
-        if self.use_information_schema:
-            ischema.reflecttable(connection, table, include_columns, ischema_names)
+        preparer = self.identifier_preparer
+        if table.schema is not None:
+            schema_where_clause = "n.nspname = :schema"
         else:
-            preparer = self.identifier_preparer
-            if table.schema is not None:
-                schema_where_clause = "n.nspname = :schema"
-            else:
-                schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
-
-            ## information schema in pg suffers from too many permissions' restrictions
-            ## let us find out at the pg way what is needed...
-
-            SQL_COLS = """
-                SELECT a.attname,
-                  pg_catalog.format_type(a.atttypid, a.atttypmod),
-                  (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
-                   WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef)
-                  AS DEFAULT,
-                  a.attnotnull, a.attnum, a.attrelid as table_oid
-                FROM pg_catalog.pg_attribute a
-                WHERE a.attrelid = (
-                    SELECT c.oid
-                    FROM pg_catalog.pg_class c
-                         LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
-                         WHERE (%s)
-                         AND c.relname = :table_name AND c.relkind in ('r','v')
-                ) AND a.attnum > 0 AND NOT a.attisdropped
-                ORDER BY a.attnum
-            """ % schema_where_clause
-
-            s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
-            c = connection.execute(s, table_name=table.name,
-                                      schema=table.schema)
-            rows = c.fetchall()
-
-            if not rows:
-                raise exceptions.NoSuchTableError(table.name)
-
-            domains = self._load_domains(connection)
+            schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
+
+        ## information schema in pg suffers from too many permissions' restrictions
+        ## let us find out at the pg way what is needed...
+
+        SQL_COLS = """
+            SELECT a.attname,
+              pg_catalog.format_type(a.atttypid, a.atttypmod),
+              (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
+               WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef)
+              AS DEFAULT,
+              a.attnotnull, a.attnum, a.attrelid as table_oid
+            FROM pg_catalog.pg_attribute a
+            WHERE a.attrelid = (
+                SELECT c.oid
+                FROM pg_catalog.pg_class c
+                     LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+                     WHERE (%s)
+                     AND c.relname = :table_name AND c.relkind in ('r','v')
+            ) AND a.attnum > 0 AND NOT a.attisdropped
+            ORDER BY a.attnum
+        """ % schema_where_clause
+
+        s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
+        c = connection.execute(s, table_name=table.name,
+                                  schema=table.schema)
+        rows = c.fetchall()
+
+        if not rows:
+            raise exceptions.NoSuchTableError(table.name)
+
+        domains = self._load_domains(connection)
+        
+        for name, format_type, default, notnull, attnum, table_oid in rows:
+            if include_columns and name not in include_columns:
+                continue
             
-            for name, format_type, default, notnull, attnum, table_oid in rows:
-                if include_columns and name not in include_columns:
-                    continue
-                
-                ## strip (30) from character varying(30)
-                attype = re.search('([^\([]+)', format_type).group(1)
-                nullable = not notnull
-                is_array = format_type.endswith('[]')
-
-                try:
-                    charlen = re.search('\(([\d,]+)\)', format_type).group(1)
-                except:
-                    charlen = False
-
-                numericprec = False
-                numericscale = False
-                if attype == 'numeric':
-                    if charlen is False:
-                        numericprec, numericscale = (None, None)
-                    else:
-                        numericprec, numericscale = charlen.split(',')
-                    charlen = False
-                if attype == 'double precision':
-                    numericprec, numericscale = (53, False)
-                    charlen = False
-                if attype == 'integer':
-                    numericprec, numericscale = (32, 0)
-                    charlen = False
-
-                args = []
-                for a in (charlen, numericprec, numericscale):
-                    if a is None:
-                        args.append(None)
-                    elif a is not False:
-                        args.append(int(a))
-
-                kwargs = {}
-                if attype == 'timestamp with time zone':
-                    kwargs['timezone'] = True
-                elif attype == 'timestamp without time zone':
-                    kwargs['timezone'] = False
-
-                if attype in ischema_names:
-                    coltype = ischema_names[attype]
-                else:
-                    if attype in domains:
-                        domain = domains[attype]
-                        if domain['attype'] in ischema_names:
-                            # A table can't override whether the domain is nullable.
-                            nullable = domain['nullable']
-
-                            if domain['default'] and not default:
-                                # It can, however, override the default value, but can't set it to null.
-                                default = domain['default']
-                            coltype = ischema_names[domain['attype']]
-                    else:
-                        coltype=None
-
-                if coltype:
-                    coltype = coltype(*args, **kwargs)
-                    if is_array:
-                        coltype = PGArray(coltype)
+            ## strip (30) from character varying(30)
+            attype = re.search('([^\([]+)', format_type).group(1)
+            nullable = not notnull
+            is_array = format_type.endswith('[]')
+
+            try:
+                charlen = re.search('\(([\d,]+)\)', format_type).group(1)
+            except:
+                charlen = False
+
+            numericprec = False
+            numericscale = False
+            if attype == 'numeric':
+                if charlen is False:
+                    numericprec, numericscale = (None, None)
                 else:
-                    warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
-                    coltype = sqltypes.NULLTYPE
-
-                colargs= []
-                if default is not None:
-                    match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
-                    if match is not None:
-                        # the default is related to a Sequence
-                        sch = table.schema
-                        if '.' not in match.group(2) and sch is not None:
-                            default = match.group(1) + sch + '.' + match.group(2) + match.group(3)
-                    colargs.append(schema.PassiveDefault(sql.text(default)))
-                table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
-
-
-            # Primary keys
-            PK_SQL = """
-              SELECT attname FROM pg_attribute
-              WHERE attrelid = (
-                 SELECT indexrelid FROM pg_index i
-                 WHERE i.indrelid = :table
-                 AND i.indisprimary = 't')
-              ORDER BY attnum
-            """
-            t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
-            c = connection.execute(t, table=table_oid)
-            for row in c.fetchall():
-                pk = row[0]
-                table.primary_key.add(table.c[pk])
-
-            # Foreign keys
-            FK_SQL = """
-              SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
-              FROM  pg_catalog.pg_constraint r
-              WHERE r.conrelid = :table AND r.contype = 'f'
-              ORDER BY 1
-            """
-
-            t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
-            c = connection.execute(t, table=table_oid)
-            for conname, condef in c.fetchall():
-                m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
-                (constrained_columns, referred_schema, referred_table, referred_columns) = m
-                constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
-                if referred_schema:
-                    referred_schema = preparer._unquote_identifier(referred_schema)
-                referred_table = preparer._unquote_identifier(referred_table)
-                referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
-
-                refspec = []
-                if referred_schema is not None:
-                    schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
-                                autoload_with=connection)
-                    for column in referred_columns:
-                        refspec.append(".".join([referred_schema, referred_table, column]))
+                    numericprec, numericscale = charlen.split(',')
+                charlen = False
+            if attype == 'double precision':
+                numericprec, numericscale = (53, False)
+                charlen = False
+            if attype == 'integer':
+                numericprec, numericscale = (32, 0)
+                charlen = False
+
+            args = []
+            for a in (charlen, numericprec, numericscale):
+                if a is None:
+                    args.append(None)
+                elif a is not False:
+                    args.append(int(a))
+
+            kwargs = {}
+            if attype == 'timestamp with time zone':
+                kwargs['timezone'] = True
+            elif attype == 'timestamp without time zone':
+                kwargs['timezone'] = False
+
+            if attype in ischema_names:
+                coltype = ischema_names[attype]
+            else:
+                if attype in domains:
+                    domain = domains[attype]
+                    if domain['attype'] in ischema_names:
+                        # A table can't override whether the domain is nullable.
+                        nullable = domain['nullable']
+
+                        if domain['default'] and not default:
+                            # It can, however, override the default value, but can't set it to null.
+                            default = domain['default']
+                        coltype = ischema_names[domain['attype']]
                 else:
-                    schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
-                    for column in referred_columns:
-                        refspec.append(".".join([referred_table, column]))
+                    coltype=None
+
+            if coltype:
+                coltype = coltype(*args, **kwargs)
+                if is_array:
+                    coltype = PGArray(coltype)
+            else:
+                warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
+                coltype = sqltypes.NULLTYPE
+
+            colargs= []
+            if default is not None:
+                match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+                if match is not None:
+                    # the default is related to a Sequence
+                    sch = table.schema
+                    if '.' not in match.group(2) and sch is not None:
+                        default = match.group(1) + sch + '.' + match.group(2) + match.group(3)
+                colargs.append(schema.PassiveDefault(sql.text(default)))
+            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+
+
+        # Primary keys
+        PK_SQL = """
+          SELECT attname FROM pg_attribute
+          WHERE attrelid = (
+             SELECT indexrelid FROM pg_index i
+             WHERE i.indrelid = :table
+             AND i.indisprimary = 't')
+          ORDER BY attnum
+        """
+        t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
+        c = connection.execute(t, table=table_oid)
+        for row in c.fetchall():
+            pk = row[0]
+            table.primary_key.add(table.c[pk])
+
+        # Foreign keys
+        FK_SQL = """
+          SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
+          FROM  pg_catalog.pg_constraint r
+          WHERE r.conrelid = :table AND r.contype = 'f'
+          ORDER BY 1
+        """
+
+        t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
+        c = connection.execute(t, table=table_oid)
+        for conname, condef in c.fetchall():
+            m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
+            (constrained_columns, referred_schema, referred_table, referred_columns) = m
+            constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
+            if referred_schema:
+                referred_schema = preparer._unquote_identifier(referred_schema)
+            referred_table = preparer._unquote_identifier(referred_table)
+            referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
+
+            refspec = []
+            if referred_schema is not None:
+                schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
+                            autoload_with=connection)
+                for column in referred_columns:
+                    refspec.append(".".join([referred_schema, referred_table, column]))
+            else:
+                schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
+                for column in referred_columns:
+                    refspec.append(".".join([referred_table, column]))
 
-                table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
+            table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
                 
     def _load_domains(self, connection):
         ## Load data types for domains: