From 5f196e087b531ef11e52edfbeff775728614eba6 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 11 Feb 2009 04:44:43 +0000 Subject: [PATCH] added DefaultInfoCache --- lib/sqlalchemy/engine/default.py | 122 +++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index b719219a5d..5a0044d9f1 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -20,6 +20,128 @@ from sqlalchemy import exc, types as sqltypes AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) +class DefaultInfoCache(object): + """Default implementation of InfoCache + + InfoCache provides a means for dialects to cache information obtained for + reflection and a convenient interface for setting and retrieving cached + data. + + """ + + def __init__(self): + self._cache = dict(schemas={}) + self.tables_are_complete = False + self.schemas_are_complete = False + + def clear(self): + """Clear the cache.""" + self._cache = dict(schemas={}) + + def getSchemas(self): + """Return the schemas dict.""" + return self._cache.get('schemas') + + def getSchemaNames(self, check_complete=True): + """Return cached schema names. + + By default, only return them if they're complete. + + """ + if check_complete and self.schemas_are_complete: + return self.getSchemas().keys() + elif not check_complete: + return self.getSchemas().keys() + else: + return None + + def getSchema(self, schemaname, create=False): + """Return cached schema and optionally create it if it does not exist. + + """ + schema = self._cache['schemas'].get(schemaname) + if schema is not None: + return schema + elif create: + return self.addSchema(schemaname) + return None + + def addSchema(self, schemaname): + self._cache['schemas'][schemaname] = dict(tables={}) + return self.getSchema(schemaname) + + def addAllSchemas(self, schemanames): + for schemaname in schemanames: + self.addSchema(schemaname) + self.schemas_are_complete = True + + def getTable(self, tablename, schemaname=None, create=False, + table_type='table'): + """Return cached table and optionally create it if it does not exist. + + + """ + cache = self._cache + schema = self.getSchema(schemaname, create=create) + if schema is None: + return None + if table_type == 'view': + table = schema['views'].get(tablename) + else: + table = schema['tables'].get(tablename) + if table is not None: + return table + elif create: + return self.addTable(tablename, schemaname, table_type=table_type) + return None + + def getTableNames(self, schemaname=None, check_complete=True, + table_type='table'): + """Return cached table names. + + By default, only return them if they're complete. + + """ + if table_type == 'view': + complete = self.views_are_complete + else: + complete = self.tables_are_complete + if check_complete and complete: + return self.getTables(schemaname, table_type=table_type).keys() + elif not check_complete: + return self.getTables(schemaname, table_type=table_type).keys() + else: + return None + + def addTable(self, tablename, schemaname=None, table_type='table'): + schema = self.getSchema(schemaname, create=True) + if table_type == 'table': + schema['tables'][tablename] = dict(columns={}) + else: + schema['views'][tablename] = dict(columns={}) + return self.getTable(tablename, schemaname, table_type=table_type) + + def addAllTables(self, tablenames, schemaname=None, table_type='table'): + for tablename in tablenames: + self.addTable(tablename, schemaname, table_type) + if table_type == 'view': + self.views_are_complete = True + else: + self.tables_are_complete = True + + def getView(self, viewname, schemaname=None, create=False): + return self.getTable(viewname, schemaname, create, 'view') + + def getViewNames(self, schemaname=None, check_complete=True): + return self.getTableNames(schemaname, check_complete, 'view') + + def addView(self, viewname, schemaname=None): + return self.addTable(viewname, schemaname, 'view') + + def addAllViews(self, viewnames, schemaname=None): + return self.addAllTables(viewnames, schemaname, 'view') + + class DefaultDialect(base.Dialect): """Default implementation of Dialect""" -- 2.47.3