raise AssertionError("Could not determine version from string '%s'" % v)
return tuple([int(x) for x in m.group(1, 2, 3)])
- def reflecttable(self, connection, table, include_columns):
- preparer = self.identifier_preparer
- if table.schema is not None:
+ def _prepare_info_cache(self, info_cache, tablename, schemaname):
+ """Add index for schemaname.table_name if it does not exist.
+
+ This is done so that certain keys can be assumed to be present.
+
+ """
+ # First, make sure it has the keys we expect.
+ if info_cache is None:
+ info_cache = dict(tables={})
+ elif 'tables' not in info_cache:
+ info_cache['tables'] = {}
+ # Add the table index if needed.
+ table_index = "%s.%s" % (schemaname, tablename)
+ if table_index not in info_cache['tables']:
+ info_cache['tables'][table_index] = {}
+ return info_cache
+
+ def _get_table_oid(self, connection, tablename, schemaname=None):
+ """Fetch the oid for schemaname.tablename.
+
+ Several reflection methods require the table oid. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+ if schemaname is not None:
schema_where_clause = "n.nspname = :schema"
- schemaname = table.schema
-
- # Py2K
- if isinstance(schemaname, str):
- schemaname = schemaname.decode(self.encoding)
- # end Py2K
else:
schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- schemaname = None
-
+ query = """
+ SELECT c.oid
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE (%s)
+ AND c.relname = :table_name AND c.relkind in ('r','v')
+ """ % schema_where_clause
+ s = sql.text(query, bindparams=[
+ sql.bindparam('table_name', type_=sqltypes.Unicode),
+ sql.bindparam('schema', type_=sqltypes.Unicode)
+ ],
+ typemap={'oid':sqltypes.Integer}
+ )
+ c = connection.execute(s, table_name=tablename, schema=schemaname)
+ table_oid = c.scalar()
+ if table_oid is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_oid
+
+ def get_columns(self, connection, tablename, schemaname=None,
+ info_cache=None):
+ info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+ # looked for cached table oid
+ table_index = "%s.%s" % (schemaname, tablename)
+ table_oid = info_cache['tables'][table_index].get('table_oid')
+ if table_oid is None:
+ table_oid = self._get_table_oid(connection, tablename, schemaname)
+ # cache it
+ info_cache['tables'][table_index]['table_oid'] = table_oid
SQL_COLS = """
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod),
AS DEFAULT,
a.attnotnull, a.attnum, a.attrelid as table_oid
FROM pg_catalog.pg_attribute a
- WHERE a.attrelid = (
- SELECT c.oid
- FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE (%s)
- AND c.relname = :table_name AND c.relkind in ('r','v')
- ) AND a.attnum > 0 AND NOT a.attisdropped
+ WHERE a.attrelid = :table_oid
+ AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
- """ % schema_where_clause
-
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode})
- tablename = table.name
- # Py2K
- if isinstance(tablename, str):
- tablename = tablename.decode(self.encoding)
- # end Py2K
- c = connection.execute(s, table_name=tablename, schema=schemaname)
+ """
+ s = sql.text(SQL_COLS,
+ bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)],
+ typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}
+ )
+ c = connection.execute(s, table_oid=table_oid)
rows = c.fetchall()
-
- if not rows:
- raise exc.NoSuchTableError(table.name)
-
domains = self._load_domains(connection)
-
+ # format columns
+ columns = []
for name, format_type, default, notnull, attnum, table_oid in rows:
- if include_columns and name not in include_columns:
- continue
-
## strip (30) from character varying(30)
attype = re.search('([^\([]+)', format_type).group(1)
nullable = not notnull
is_array = format_type.endswith('[]')
-
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
except:
charlen = False
-
numericprec = False
numericscale = False
if attype == 'numeric':
if attype == 'integer':
numericprec, numericscale = (32, 0)
charlen = False
-
args = []
for a in (charlen, numericprec, numericscale):
if a is None:
args.append(None)
elif a is not False:
args.append(int(a))
-
kwargs = {}
if attype == 'timestamp with time zone':
kwargs['timezone'] = True
elif attype == 'timestamp without time zone':
kwargs['timezone'] = False
-
if attype in self.ischema_names:
coltype = self.ischema_names[attype]
else:
if domain['attype'] in self.ischema_names:
# A table can't override whether the domain is nullable.
nullable = domain['nullable']
-
if domain['default'] and not default:
# It can, however, override the default value, but can't set it to null.
default = domain['default']
coltype = self.ischema_names[domain['attype']]
else:
coltype = None
-
if coltype:
coltype = coltype(*args, **kwargs)
if is_array:
util.warn("Did not recognize type '%s' of column '%s'" %
(attype, name))
coltype = sqltypes.NULLTYPE
-
colargs = []
+ column_info = dict(name=name, type=coltype, nullable=nullable,
+ default=default, colargs=colargs)
+ columns.append(column_info)
+ return columns
+
+ def reflecttable(self, connection, table, include_columns):
+ preparer = self.identifier_preparer
+ schemaname = table.schema
+ tablename = table.name
+ # Py2K
+ if isinstance(schemaname, str):
+ schemaname = schemaname.decode(self.encoding)
+ if isinstance(tablename, str):
+ tablename = tablename.decode(self.encoding)
+ # end Py2K
+ info_cache = {}
+ for col_d in self.get_columns(connection, tablename, schemaname,
+ info_cache):
+ name = col_d['name']
+ coltype = col_d['type']
+ nullable = col_d['nullable']
+ default = col_d['default']
+ colargs = col_d['colargs']
+ if include_columns and name not in include_columns:
+ continue
if default is not None:
match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
if match is not None:
# the default is related to a Sequence
- sch = table.schema
+ sch = schemaname
if '.' not in match.group(2) and sch is not None:
# unconditionally quote the schema name. this could
# later be enhanced to obey quoting rules / "quote schema"
default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
colargs.append(schema.DefaultClause(sql.text(default)))
table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
-
-
+ # Now we have the table oid cached.
+ table_oid = info_cache['tables']["%s.%s" % (schemaname, tablename)]['table_oid']
# Primary keys
PK_SQL = """
SELECT attname FROM pg_attribute