From: Mike Bayer Date: Tue, 3 Nov 2009 17:35:13 +0000 (+0000) Subject: - dialect.get_default_schema_name(connection) is now X-Git-Tag: rel_0_6beta1~193 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4b532e2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - dialect.get_default_schema_name(connection) is now public via dialect.default_schema_name. [ticket:1571] --- diff --git a/CHANGES b/CHANGES index 6625e0af71..548edeef23 100644 --- a/CHANGES +++ b/CHANGES @@ -266,6 +266,8 @@ CHANGES - deprecated or removed * result.last_inserted_ids() is deprecated. Use result.inserted_primary_key + * dialect.get_default_schema_name(connection) is now + public via dialect.default_schema_name. - schema - the `__contains__()` method of `MetaData` now accepts diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index c02a9e2044..d1c0191ed6 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -844,14 +844,9 @@ class MaxDBDialect(default.DefaultDialect): # COMMIT/ROLLBACK so omitting it should be relatively ok. pass - def get_default_schema_name(self, connection): - try: - return self._default_schema_name - except AttributeError: - name = self.identifier_preparer._normalize_name( + def _get_default_schema_name(self, connection): + return self.identifier_preparer._normalize_name( connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) - self._default_schema_name = name - return name def has_table(self, connection, table_name, schema=None): denormalize = self.identifier_preparer._denormalize_name diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index ae9834a393..129125ca73 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1167,9 +1167,6 @@ class MSDialect(default.DefaultDialect): if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__: self.implicit_returning = True - def get_default_schema_name(self, connection): - return self.default_schema_name - def _get_default_schema_name(self, connection): user_name = connection.scalar("SELECT user_name() as user_name;") if user_name is not None: diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 6be6934de5..fd36bc9edf 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1740,7 +1740,7 @@ class MySQLDialect(default.DefaultDialect): def _extract_error_code(self, exception): raise NotImplementedError() - def get_default_schema_name(self, connection): + def _get_default_schema_name(self, connection): return connection.execute('SELECT DATABASE()').scalar() def table_names(self, connection, schema): @@ -1801,7 +1801,7 @@ class MySQLDialect(default.DefaultDialect): @reflection.cache def get_table_names(self, connection, schema=None, **kw): if schema is None: - schema = self.get_default_schema_name(connection) + schema = self.default_schema_name if self.server_version_info < (5, 0, 2): return self.table_names(connection, schema) charset = self._connection_charset @@ -1817,7 +1817,7 @@ class MySQLDialect(default.DefaultDialect): if self.server_version_info < (5, 0, 2): raise NotImplementedError if schema is None: - schema = self.get_default_schema_name(connection) + schema = self.default_schema_name if self.server_version_info < (5, 0, 2): return self.table_names(connection, schema) charset = self._connection_charset @@ -1863,7 +1863,7 @@ class MySQLDialect(default.DefaultDialect): if not ref_schema: if default_schema is None: default_schema = \ - connection.dialect.get_default_schema_name(connection) + connection.dialect.default_schema_name if schema == default_schema: ref_schema = schema diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index d02eb4984a..7b141ed235 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -535,7 +535,7 @@ class OracleDialect(default.DefaultDialect): def has_table(self, connection, table_name, schema=None): if not schema: - schema = self.get_default_schema_name(connection) + schema = self.default_schema_name cursor = connection.execute( sql.text("SELECT table_name FROM all_tables " "WHERE table_name = :name AND owner = :schema_name"), @@ -544,7 +544,7 @@ class OracleDialect(default.DefaultDialect): def has_sequence(self, connection, sequence_name, schema=None): if not schema: - schema = self.get_default_schema_name(connection) + schema = self.default_schema_name cursor = connection.execute( sql.text("SELECT sequence_name FROM all_sequences " "WHERE sequence_name = :name AND sequence_owner = :schema_name"), @@ -568,7 +568,7 @@ class OracleDialect(default.DefaultDialect): else: return name.encode(self.encoding) - def get_default_schema_name(self, connection): + def _get_default_schema_name(self, connection): return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) def table_names(self, connection, schema): @@ -638,7 +638,7 @@ class OracleDialect(default.DefaultDialect): if not dblink: dblink = '' if not owner: - owner = self.denormalize_name(schema or self.get_default_schema_name(connection)) + owner = self.denormalize_name(schema or self.default_schema_name) return (actual_name, owner, dblink, synonym) @reflection.cache @@ -649,12 +649,12 @@ class OracleDialect(default.DefaultDialect): @reflection.cache def get_table_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) + schema = self.denormalize_name(schema or self.default_schema_name) return self.table_names(connection, schema) @reflection.cache def get_view_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) + schema = self.denormalize_name(schema or self.default_schema_name) s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") cursor = connection.execute(s, owner=self.denormalize_name(schema)) return [self.normalize_name(row[0]) for row in cursor] diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index c41d5b359e..bc8cff9052 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -602,7 +602,7 @@ class PGDialect(default.DefaultDialect): resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] - def get_default_schema_name(self, connection): + def _get_default_schema_name(self, connection): return connection.scalar("select current_schema()") def has_table(self, connection, table_name, schema=None): @@ -761,7 +761,7 @@ class PGDialect(default.DefaultDialect): if schema is not None: current_schema = schema else: - current_schema = self.get_default_schema_name(connection) + current_schema = self.default_schema_name table_names = self.table_names(connection, current_schema) return table_names @@ -770,7 +770,7 @@ class PGDialect(default.DefaultDialect): if schema is not None: current_schema = schema else: - current_schema = self.get_default_schema_name(connection) + current_schema = self.default_schema_name s = """ SELECT relname FROM pg_class c @@ -789,7 +789,7 @@ class PGDialect(default.DefaultDialect): if schema is not None: current_schema = schema else: - current_schema = self.get_default_schema_name(connection) + current_schema = self.default_schema_name s = """ SELECT definition FROM pg_views WHERE schemaname = :schema @@ -953,7 +953,7 @@ class PGDialect(default.DefaultDialect): constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] if referred_schema: referred_schema = preparer._unquote_identifier(referred_schema) - elif schema is not None and schema == self.get_default_schema_name(connection): + elif schema is not None and schema == self.default_schema_name: # no schema (i.e. its the default schema), and the table we're # reflecting has the default schema explicit, then use that. # i.e. try to use the user's conventions diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index a15a1fd5be..6fc42c312e 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -348,7 +348,8 @@ class SybaseDialect(default.DefaultDialect): def last_inserted_ids(self): return self.context.last_inserted_ids - def get_default_schema_name(self, connection): + def _get_default_schema_name(self, connection): + # TODO return self.schema_name def table_names(self, connection, schema): @@ -370,7 +371,7 @@ class SybaseDialect(default.DefaultDialect): if table.schema is not None: current_schema = table.schema else: - current_schema = self.get_default_schema_name(connection) + current_schema = self.default_schema_name s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index e541399e8e..fac25b6999 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -65,11 +65,14 @@ class Dialect(object): server_version_info a tuple containing a version number for the DB backend in use. - This value is only available for supporting dialects, and only for - a dialect that's been associated with a connection pool via - create_engine() or otherwise had its ``initialize()`` method called - with a conneciton. - + This value is only available for supporting dialects, and is + typically populated during the initial connection to the database. + + default_schema_name + the name of the default schema. This value is only available for + supporting dialects, and is typically populated during the + initial connection to the database. + execution_ctx_cls a :class:`ExecutionContext` class used to handle statement execution @@ -327,10 +330,23 @@ class Dialect(object): raise NotImplementedError() - def get_default_schema_name(self, connection): - """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`. + def _get_server_version_info(self, connection): + """Retrieve the server version info from the given connection. + + This is used by the default implementation to populate the + "server_version_info" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() - DEPRECATED. moving this towards dialect.default_schema_name (not complete). + def _get_default_schema_name(self, connection): + """Return the string name of the currently selected schema from the given connection. + + This is used by the default implementation to populate the + "default_schema_name" attribute and is called exactly + once upon first connect. """ @@ -1427,7 +1443,7 @@ class Engine(Connectable): conn = connection if not schema: try: - schema = self.dialect.get_default_schema_name(conn) + schema = self.dialect.default_schema_name except NotImplementedError: pass try: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 37afa84ec1..a7f6dc4fae 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -101,10 +101,15 @@ class DefaultDialect(base.Dialect): #self.returns_unicode_strings = True def initialize(self, connection): - if hasattr(self, '_get_server_version_info'): + try: self.server_version_info = self._get_server_version_info(connection) - if hasattr(self, '_get_default_schema_name'): + except NotImplementedError: + self.server_version_info = None + try: self.default_schema_name = self._get_default_schema_name(connection) + except NotImplementedError: + self.default_schema_name = None + # Py2K self.returns_unicode_strings = self._check_unicode_returns(connection) # end Py2K