]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Inspector hits bind.connect() when invoked to ensure
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Jun 2010 15:38:06 +0000 (11:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Jun 2010 15:38:06 +0000 (11:38 -0400)
initialize has been called.  the internal name ".conn"
is changed to ".bind", since that's what it is.

CHANGES
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/reflection.py
test/engine/test_reflection.py

diff --git a/CHANGES b/CHANGES
index 53d2a8e0549b4465b8e09df8aeeab11c2d6ce226..d0be776fe998f40bf96eb59876d6117aca25346a 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -10,6 +10,10 @@ CHANGES
     flag when used with TypeDecorators or other adaption
     scenarios.
 
+  - Inspector hits bind.connect() when invoked to ensure
+    initialize has been called.  the internal name ".conn"
+    is changed to ".bind", since that's what it is.
+    
 - firebird
   - Fixed incorrect signature in do_execute(), error 
     introduced in 0.6.1. [ticket:1823]
index aa2048d4be31fe63ec0c544d517990053d3515f2..76d1122e8d66c267e0067e4a338de2c2aa66a74a 100644 (file)
@@ -556,7 +556,7 @@ class PGInspector(reflection.Inspector):
     def get_table_oid(self, table_name, schema=None):
         """Return the oid from `table_name` and `schema`."""
 
-        return self.dialect.get_table_oid(self.conn, table_name, schema,
+        return self.dialect.get_table_oid(self.bind, table_name, schema,
                                           info_cache=self.info_cache)
 
 class CreateEnumType(schema._CreateDropBase):
index c66c1645b2efe8417cd46a338fc41dc1ea6e12e5..4a3643a416b9435a69ac9277c5013f3df66abc15 100644 (file)
@@ -49,18 +49,23 @@ class Inspector(object):
     provides higher level functions for accessing database schema information.
     """
 
-    def __init__(self, conn):
+    def __init__(self, bind):
         """Initialize the instance.
 
-        :param conn: a :class:`~sqlalchemy.engine.base.Connectable`
+        :param bind: a :class:`~sqlalchemy.engine.base.Connectable`
         """
 
-        self.conn = conn
+        # ensure initialized
+        bind.connect()
+        
+        # this might not be a connection, it could be an engine.
+        self.bind = bind
+        
         # set the engine
-        if hasattr(conn, 'engine'):
-            self.engine = conn.engine
+        if hasattr(bind, 'engine'):
+            self.engine = bind.engine
         else:
-            self.engine = conn
+            self.engine = bind
         self.dialect = self.engine.dialect
         self.info_cache = {}
 
@@ -79,7 +84,7 @@ class Inspector(object):
         """
 
         if hasattr(self.dialect, 'get_schema_names'):
-            return self.dialect.get_schema_names(self.conn,
+            return self.dialect.get_schema_names(self.bind,
                                                     info_cache=self.info_cache)
         return []
 
@@ -95,7 +100,7 @@ class Inspector(object):
         """
 
         if hasattr(self.dialect, 'get_table_names'):
-            tnames = self.dialect.get_table_names(self.conn,
+            tnames = self.dialect.get_table_names(self.bind,
             schema,
                                                     info_cache=self.info_cache)
         else:
@@ -121,7 +126,7 @@ class Inspector(object):
 
     def get_table_options(self, table_name, schema=None, **kw):
         if hasattr(self.dialect, 'get_table_options'):
-            return self.dialect.get_table_options(self.conn, table_name, schema,
+            return self.dialect.get_table_options(self.bind, table_name, schema,
                                                   info_cache=self.info_cache,
                                                   **kw)
         return {}
@@ -132,7 +137,7 @@ class Inspector(object):
         :param schema: Optional, retrieve names from a non-default schema.
         """
 
-        return self.dialect.get_view_names(self.conn, schema,
+        return self.dialect.get_view_names(self.bind, schema,
                                                   info_cache=self.info_cache)
 
     def get_view_definition(self, view_name, schema=None):
@@ -142,7 +147,7 @@ class Inspector(object):
         """
 
         return self.dialect.get_view_definition(
-            self.conn, view_name, schema, info_cache=self.info_cache)
+            self.bind, view_name, schema, info_cache=self.info_cache)
 
     def get_columns(self, table_name, schema=None, **kw):
         """Return information about columns in `table_name`.
@@ -166,7 +171,7 @@ class Inspector(object):
           dict containing optional column attributes
         """
 
-        col_defs = self.dialect.get_columns(self.conn, table_name, schema,
+        col_defs = self.dialect.get_columns(self.bind, table_name, schema,
                                             info_cache=self.info_cache,
                                             **kw)
         for col_def in col_defs:
@@ -183,7 +188,7 @@ class Inspector(object):
         primary key information as a list of column names.
         """
 
-        pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
+        pkeys = self.dialect.get_primary_keys(self.bind, table_name, schema,
                                               info_cache=self.info_cache,
                                               **kw)
 
@@ -202,7 +207,7 @@ class Inspector(object):
           optional name of the primary key constraint.
 
         """
-        pkeys = self.dialect.get_pk_constraint(self.conn, table_name, schema,
+        pkeys = self.dialect.get_pk_constraint(self.bind, table_name, schema,
                                               info_cache=self.info_cache,
                                               **kw)
 
@@ -236,7 +241,7 @@ class Inspector(object):
 
         """
 
-        fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
+        fk_defs = self.dialect.get_foreign_keys(self.bind, table_name, schema,
                                                 info_cache=self.info_cache,
                                                 **kw)
         return fk_defs
@@ -260,14 +265,14 @@ class Inspector(object):
           other options passed to the dialect's get_indexes() method.
         """
 
-        indexes = self.dialect.get_indexes(self.conn, table_name,
+        indexes = self.dialect.get_indexes(self.bind, table_name,
                                                   schema,
                                             info_cache=self.info_cache, **kw)
         return indexes
 
     def reflecttable(self, table, include_columns):
 
-        dialect = self.conn.dialect
+        dialect = self.bind.dialect
 
         # MySQL dialect does this.  Applicable with other dialects?
         if hasattr(dialect, '_connection_charset') \
@@ -362,7 +367,7 @@ class Inspector(object):
             if referred_schema is not None:
                 sa_schema.Table(referred_table, table.metadata,
                                 autoload=True, schema=referred_schema,
-                                autoload_with=self.conn,
+                                autoload_with=self.bind,
                                 **reflection_options
                                 )
                 for column in referred_columns:
@@ -370,7 +375,7 @@ class Inspector(object):
                         [referred_schema, referred_table, column]))
             else:
                 sa_schema.Table(referred_table, table.metadata, autoload=True,
-                                autoload_with=self.conn,
+                                autoload_with=self.bind,
                                 **reflection_options
                                 )
                 for column in referred_columns:
index 4b1cb4652d796636470119aebe75998a7f6e76cf..be53b05492c00801962a52c80cf42c3d0212f5aa 100644 (file)
@@ -1035,6 +1035,12 @@ class ComponentReflectionTest(TestBase):
         
         self.assert_('test_schema' in insp.get_schema_names())
 
+    def test_dialect_initialize(self):
+        engine = engines.testing_engine()
+        assert not hasattr(engine.dialect, 'default_schema_name')
+        insp = Inspector(engine)
+        assert hasattr(engine.dialect, 'default_schema_name')
+        
     def test_get_default_schema_name(self):
         insp = Inspector(testing.db)
         eq_(insp.default_schema_name, testing.db.dialect.default_schema_name)