value = value[1:-1].replace('""','"')
return value
+class PGInfoCache(default.DefaultInfoCache):
+
+ def __init__(self):
+
+ default.DefaultInfoCache.__init__(self)
+
+ def getTableOID(self, tablename, schemaname=None):
+ table = self.getTable(tablename, schemaname)
+ if table:
+ return table.get('oid')
+
+ def setTableOID(self, oid, tablename, schemaname=None):
+ table = self.getTable(tablename, schemaname, create=True)
+ table['oid'] = oid
+
class PGDialect(default.DefaultDialect):
name = 'postgres'
supports_alter = True
raise AssertionError("Could not determine version from string '%s'" % v)
return tuple([int(x) for x in m.group(1, 2, 3)])
- def _prepare_info_cache(self, info_cache, tablename, schemaname):
- """Add index for schemaname.table_name if it does not exist.
-
- This is done so that certain keys can be assumed to be present.
-
- """
- # First, make sure it has the keys we expect.
- if info_cache is None:
- info_cache = dict(tables={})
- elif 'tables' not in info_cache:
- info_cache['tables'] = {}
- # Add the table index if needed.
- table_index = "%s.%s" % (schemaname, tablename)
- if table_index not in info_cache['tables']:
- info_cache['tables'][table_index] = {}
- return info_cache
-
def _get_table_oid(self, connection, tablename, schemaname=None,
info_cache=None):
"""Fetch the oid for schemaname.tablename.
subsequent calls.
"""
- info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
- # If it's in info_cache, juse use that.
- table_index = "%s.%s" % (schemaname, tablename)
- table_oid = info_cache['tables'][table_index].get('table_oid')
+ table_oid = None
+ if info_cache:
+ table_oid = info_cache.getTableOID(tablename, schemaname)
if table_oid:
return table_oid
if schemaname is not None:
if table_oid is None:
raise exc.NoSuchTableError(table_name)
# cache it
- info_cache['tables'][table_index]['table_oid'] = table_oid
+ if info_cache:
+ info_cache.setTableOID(table_oid, tablename, schemaname)
return table_oid
def get_schema_names(self, connection, info_cache=None):
+ if info_cache:
+ schema_names = info_cache.getSchemaNames()
+ if schema_names is not None:
+ return schema_names
s = """
SELECT nspname
FROM pg_namespace
"""
rp = connection.execute(s)
# what about system tables?
- return [row[0].decode(self.encoding) for row in rp \
- if not row[0].startswith('pg_')]
+ schema_names = [row[0].decode(self.encoding) for row in rp \
+ if not row[0].startswith('pg_')]
+ if info_cache:
+ info_cache.addAllSchemas(schema_names)
+ return schema_names
def get_table_names(self, connection, schemaname=None, info_cache=None):
if schemaname is not None:
current_schema = schemaname
else:
current_schema = self.get_default_schema_name(connection)
- return self.table_names(connection, current_schema)
+ if info_cache:
+ table_names = info_cache.getTableNames(current_schema)
+ if table_names is not None:
+ return table_names
+ table_names = self.table_names(connection, current_schema)
+ if info_cache:
+ info_cache.addAllTables(table_names, current_schema)
+ return table_names
def get_view_names(self, connection, schemaname=None, info_cache=None):
if schemaname is not None:
current_schema = schemaname
else:
current_schema = self.get_default_schema_name(connection)
+ if info_cache:
+ view_names = info_cache.getViewNames(current_schema)
+ if view_names is not None:
+ return view_names
s = """
SELECT relname
FROM pg_class c
WHERE relkind = 'v'
AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
""" % dict(schema=current_schema)
- return [row[0].decode(self.encoding) for row in connection.execute(s)]
+ view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
+ if info_cache:
+ info_cache.addAllViews(view_names, schemaname)
+ return view_names
def get_view_definition(self, connection, viewname, schemaname=None,
info_cache=None):
current_schema = schemaname
else:
current_schema = self.get_default_schema_name(connection)
+ if info_cache:
+ view = info_cache.getView(viewname, current_schema)
+ if view.get('definition'):
+ return view['definition']
s = """
SELECT definition FROM pg_views
WHERE schemaname = :schemaname
rp = connection.execute(sql.text(s),
viewname=viewname, schemaname=current_schema)
if rp:
- return rp.scalar().decode(self.encoding)
+ view_def = rp.scalar().decode(self.encoding)
+ if info_cache:
+ view = info_cache.getView(viewname, current_schema,
+ create=True)
+ view['definition'] = view_def
+ return view_def
def get_columns(self, connection, tablename, schemaname=None,
info_cache=None):
- info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname)
+ if table_cache and 'columns' in table_cache.keys():
+ return table_cache.get('columns')
table_oid = self._get_table_oid(connection, tablename, schemaname,
info_cache)
SQL_COLS = """
column_info = dict(name=name, type=coltype, nullable=nullable,
default=default, colargs=colargs)
columns.append(column_info)
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname,
+ create=True)
+ table_cache['columns'] = columns
return columns
def get_primary_keys(self, connection, tablename, schemaname=None,
info_cache=None):
- info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname)
+ if table_cache and 'primary_keys' in table_cache.keys():
+ return table_cache.get('primary_keys')
table_oid = self._get_table_oid(connection, tablename, schemaname,
info_cache)
PK_SQL = """
"""
t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
c = connection.execute(t, table_oid=table_oid)
- return [r[0] for r in c.fetchall()]
- for row in c.fetchall():
- pk = row[0]
- if pk in table.c:
- col = table.c[pk]
- table.primary_key.add(col)
- if col.default is None:
- col.autoincrement = False
+ primary_keys = [r[0] for r in c.fetchall()]
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname,
+ create=True)
+ table_cache['primary_keys'] = primary_keys
+ return primary_keys
def get_foreign_keys(self, connection, tablename, schemaname=None,
info_cache=None):
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname)
+ if table_cache and 'foreign_keys' in table_cache.keys():
+ return table_cache.get('foreign_keys')
preparer = self.identifier_preparer
- info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
table_oid = self._get_table_oid(connection, tablename, schemaname,
info_cache)
FK_SQL = """
'referred_columns' : referred_columns
}
fkeys.append(fkey_d)
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname,
+ create=True)
+ table_cache['foreign_keys'] = fkeys
return fkeys
def get_indexes(self, connection, tablename, schemaname, info_cache=None):
- info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname)
+ if table_cache and 'indexes' in table_cache.keys():
+ return table_cache.get('indexes')
table_oid = self._get_table_oid(connection, tablename, schemaname,
info_cache)
IDX_SQL = """
index_d['name'] = idx_name
index_d['column_names'].append(col)
index_d['unique'] = unique
+ if info_cache:
+ table_cache = info_cache.getTable(tablename, schemaname,
+ create=True)
+ table_cache['indexes'] = indexes
return indexes
def reflecttable(self, connection, table, include_columns):
if isinstance(tablename, str):
tablename = tablename.decode(self.encoding)
# end Py2K
- info_cache = {}
+ info_cache = PGInfoCache()
for col_d in self.get_columns(connection, tablename, schemaname,
info_cache):
name = col_d['name']