From 464a706e2a95ff74ed63d07580361bea86c26038 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Mon, 2 Feb 2009 06:25:14 +0000 Subject: [PATCH] factored out column reflection from reflecttable --- lib/sqlalchemy/dialects/postgres/base.py | 139 +++++++++++++++-------- 1 file changed, 93 insertions(+), 46 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index bab875d686..2bdab3914f 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -499,20 +499,64 @@ class PGDialect(default.DefaultDialect): raise AssertionError("Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3)]) - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - if table.schema is not None: + def _prepare_info_cache(self, info_cache, tablename, schemaname): + """Add index for schemaname.table_name if it does not exist. + + This is done so that certain keys can be assumed to be present. + + """ + # First, make sure it has the keys we expect. + if info_cache is None: + info_cache = dict(tables={}) + elif 'tables' not in info_cache: + info_cache['tables'] = {} + # Add the table index if needed. + table_index = "%s.%s" % (schemaname, tablename) + if table_index not in info_cache['tables']: + info_cache['tables'][table_index] = {} + return info_cache + + def _get_table_oid(self, connection, tablename, schemaname=None): + """Fetch the oid for schemaname.tablename. + + Several reflection methods require the table oid. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + if schemaname is not None: schema_where_clause = "n.nspname = :schema" - schemaname = table.schema - - # Py2K - if isinstance(schemaname, str): - schemaname = schemaname.decode(self.encoding) - # end Py2K else: schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - schemaname = None - + query = """ + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in ('r','v') + """ % schema_where_clause + s = sql.text(query, bindparams=[ + sql.bindparam('table_name', type_=sqltypes.Unicode), + sql.bindparam('schema', type_=sqltypes.Unicode) + ], + typemap={'oid':sqltypes.Integer} + ) + c = connection.execute(s, table_name=tablename, schema=schemaname) + table_oid = c.scalar() + if table_oid is None: + raise exc.NoSuchTableError(table_name) + return table_oid + + def get_columns(self, connection, tablename, schemaname=None, + info_cache=None): + info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + # looked for cached table oid + table_index = "%s.%s" % (schemaname, tablename) + table_oid = info_cache['tables'][table_index].get('table_oid') + if table_oid is None: + table_oid = self._get_table_oid(connection, tablename, schemaname) + # cache it + info_cache['tables'][table_index]['table_oid'] = table_oid SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), @@ -521,44 +565,28 @@ class PGDialect(default.DefaultDialect): AS DEFAULT, a.attnotnull, a.attnum, a.attrelid as table_oid FROM pg_catalog.pg_attribute a - WHERE a.attrelid = ( - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in ('r','v') - ) AND a.attnum > 0 AND NOT a.attisdropped + WHERE a.attrelid = :table_oid + AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum - """ % schema_where_clause - - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}) - tablename = table.name - # Py2K - if isinstance(tablename, str): - tablename = tablename.decode(self.encoding) - # end Py2K - c = connection.execute(s, table_name=tablename, schema=schemaname) + """ + s = sql.text(SQL_COLS, + bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], + typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode} + ) + c = connection.execute(s, table_oid=table_oid) rows = c.fetchall() - - if not rows: - raise exc.NoSuchTableError(table.name) - domains = self._load_domains(connection) - + # format columns + columns = [] for name, format_type, default, notnull, attnum, table_oid in rows: - if include_columns and name not in include_columns: - continue - ## strip (30) from character varying(30) attype = re.search('([^\([]+)', format_type).group(1) nullable = not notnull is_array = format_type.endswith('[]') - try: charlen = re.search('\(([\d,]+)\)', format_type).group(1) except: charlen = False - numericprec = False numericscale = False if attype == 'numeric': @@ -573,20 +601,17 @@ class PGDialect(default.DefaultDialect): if attype == 'integer': numericprec, numericscale = (32, 0) charlen = False - args = [] for a in (charlen, numericprec, numericscale): if a is None: args.append(None) elif a is not False: args.append(int(a)) - kwargs = {} if attype == 'timestamp with time zone': kwargs['timezone'] = True elif attype == 'timestamp without time zone': kwargs['timezone'] = False - if attype in self.ischema_names: coltype = self.ischema_names[attype] else: @@ -595,14 +620,12 @@ class PGDialect(default.DefaultDialect): if domain['attype'] in self.ischema_names: # A table can't override whether the domain is nullable. nullable = domain['nullable'] - if domain['default'] and not default: # It can, however, override the default value, but can't set it to null. default = domain['default'] coltype = self.ischema_names[domain['attype']] else: coltype = None - if coltype: coltype = coltype(*args, **kwargs) if is_array: @@ -611,21 +634,45 @@ class PGDialect(default.DefaultDialect): util.warn("Did not recognize type '%s' of column '%s'" % (attype, name)) coltype = sqltypes.NULLTYPE - colargs = [] + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, colargs=colargs) + columns.append(column_info) + return columns + + def reflecttable(self, connection, table, include_columns): + preparer = self.identifier_preparer + schemaname = table.schema + tablename = table.name + # Py2K + if isinstance(schemaname, str): + schemaname = schemaname.decode(self.encoding) + if isinstance(tablename, str): + tablename = tablename.decode(self.encoding) + # end Py2K + info_cache = {} + for col_d in self.get_columns(connection, tablename, schemaname, + info_cache): + name = col_d['name'] + coltype = col_d['type'] + nullable = col_d['nullable'] + default = col_d['default'] + colargs = col_d['colargs'] + if include_columns and name not in include_columns: + continue if default is not None: match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) if match is not None: # the default is related to a Sequence - sch = table.schema + sch = schemaname if '.' not in match.group(2) and sch is not None: # unconditionally quote the schema name. this could # later be enhanced to obey quoting rules / "quote schema" default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) colargs.append(schema.DefaultClause(sql.text(default))) table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) - - + # Now we have the table oid cached. + table_oid = info_cache['tables']["%s.%s" % (schemaname, tablename)]['table_oid'] # Primary keys PK_SQL = """ SELECT attname FROM pg_attribute -- 2.47.3