From 27781263d93e4e35c8c39ed163fa827506cfb193 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 11 Jun 2010 11:38:06 -0400 Subject: [PATCH] - Inspector hits bind.connect() when invoked to ensure initialize has been called. the internal name ".conn" is changed to ".bind", since that's what it is. --- CHANGES | 4 ++ lib/sqlalchemy/dialects/postgresql/base.py | 2 +- lib/sqlalchemy/engine/reflection.py | 43 ++++++++++++---------- test/engine/test_reflection.py | 6 +++ 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/CHANGES b/CHANGES index 53d2a8e054..d0be776fe9 100644 --- a/CHANGES +++ b/CHANGES @@ -10,6 +10,10 @@ CHANGES flag when used with TypeDecorators or other adaption scenarios. + - Inspector hits bind.connect() when invoked to ensure + initialize has been called. the internal name ".conn" + is changed to ".bind", since that's what it is. + - firebird - Fixed incorrect signature in do_execute(), error introduced in 0.6.1. [ticket:1823] diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index aa2048d4be..76d1122e8d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -556,7 +556,7 @@ class PGInspector(reflection.Inspector): def get_table_oid(self, table_name, schema=None): """Return the oid from `table_name` and `schema`.""" - return self.dialect.get_table_oid(self.conn, table_name, schema, + return self.dialect.get_table_oid(self.bind, table_name, schema, info_cache=self.info_cache) class CreateEnumType(schema._CreateDropBase): diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index c66c1645b2..4a3643a416 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -49,18 +49,23 @@ class Inspector(object): provides higher level functions for accessing database schema information. """ - def __init__(self, conn): + def __init__(self, bind): """Initialize the instance. - :param conn: a :class:`~sqlalchemy.engine.base.Connectable` + :param bind: a :class:`~sqlalchemy.engine.base.Connectable` """ - self.conn = conn + # ensure initialized + bind.connect() + + # this might not be a connection, it could be an engine. + self.bind = bind + # set the engine - if hasattr(conn, 'engine'): - self.engine = conn.engine + if hasattr(bind, 'engine'): + self.engine = bind.engine else: - self.engine = conn + self.engine = bind self.dialect = self.engine.dialect self.info_cache = {} @@ -79,7 +84,7 @@ class Inspector(object): """ if hasattr(self.dialect, 'get_schema_names'): - return self.dialect.get_schema_names(self.conn, + return self.dialect.get_schema_names(self.bind, info_cache=self.info_cache) return [] @@ -95,7 +100,7 @@ class Inspector(object): """ if hasattr(self.dialect, 'get_table_names'): - tnames = self.dialect.get_table_names(self.conn, + tnames = self.dialect.get_table_names(self.bind, schema, info_cache=self.info_cache) else: @@ -121,7 +126,7 @@ class Inspector(object): def get_table_options(self, table_name, schema=None, **kw): if hasattr(self.dialect, 'get_table_options'): - return self.dialect.get_table_options(self.conn, table_name, schema, + return self.dialect.get_table_options(self.bind, table_name, schema, info_cache=self.info_cache, **kw) return {} @@ -132,7 +137,7 @@ class Inspector(object): :param schema: Optional, retrieve names from a non-default schema. """ - return self.dialect.get_view_names(self.conn, schema, + return self.dialect.get_view_names(self.bind, schema, info_cache=self.info_cache) def get_view_definition(self, view_name, schema=None): @@ -142,7 +147,7 @@ class Inspector(object): """ return self.dialect.get_view_definition( - self.conn, view_name, schema, info_cache=self.info_cache) + self.bind, view_name, schema, info_cache=self.info_cache) def get_columns(self, table_name, schema=None, **kw): """Return information about columns in `table_name`. @@ -166,7 +171,7 @@ class Inspector(object): dict containing optional column attributes """ - col_defs = self.dialect.get_columns(self.conn, table_name, schema, + col_defs = self.dialect.get_columns(self.bind, table_name, schema, info_cache=self.info_cache, **kw) for col_def in col_defs: @@ -183,7 +188,7 @@ class Inspector(object): primary key information as a list of column names. """ - pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema, + pkeys = self.dialect.get_primary_keys(self.bind, table_name, schema, info_cache=self.info_cache, **kw) @@ -202,7 +207,7 @@ class Inspector(object): optional name of the primary key constraint. """ - pkeys = self.dialect.get_pk_constraint(self.conn, table_name, schema, + pkeys = self.dialect.get_pk_constraint(self.bind, table_name, schema, info_cache=self.info_cache, **kw) @@ -236,7 +241,7 @@ class Inspector(object): """ - fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema, + fk_defs = self.dialect.get_foreign_keys(self.bind, table_name, schema, info_cache=self.info_cache, **kw) return fk_defs @@ -260,14 +265,14 @@ class Inspector(object): other options passed to the dialect's get_indexes() method. """ - indexes = self.dialect.get_indexes(self.conn, table_name, + indexes = self.dialect.get_indexes(self.bind, table_name, schema, info_cache=self.info_cache, **kw) return indexes def reflecttable(self, table, include_columns): - dialect = self.conn.dialect + dialect = self.bind.dialect # MySQL dialect does this. Applicable with other dialects? if hasattr(dialect, '_connection_charset') \ @@ -362,7 +367,7 @@ class Inspector(object): if referred_schema is not None: sa_schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, - autoload_with=self.conn, + autoload_with=self.bind, **reflection_options ) for column in referred_columns: @@ -370,7 +375,7 @@ class Inspector(object): [referred_schema, referred_table, column])) else: sa_schema.Table(referred_table, table.metadata, autoload=True, - autoload_with=self.conn, + autoload_with=self.bind, **reflection_options ) for column in referred_columns: diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 4b1cb4652d..be53b05492 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1035,6 +1035,12 @@ class ComponentReflectionTest(TestBase): self.assert_('test_schema' in insp.get_schema_names()) + def test_dialect_initialize(self): + engine = engines.testing_engine() + assert not hasattr(engine.dialect, 'default_schema_name') + insp = Inspector(engine) + assert hasattr(engine.dialect, 'default_schema_name') + def test_get_default_schema_name(self): insp = Inspector(testing.db) eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) -- 2.47.2