]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added DefaultInfoCache
authorRandall Smith <randall@tnr.cc>
Wed, 11 Feb 2009 04:44:43 +0000 (04:44 +0000)
committerRandall Smith <randall@tnr.cc>
Wed, 11 Feb 2009 04:44:43 +0000 (04:44 +0000)
lib/sqlalchemy/engine/default.py

index b719219a5df794c2c20b8af18d4c0bb5a047d7b7..5a0044d9f154583c2ec1288b649425f13ad2d4db 100644 (file)
@@ -20,6 +20,128 @@ from sqlalchemy import exc, types as sqltypes
 AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
                                re.I | re.UNICODE)
 
+class DefaultInfoCache(object):
+    """Default implementation of InfoCache
+
+    InfoCache provides a means for dialects to cache information obtained for
+    reflection and a convenient interface for setting and retrieving cached
+    data.
+
+    """
+    
+    def __init__(self):
+        self._cache = dict(schemas={})
+        self.tables_are_complete = False
+        self.schemas_are_complete = False
+
+    def clear(self):
+        """Clear the cache."""
+        self._cache = dict(schemas={})
+
+    def getSchemas(self):
+        """Return the schemas dict."""
+        return self._cache.get('schemas')
+
+    def getSchemaNames(self, check_complete=True):
+        """Return cached schema names.
+
+        By default, only return them if they're complete.
+
+        """
+        if check_complete and self.schemas_are_complete:
+            return self.getSchemas().keys()
+        elif not check_complete:
+            return self.getSchemas().keys()
+        else:
+            return None
+
+    def getSchema(self, schemaname, create=False):
+        """Return cached schema and optionally create it if it does not exist.
+
+        """
+        schema = self._cache['schemas'].get(schemaname)
+        if schema is not None:
+            return schema
+        elif create:
+            return self.addSchema(schemaname)
+        return None
+
+    def addSchema(self, schemaname):
+        self._cache['schemas'][schemaname] = dict(tables={})
+        return self.getSchema(schemaname)
+
+    def addAllSchemas(self, schemanames):
+        for schemaname in schemanames:
+            self.addSchema(schemaname)
+        self.schemas_are_complete = True
+
+    def getTable(self, tablename, schemaname=None, create=False,
+                                                        table_type='table'):
+        """Return cached table and optionally create it if it does not exist.
+
+
+        """
+        cache = self._cache
+        schema = self.getSchema(schemaname, create=create)
+        if schema is None:
+            return None
+        if table_type == 'view':
+            table = schema['views'].get(tablename)
+        else:
+            table = schema['tables'].get(tablename)
+        if table is not None:
+            return table
+        elif create:
+            return self.addTable(tablename, schemaname, table_type=table_type)
+        return None
+
+    def getTableNames(self, schemaname=None, check_complete=True,
+                                                        table_type='table'):
+        """Return cached table names.
+
+        By default, only return them if they're complete.
+
+        """
+        if table_type == 'view':
+            complete = self.views_are_complete
+        else:
+            complete = self.tables_are_complete
+        if check_complete and complete:
+            return self.getTables(schemaname, table_type=table_type).keys()
+        elif not check_complete:
+            return self.getTables(schemaname, table_type=table_type).keys()
+        else:
+            return None
+
+    def addTable(self, tablename, schemaname=None, table_type='table'):
+        schema = self.getSchema(schemaname, create=True)
+        if table_type == 'table':
+            schema['tables'][tablename] = dict(columns={})
+        else:
+            schema['views'][tablename] = dict(columns={})
+        return self.getTable(tablename, schemaname, table_type=table_type)
+
+    def addAllTables(self, tablenames, schemaname=None, table_type='table'):
+        for tablename in tablenames:
+            self.addTable(tablename, schemaname, table_type)
+        if table_type == 'view':
+            self.views_are_complete = True
+        else:
+            self.tables_are_complete = True
+            
+    def getView(self, viewname, schemaname=None, create=False):
+        return self.getTable(viewname, schemaname, create, 'view')
+
+    def getViewNames(self, schemaname=None, check_complete=True):
+        return self.getTableNames(schemaname, check_complete, 'view')
+
+    def addView(self, viewname, schemaname=None):
+        return self.addTable(viewname, schemaname, 'view')
+
+    def addAllViews(self, viewnames, schemaname=None):
+        return self.addAllTables(viewnames, schemaname, 'view')
+
+
 class DefaultDialect(base.Dialect):
     """Default implementation of Dialect"""