]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
moved get_table_names
authorRandall Smith <randall@tnr.cc>
Wed, 4 Mar 2009 05:44:13 +0000 (05:44 +0000)
committerRandall Smith <randall@tnr.cc>
Wed, 4 Mar 2009 05:44:13 +0000 (05:44 +0000)
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
test/reflection.py

index ed6160e83eb934ba537143f89341817869eb1f76..d0e0336f7daebaa7c6307892328e0991299933e1 100644 (file)
@@ -25,7 +25,8 @@ so historical dates are fully supported.
 
 import datetime, re, time
 
-from sqlalchemy import sql, schema, exc, pool, DefaultClause
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import sql, exc, pool, DefaultClause
 from sqlalchemy.engine import default
 from sqlalchemy.engine import reflection
 from sqlalchemy import types as sqltypes
@@ -292,14 +293,67 @@ class SQLiteDialect(default.DefaultDialect):
         return (row is not None)
 
     @reflection.cache
-    def get_columns(self, connection, tablename, schemaname=None,
-                                                        info_cache=None):
+    def get_table_names(self, connection, schema=None, **kw):
+        return self.table_names(connection, schema)
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        if schema is not None:
+            qschema = self.identifier_preparer.quote_identifier(schema)
+            master = '%s.sqlite_master' % qschema
+            s = ("SELECT name FROM %s "
+                 "WHERE type='view' ORDER BY name") % (master,)
+            rs = connection.execute(s)
+        else:
+            try:
+                s = ("SELECT name FROM "
+                     " (SELECT * FROM sqlite_master UNION ALL "
+                     "  SELECT * FROM sqlite_temp_master) "
+                     "WHERE type='view' ORDER BY name")
+                rs = connection.execute(s)
+            except exc.DBAPIError:
+                raise
+                s = ("SELECT name FROM sqlite_master "
+                     "WHERE type='view' ORDER BY name")
+                rs = connection.execute(s)
+
+        return [row[0] for row in rs]
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
         quote = self.identifier_preparer.quote_identifier
-        if schemaname is not None:
-            pragma = "PRAGMA %s." % quote(schemaname)
+        if schema is not None:
+            qschema = self.identifier_preparer.quote_identifier(schema)
+            master = '%s.sqlite_master' % qschema
+            s = ("SELECT sql FROM %s WHERE name = '%s'"
+                 "AND type='view'") % (master, view_name)
+            rs = connection.execute(s)
+        else:
+            try:
+                s = ("SELECT sql FROM "
+                     " (SELECT * FROM sqlite_master UNION ALL "
+                     "  SELECT * FROM sqlite_temp_master) "
+                     "WHERE name = '%s' "
+                     "AND type='view'") % view_name
+                rs = connection.execute(s)
+            except exc.DBAPIError:
+                raise
+                s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
+                     "AND type='view'") % view_name
+                rs = connection.execute(s)
+
+        result = rs.fetchall()
+        if result:
+            return result[0].sql
+
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
         else:
             pragma = "PRAGMA "
-        qtable = quote(tablename)
+        qtable = quote(table_name)
         c = connection.execute("%stable_info(%s)" % (pragma, qtable))
         found_table = False
         columns = []
@@ -339,14 +393,22 @@ class SQLiteDialect(default.DefaultDialect):
         return columns
 
     @reflection.cache
-    def get_foreign_keys(self, connection, tablename, schemaname=None,
-                                                        info_cache=None):
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        cols = self.get_columns(connection, table_name, schema, **kw)
+        pkeys = []
+        for col in cols:
+            if col['primary_key']:
+                pkeys.append(col['name'])
+        return pkeys
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
         quote = self.identifier_preparer.quote_identifier
-        if schemaname is not None:
-            pragma = "PRAGMA %s." % quote(schemaname)
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
         else:
             pragma = "PRAGMA "
-        qtable = quote(tablename)
+        qtable = quote(table_name)
         c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))
         fkeys = []
         fks = {}
@@ -379,14 +441,39 @@ class SQLiteDialect(default.DefaultDialect):
                 fk['referred_columns'].append(rcol)
         return fkeys
 
-    def get_unique_indexes(self, connection, tablename, schemaname=None,
-                                                            info_cache=None):
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None, **kw):
         quote = self.identifier_preparer.quote_identifier
-        if schemaname is not None:
-            pragma = "PRAGMA %s." % quote(schemaname)
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
         else:
             pragma = "PRAGMA "
-        qtable = quote(tablename)
+        qtable = quote(table_name)
+        c = connection.execute("%sindex_list(%s)" % (pragma, qtable))
+        indexes = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
+        # loop thru unique indexes to get the column names.
+        for idx in indexes:
+            c = connection.execute("%sindex_info(%s)" % (pragma, idx['name']))
+            cols = idx['column_names']
+            while True:
+                row = c.fetchone()
+                if row is None:
+                    break
+                cols.append(row[2])
+        return indexes
+
+    def get_unique_indexes(self, connection, table_name, schema=None, **kw):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
+        else:
+            pragma = "PRAGMA "
+        qtable = quote(table_name)
         c = connection.execute("%sindex_list(%s)" % (pragma, qtable))
         unique_indexes = []
         while True:
@@ -408,14 +495,15 @@ class SQLiteDialect(default.DefaultDialect):
 
     def reflecttable(self, connection, table, include_columns):
         preparer = self.identifier_preparer
-        tablename = table.name
-        schemaname = table.schema
+        table_name = table.name
+        schema = table.schema
         found_table = False
-        info_cache = SQLiteInfoCache()
+
+        info_cache = {}
 
         # columns
-        for column in self.get_columns(connection, tablename, schemaname,
-                                                                info_cache):
+        for column in self.get_columns(connection, table_name, schema,
+                                                        info_cache=info_cache):
             name = column['name']
             coltype = column['type']
             nullable = column['nullable']
@@ -425,22 +513,22 @@ class SQLiteDialect(default.DefaultDialect):
             found_table = True
             if include_columns and name not in include_columns:
                 continue
-            table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
+            table.append_column(sa_schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
         if not found_table:
             raise exc.NoSuchTableError(table.name)
 
         # foreign keys
-        for fkey_d in self.get_foreign_keys(connection, tablename, schemaname,
-                                                                   info_cache):
+        for fkey_d in self.get_foreign_keys(connection, table_name, schema,
+                                                        info_cache=info_cache):
 
             rtbl = fkey_d['referred_table']
             rcols = fkey_d['referred_columns']
             lcols = fkey_d['constrained_columns']
             # look up the table based on the given table's engine, not 'self',
             # since it could be a ProxyEngine
-            remotetable = schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection)
+            remotetable = sa_schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection)
             refspecs = ["%s.%s" % (rtbl, rcol) for rcol in rcols]
-            table.append_constraint(schema.ForeignKeyConstraint(lcols, refspecs, link_to_name=True))
+            table.append_constraint(sa_schema.ForeignKeyConstraint(lcols, refspecs, link_to_name=True))
         # this doesn't do anything ???
-        unique_indexes = self.get_unique_indexes(connection, tablename, 
-                                                 schemaname, info_cache)
+        unique_indexes = self.get_unique_indexes(connection, table_name, 
+                                    schema, info_cache=info_cache)
index 273c41cf4b9eb4acf4c5c7968a02dde49f8a1ab2..ac1f26630c57af206da7bdd2c00faaaa9438916e 100644 (file)
@@ -239,6 +239,11 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    def get_table_names(self, connection, schema=None, **kw):
+        """Return a list of table names for `schema`."""
+
+        raise NotImplementedError
+
     def get_view_names(self, connection, schema=None, **kw):
         """Return a list of all view names available in the database.
 
@@ -305,11 +310,6 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-    def get_table_names(self, connection, schema=None):
-        """Return a list of table names for `schema`."""
-
-        raise NotImplementedError
-
     def do_begin(self, connection):
         """Provide an implementation of *connection.begin()*, given a DB-API connection."""
 
index 999f0e2c223accf4df9ce0c9a75fa06e879dc385..ee0d1ec6bfbd16a72a4d4320f6242d2d94b15e2b 100644 (file)
@@ -15,6 +15,8 @@ if 'set' not in dir(__builtins__):
     from sets import Set as set
 
 def getSchema():
+    if testing.against('sqlite'):
+        return None
     if testing.against('oracle'):
         return 'test'
     else:
@@ -88,6 +90,7 @@ def dropViews(con, schema=None):
 
 class ReflectionTest(TestBase):
 
+    @testing.fails_on('sqlite', 'no schema support')
     def test_get_schema_names(self):
         meta = MetaData(testing.db)
         insp = Inspector(meta.bind)
@@ -221,8 +224,11 @@ class ReflectionTest(TestBase):
         try:
             expected_schema = schema
             if schema is None:
-                expected_schema = meta.bind.dialect.get_default_schema_name(
+                try:
+                    expected_schema = meta.bind.dialect.get_default_schema_name(
                                     meta.bind)
+                except NotImplementedError:
+                    expected_schema = None
             # users
             users_fkeys = insp.get_foreign_keys(users.name,
                                                 schema=schema)