]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
using util.decorator and adding *kw to reflection method signatures
authorRandall Smith <randall@tnr.cc>
Sun, 1 Mar 2009 04:05:19 +0000 (04:05 +0000)
committerRandall Smith <randall@tnr.cc>
Sun, 1 Mar 2009 04:05:19 +0000 (04:05 +0000)
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/engine/reflection.py

index e9a1f09e335e0af9933762ff5cb2d2e1a4b26247..6817ce0bcbf97e5a097463038309f9ccc2492e9f 100644 (file)
@@ -537,7 +537,7 @@ class PGDialect(default.DefaultDialect):
         return table_oid
 
     @reflection.cache
-    def get_schema_names(self, connection):
+    def get_schema_names(self, connection, **kw):
         s = """
         SELECT nspname
         FROM pg_namespace
@@ -550,7 +550,7 @@ class PGDialect(default.DefaultDialect):
         return schema_names
 
     @reflection.cache
-    def get_table_names(self, connection, schemaname=None):
+    def get_table_names(self, connection, schemaname=None, **kw):
         if schemaname is not None:
             current_schema = schemaname
         else:
@@ -559,7 +559,7 @@ class PGDialect(default.DefaultDialect):
         return table_names
 
     @reflection.cache
-    def get_view_names(self, connection, schemaname=None):
+    def get_view_names(self, connection, schemaname=None, **kw):
         if schemaname is not None:
             current_schema = schemaname
         else:
@@ -574,7 +574,7 @@ class PGDialect(default.DefaultDialect):
         return view_names
 
     @reflection.cache
-    def get_view_definition(self, connection, viewname, schemaname=None):
+    def get_view_definition(self, connection, viewname, schemaname=None, **kw):
         if schemaname is not None:
             current_schema = schemaname
         else:
@@ -591,7 +591,7 @@ class PGDialect(default.DefaultDialect):
             return view_def
 
     @reflection.cache
-    def get_columns(self, connection, tablename, schemaname=None):
+    def get_columns(self, connection, tablename, schemaname=None, **kw):
 
         table_oid = self._get_table_oid(connection, tablename, schemaname)
         SQL_COLS = """
@@ -678,7 +678,7 @@ class PGDialect(default.DefaultDialect):
         return columns
 
     @reflection.cache
-    def get_primary_keys(self, connection, tablename, schemaname=None):
+    def get_primary_keys(self, connection, tablename, schemaname=None, **kw):
         table_oid = self._get_table_oid(connection, tablename, schemaname)
         PK_SQL = """
           SELECT attname FROM pg_attribute
@@ -694,7 +694,7 @@ class PGDialect(default.DefaultDialect):
         return primary_keys
 
     @reflection.cache
-    def get_foreign_keys(self, connection, tablename, schemaname=None):
+    def get_foreign_keys(self, connection, tablename, schemaname=None, **kw):
         preparer = self.identifier_preparer
         table_oid = self._get_table_oid(connection, tablename, schemaname)
         FK_SQL = """
@@ -731,7 +731,7 @@ class PGDialect(default.DefaultDialect):
         return fkeys
 
     @reflection.cache
-    def get_indexes(self, connection, tablename, schemaname):
+    def get_indexes(self, connection, tablename, schemaname, **kw):
         table_oid = self._get_table_oid(connection, tablename, schemaname)
         IDX_SQL = """
           SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
@@ -775,13 +775,15 @@ class PGDialect(default.DefaultDialect):
         preparer = self.identifier_preparer
         schemaname = table.schema
         tablename = table.name
+        info_cache = {}
         # Py2K
         if isinstance(schemaname, str):
             schemaname = schemaname.decode(self.encoding)
         if isinstance(tablename, str):
             tablename = tablename.decode(self.encoding)
         # end Py2K
-        for col_d in self.get_columns(connection, tablename, schemaname):
+        for col_d in self.get_columns(connection, tablename, schemaname,
+                                      info_cache=info_cache):
             name = col_d['name']
             coltype = col_d['type']
             nullable = col_d['nullable']
@@ -803,14 +805,16 @@ class PGDialect(default.DefaultDialect):
         # Now we have the table oid cached.
         table_oid = self._get_table_oid(connection, tablename, schemaname)
         # Primary keys
-        for pk in self.get_primary_keys(connection, tablename, schemaname):
+        for pk in self.get_primary_keys(connection, tablename, schemaname,
+                                        info_cache=info_cache):
             if pk in table.c:
                 col = table.c[pk]
                 table.primary_key.add(col)
                 if col.default is None:
                     col.autoincrement = False
         # Foreign keys
-        fkeys = self.get_foreign_keys(connection, tablename, schemaname)
+        fkeys = self.get_foreign_keys(connection, tablename, schemaname,
+                                      info_cache=info_cache)
         for fkey_d in fkeys:
             conname = fkey_d['name']
             constrained_columns = fkey_d['constrained_columns']
@@ -831,7 +835,8 @@ class PGDialect(default.DefaultDialect):
             table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
 
         # Indexes 
-        indexes = self.get_indexes(connection, tablename, schemaname)
+        indexes = self.get_indexes(connection, tablename, schemaname,
+                                   info_cache=info_cache)
         for index_d in indexes:
             name = index_d['name']
             columns = index_d['column_names']
index 2f7d3021d6c764792dfca072fd5d44432484430e..2e1ec9f50df0a6ed4916cf6de8863bc174cd3e1c 100644 (file)
@@ -22,19 +22,17 @@ from sqlalchemy import util
 from sqlalchemy.types import TypeEngine
 
 
-##@util.decorator
-def cache(fn):
-    def decorated(self, con, *args, **kw):
-        info_cache = kw.pop('info_cache', None)
-        if info_cache is None:
-            return fn(self, con, *args, **kw)
-        key = (fn.__name__, args, str(kw))
-        ret = info_cache.get(key)
-        if ret is None:
-            ret = fn(self, con, *args, **kw)
-            info_cache[key] = ret
-        return ret
-    return decorated
+@util.decorator
+def cache(fn, self, con, *args, **kw):
+    info_cache = kw.pop('info_cache', None)
+    if info_cache is None:
+        return fn(self, con, *args, **kw)
+    key = (fn.__name__, args, str(kw))
+    ret = info_cache.get(key)
+    if ret is None:
+        ret = fn(self, con, *args, **kw)
+        info_cache[key] = ret
+    return ret
 
 # keeping this around until all dialects are fixed
 @util.decorator