]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
moving to simpler cache technique
authorRandall Smith <randall@tnr.cc>
Sat, 28 Feb 2009 06:06:44 +0000 (06:06 +0000)
committerRandall Smith <randall@tnr.cc>
Sat, 28 Feb 2009 06:06:44 +0000 (06:06 +0000)
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/engine/reflection.py

index d031e30ae8ff02afdbb96f45a4460bc157f49c21..e9a1f09e335e0af9933762ff5cb2d2e1a4b26247 100644 (file)
@@ -397,21 +397,6 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
             value = value[1:-1].replace('""','"')
         return value
 
-class PGInfoCache(reflection.DefaultInfoCache):
-    
-    def __init__(self):
-
-        reflection.DefaultInfoCache.__init__(self)
-
-    def get_table_oid(self, tablename, schemaname=None):
-        table = self.get_table(tablename, schemaname)
-        if table:
-            return table.get('oid')
-
-    def set_table_oid(self, oid, tablename, schemaname=None):
-        table = self.get_table(tablename, schemaname, create=True)
-        table['oid'] = oid
-
 class PGDialect(default.DefaultDialect):
     name = 'postgres'
     supports_alter = True
@@ -432,7 +417,6 @@ class PGDialect(default.DefaultDialect):
     type_compiler = PGTypeCompiler
     preparer = PGIdentifierPreparer
     defaultrunner = PGDefaultRunner
-    info_cache = PGInfoCache
 
 
     def do_begin_twophase(self, connection, xid):
@@ -515,8 +499,7 @@ class PGDialect(default.DefaultDialect):
             raise AssertionError("Could not determine version from string '%s'" % v)
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
-    def _get_table_oid(self, connection, tablename, schemaname=None,
-                       info_cache=None):
+    def _get_table_oid(self, connection, tablename, schemaname=None):
         """Fetch the oid for schemaname.tablename.
 
         Several reflection methods require the table oid.  The idea for using
@@ -525,10 +508,6 @@ class PGDialect(default.DefaultDialect):
 
         """
         table_oid = None
-        if info_cache:
-            table_oid = info_cache.get_table_oid(tablename, schemaname)
-        if table_oid:
-            return table_oid
         if schemaname is not None:
             schema_where_clause = "n.nspname = :schema"
         else:
@@ -555,13 +534,10 @@ class PGDialect(default.DefaultDialect):
         table_oid = c.scalar()
         if table_oid is None:
             raise exc.NoSuchTableError(tablename)
-        # cache it
-        if info_cache:
-            info_cache.set_table_oid(table_oid, tablename, schemaname)
         return table_oid
 
-    @reflection.caches
-    def get_schema_names(self, connection, info_cache=None):
+    @reflection.cache
+    def get_schema_names(self, connection):
         s = """
         SELECT nspname
         FROM pg_namespace
@@ -573,8 +549,8 @@ class PGDialect(default.DefaultDialect):
                         if not row[0].startswith('pg_')]
         return schema_names
 
-    @reflection.caches
-    def get_table_names(self, connection, schemaname=None, info_cache=None):
+    @reflection.cache
+    def get_table_names(self, connection, schemaname=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
@@ -582,8 +558,8 @@ class PGDialect(default.DefaultDialect):
         table_names = self.table_names(connection, current_schema)
         return table_names
 
-    @reflection.caches
-    def get_view_names(self, connection, schemaname=None, info_cache=None):
+    @reflection.cache
+    def get_view_names(self, connection, schemaname=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
@@ -597,9 +573,8 @@ class PGDialect(default.DefaultDialect):
         view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
         return view_names
 
-    @reflection.caches
-    def get_view_definition(self, connection, viewname, schemaname=None,
-                                                            info_cache=None):
+    @reflection.cache
+    def get_view_definition(self, connection, viewname, schemaname=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
@@ -615,12 +590,10 @@ class PGDialect(default.DefaultDialect):
             view_def = rp.scalar().decode(self.encoding)
             return view_def
 
-    @reflection.caches
-    def get_columns(self, connection, tablename, schemaname=None,
-                    info_cache=None):
+    @reflection.cache
+    def get_columns(self, connection, tablename, schemaname=None):
 
-        table_oid = self._get_table_oid(connection, tablename, schemaname,
-                                        info_cache)
+        table_oid = self._get_table_oid(connection, tablename, schemaname)
         SQL_COLS = """
             SELECT a.attname,
               pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -704,11 +677,9 @@ class PGDialect(default.DefaultDialect):
             columns.append(column_info)
         return columns
 
-    @reflection.caches
-    def get_primary_keys(self, connection, tablename, schemaname=None,
-                         info_cache=None):
-        table_oid = self._get_table_oid(connection, tablename, schemaname,
-                                        info_cache)
+    @reflection.cache
+    def get_primary_keys(self, connection, tablename, schemaname=None):
+        table_oid = self._get_table_oid(connection, tablename, schemaname)
         PK_SQL = """
           SELECT attname FROM pg_attribute
           WHERE attrelid = (
@@ -722,12 +693,10 @@ class PGDialect(default.DefaultDialect):
         primary_keys = [r[0] for r in c.fetchall()]
         return primary_keys
 
-    @reflection.caches
-    def get_foreign_keys(self, connection, tablename, schemaname=None,
-                         info_cache=None):
+    @reflection.cache
+    def get_foreign_keys(self, connection, tablename, schemaname=None):
         preparer = self.identifier_preparer
-        table_oid = self._get_table_oid(connection, tablename, schemaname,
-                                        info_cache)
+        table_oid = self._get_table_oid(connection, tablename, schemaname)
         FK_SQL = """
           SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
           FROM  pg_catalog.pg_constraint r
@@ -761,10 +730,9 @@ class PGDialect(default.DefaultDialect):
             fkeys.append(fkey_d)
         return fkeys
 
-    @reflection.caches
-    def get_indexes(self, connection, tablename, schemaname, info_cache=None):
-        table_oid = self._get_table_oid(connection, tablename, schemaname,
-                                        info_cache)
+    @reflection.cache
+    def get_indexes(self, connection, tablename, schemaname):
+        table_oid = self._get_table_oid(connection, tablename, schemaname)
         IDX_SQL = """
           SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
             a.attname
@@ -813,9 +781,7 @@ class PGDialect(default.DefaultDialect):
         if isinstance(tablename, str):
             tablename = tablename.decode(self.encoding)
         # end Py2K
-        info_cache = PGInfoCache()
-        for col_d in self.get_columns(connection, tablename, schemaname,
-                                                                info_cache):
+        for col_d in self.get_columns(connection, tablename, schemaname):
             name = col_d['name']
             coltype = col_d['type']
             nullable = col_d['nullable']
@@ -835,19 +801,16 @@ class PGDialect(default.DefaultDialect):
                 colargs.append(schema.DefaultClause(sql.text(default)))
             table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
         # Now we have the table oid cached.
-        table_oid = self._get_table_oid(connection, tablename, schemaname,
-                                        info_cache)
+        table_oid = self._get_table_oid(connection, tablename, schemaname)
         # Primary keys
-        for pk in self.get_primary_keys(connection, tablename, schemaname,
-                                                                    info_cache):
+        for pk in self.get_primary_keys(connection, tablename, schemaname):
             if pk in table.c:
                 col = table.c[pk]
                 table.primary_key.add(col)
                 if col.default is None:
                     col.autoincrement = False
         # Foreign keys
-        fkeys = self.get_foreign_keys(connection, tablename, schemaname,
-                                      info_cache)
+        fkeys = self.get_foreign_keys(connection, tablename, schemaname)
         for fkey_d in fkeys:
             conname = fkey_d['name']
             constrained_columns = fkey_d['constrained_columns']
@@ -868,8 +831,7 @@ class PGDialect(default.DefaultDialect):
             table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
 
         # Indexes 
-        indexes = self.get_indexes(connection, tablename, schemaname,
-                                   info_cache)
+        indexes = self.get_indexes(connection, tablename, schemaname)
         for index_d in indexes:
             name = index_d['name']
             columns = index_d['column_names']
index 7f8143d6006443aaeb044a959d39de2fc4604bdf..2f7d3021d6c764792dfca072fd5d44432484430e 100644 (file)
@@ -22,6 +22,21 @@ from sqlalchemy import util
 from sqlalchemy.types import TypeEngine
 
 
+##@util.decorator
+def cache(fn):
+    def decorated(self, con, *args, **kw):
+        info_cache = kw.pop('info_cache', None)
+        if info_cache is None:
+            return fn(self, con, *args, **kw)
+        key = (fn.__name__, args, str(kw))
+        ret = info_cache.get(key)
+        if ret is None:
+            ret = fn(self, con, *args, **kw)
+            info_cache[key] = ret
+        return ret
+    return decorated
+
+# keeping this around until all dialects are fixed
 @util.decorator
 def caches(fn, self, con, *args, **kw):
     # what are we caching?
@@ -270,7 +285,11 @@ class Inspector(object):
             self.engine = conn.engine
         else:
             self.engine = conn
-        self.info_cache = self.engine.dialect.info_cache()
+        # fixme. This is just until all dialects are converted
+        if hasattr(self, 'info_cache'):
+            self.info_cache = self.engine.dialect.info_cache()
+        else:
+            self.info_cache = {}
 
     def default_schema_name(self):
         return self.engine.dialect.get_default_schema_name(self.conn)
@@ -282,7 +301,7 @@ class Inspector(object):
         """
         if hasattr(self.engine.dialect, 'get_schema_names'):
             return self.engine.dialect.get_schema_names(self.conn,
-                                                        self.info_cache)
+                                                    info_cache=self.info_cache)
         return []
 
     def get_table_names(self, schemaname=None, order_by=None):
@@ -296,7 +315,7 @@ class Inspector(object):
         """
         if hasattr(self.engine.dialect, 'get_table_names'):
             tnames = self.engine.dialect.get_table_names(self.conn, schemaname,
-                                                       self.info_cache)
+                                                    info_cache=self.info_cache)
         else:
             tnames = self.engine.table_names(schemaname)
         if order_by == 'foreign_key':
@@ -325,7 +344,7 @@ class Inspector(object):
 
         """
         return self.engine.dialect.get_view_names(self.conn, schemaname,
-                                                  self.info_cache)
+                                                  info_cache=self.info_cache)
 
     def get_view_definition(self, view_name, schemaname=None):
         """Return definition for `view_name`.
@@ -334,7 +353,7 @@ class Inspector(object):
 
         """
         return self.engine.dialect.get_view_definition(
-            self.conn, view_name, schemaname, self.info_cache)
+            self.conn, view_name, schemaname, info_cache=self.info_cache)
 
     def get_columns(self, tablename, schemaname=None):
         """Return information about columns in `tablename`.
@@ -379,7 +398,7 @@ class Inspector(object):
 
         pkeys = self.engine.dialect.get_primary_keys(self.conn, tablename,
                                                      schemaname,
-                                                     self.info_cache)
+                                            info_cache=self.info_cache)
 
         return pkeys
 
@@ -406,7 +425,7 @@ class Inspector(object):
 
         fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename,
                                                        schemaname,
-                                                       self.info_cache)
+                                                info_cache=self.info_cache)
         for fk_def in fk_defs:
             referred_schema = fk_def['referred_schema']
             # always set the referred_schema.
@@ -435,5 +454,5 @@ class Inspector(object):
 
         indexes = self.engine.dialect.get_indexes(self.conn, tablename,
                                                   schemaname,
-                                                  self.info_cache)
+                                            info_cache=self.info_cache)
         return indexes