]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- dialect.get_default_schema_name(connection) is now
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Nov 2009 17:35:13 +0000 (17:35 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Nov 2009 17:35:13 +0000 (17:35 +0000)
public via dialect.default_schema_name.
[ticket:1571]

CHANGES
lib/sqlalchemy/dialects/maxdb/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py

diff --git a/CHANGES b/CHANGES
index 6625e0af71b63337f9bd2dcb8f960dfdceeafbbe..548edeef23a6a142fe89abf828efe5d3ff6ea9ed 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -266,6 +266,8 @@ CHANGES
   - deprecated or removed
       * result.last_inserted_ids() is deprecated.  Use 
         result.inserted_primary_key
+      * dialect.get_default_schema_name(connection) is now
+        public via dialect.default_schema_name.
             
 - schema
     - the `__contains__()` method of `MetaData` now accepts
index c02a9e204400c18715a9aa3e8e030afdb3ebc96e..d1c0191ed60a7933d76b860ba57900991e667223 100644 (file)
@@ -844,14 +844,9 @@ class MaxDBDialect(default.DefaultDialect):
         # COMMIT/ROLLBACK so omitting it should be relatively ok.
         pass
 
-    def get_default_schema_name(self, connection):
-        try:
-            return self._default_schema_name
-        except AttributeError:
-            name = self.identifier_preparer._normalize_name(
+    def _get_default_schema_name(self, connection):
+        return self.identifier_preparer._normalize_name(
                 connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
-            self._default_schema_name = name
-            return name
 
     def has_table(self, connection, table_name, schema=None):
         denormalize = self.identifier_preparer._denormalize_name
index ae9834a3934779ee1515a8a1f170bb256cc025e2..129125ca73a6155d7d4535aeb5caf8711627db97 100644 (file)
@@ -1167,9 +1167,6 @@ class MSDialect(default.DefaultDialect):
         if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__:
             self.implicit_returning = True
         
-    def get_default_schema_name(self, connection):
-        return self.default_schema_name
-        
     def _get_default_schema_name(self, connection):
         user_name = connection.scalar("SELECT user_name() as user_name;")
         if user_name is not None:
index 6be6934de5661d8f68fad7f15e58d4d186df9ee0..fd36bc9edfebff80a7d20becc3bcc33916cfb52b 100644 (file)
@@ -1740,7 +1740,7 @@ class MySQLDialect(default.DefaultDialect):
     def _extract_error_code(self, exception):
         raise NotImplementedError()
     
-    def get_default_schema_name(self, connection):
+    def _get_default_schema_name(self, connection):
         return connection.execute('SELECT DATABASE()').scalar()
 
     def table_names(self, connection, schema):
@@ -1801,7 +1801,7 @@ class MySQLDialect(default.DefaultDialect):
     @reflection.cache
     def get_table_names(self, connection, schema=None, **kw):
         if schema is None:
-            schema = self.get_default_schema_name(connection)
+            schema = self.default_schema_name
         if self.server_version_info < (5, 0, 2):
             return self.table_names(connection, schema)
         charset = self._connection_charset
@@ -1817,7 +1817,7 @@ class MySQLDialect(default.DefaultDialect):
         if self.server_version_info < (5, 0, 2):
             raise NotImplementedError
         if schema is None:
-            schema = self.get_default_schema_name(connection)
+            schema = self.default_schema_name
         if self.server_version_info < (5, 0, 2):
             return self.table_names(connection, schema)
         charset = self._connection_charset
@@ -1863,7 +1863,7 @@ class MySQLDialect(default.DefaultDialect):
             if not ref_schema:
                 if default_schema is None:
                     default_schema = \
-                        connection.dialect.get_default_schema_name(connection)
+                        connection.dialect.default_schema_name
                 if schema == default_schema:
                     ref_schema = schema
 
index d02eb4984aeda6812b4fd28546c8172a8dfe5e54..7b141ed2357c46294ea0463d6ed4c6612b3fd1cc 100644 (file)
@@ -535,7 +535,7 @@ class OracleDialect(default.DefaultDialect):
 
     def has_table(self, connection, table_name, schema=None):
         if not schema:
-            schema = self.get_default_schema_name(connection)
+            schema = self.default_schema_name
         cursor = connection.execute(
             sql.text("SELECT table_name FROM all_tables "
                      "WHERE table_name = :name AND owner = :schema_name"),
@@ -544,7 +544,7 @@ class OracleDialect(default.DefaultDialect):
 
     def has_sequence(self, connection, sequence_name, schema=None):
         if not schema:
-            schema = self.get_default_schema_name(connection)
+            schema = self.default_schema_name
         cursor = connection.execute(
             sql.text("SELECT sequence_name FROM all_sequences "
                      "WHERE sequence_name = :name AND sequence_owner = :schema_name"),
@@ -568,7 +568,7 @@ class OracleDialect(default.DefaultDialect):
         else:
             return name.encode(self.encoding)
 
-    def get_default_schema_name(self, connection):
+    def _get_default_schema_name(self, connection):
         return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
 
     def table_names(self, connection, schema):
@@ -638,7 +638,7 @@ class OracleDialect(default.DefaultDialect):
         if not dblink:
             dblink = ''
         if not owner:
-            owner = self.denormalize_name(schema or self.get_default_schema_name(connection))
+            owner = self.denormalize_name(schema or self.default_schema_name)
         return (actual_name, owner, dblink, synonym)
 
     @reflection.cache
@@ -649,12 +649,12 @@ class OracleDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_table_names(self, connection, schema=None, **kw):
-        schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
+        schema = self.denormalize_name(schema or self.default_schema_name)
         return self.table_names(connection, schema)
 
     @reflection.cache
     def get_view_names(self, connection, schema=None, **kw):
-        schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
+        schema = self.denormalize_name(schema or self.default_schema_name)
         s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
         cursor = connection.execute(s, owner=self.denormalize_name(schema))
         return [self.normalize_name(row[0]) for row in cursor]
index c41d5b359e290b863caa6a97038d2bb89ba775b9..bc8cff90525340c6acae3038c77c2ecce5d43e14 100644 (file)
@@ -602,7 +602,7 @@ class PGDialect(default.DefaultDialect):
         resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
         return [row[0] for row in resultset]
 
-    def get_default_schema_name(self, connection):
+    def _get_default_schema_name(self, connection):
         return connection.scalar("select current_schema()")
 
     def has_table(self, connection, table_name, schema=None):
@@ -761,7 +761,7 @@ class PGDialect(default.DefaultDialect):
         if schema is not None:
             current_schema = schema
         else:
-            current_schema = self.get_default_schema_name(connection)
+            current_schema = self.default_schema_name
         table_names = self.table_names(connection, current_schema)
         return table_names
 
@@ -770,7 +770,7 @@ class PGDialect(default.DefaultDialect):
         if schema is not None:
             current_schema = schema
         else:
-            current_schema = self.get_default_schema_name(connection)
+            current_schema = self.default_schema_name
         s = """
         SELECT relname
         FROM pg_class c
@@ -789,7 +789,7 @@ class PGDialect(default.DefaultDialect):
         if schema is not None:
             current_schema = schema
         else:
-            current_schema = self.get_default_schema_name(connection)
+            current_schema = self.default_schema_name
         s = """
         SELECT definition FROM pg_views
         WHERE schemaname = :schema
@@ -953,7 +953,7 @@ class PGDialect(default.DefaultDialect):
             constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
             if referred_schema:
                 referred_schema = preparer._unquote_identifier(referred_schema)
-            elif schema is not None and schema == self.get_default_schema_name(connection):
+            elif schema is not None and schema == self.default_schema_name:
                 # no schema (i.e. its the default schema), and the table we're
                 # reflecting has the default schema explicit, then use that.
                 # i.e. try to use the user's conventions
index a15a1fd5bef34e96a444f420779eb6233918479c..6fc42c312ecbcba45d895210ce81e6e45fadc506 100644 (file)
@@ -348,7 +348,8 @@ class SybaseDialect(default.DefaultDialect):
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
-    def get_default_schema_name(self, connection):
+    def _get_default_schema_name(self, connection):
+        # TODO
         return self.schema_name
 
     def table_names(self, connection, schema):
@@ -370,7 +371,7 @@ class SybaseDialect(default.DefaultDialect):
         if table.schema is not None:
             current_schema = table.schema
         else:
-            current_schema = self.get_default_schema_name(connection)
+            current_schema = self.default_schema_name
 
         s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
 
index e541399e8e27e7286ca49cea5eeabcb705769294..fac25b69993cc34de33190e453e0389a9ba4a5c9 100644 (file)
@@ -65,11 +65,14 @@ class Dialect(object):
 
     server_version_info
       a tuple containing a version number for the DB backend in use.
-      This value is only available for supporting dialects, and only for
-      a dialect that's been associated with a connection pool via
-      create_engine() or otherwise had its ``initialize()`` method called
-      with a conneciton.
-
+      This value is only available for supporting dialects, and is
+      typically populated during the initial connection to the database.
+    
+    default_schema_name
+     the name of the default schema.  This value is only available for
+     supporting dialects, and is typically populated during the
+     initial connection to the database.
+     
     execution_ctx_cls
       a :class:`ExecutionContext` class used to handle statement execution
 
@@ -327,10 +330,23 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-    def get_default_schema_name(self, connection):
-        """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`.
+    def _get_server_version_info(self, connection):
+        """Retrieve the server version info from the given connection.
+        
+        This is used by the default implementation to populate the
+        "server_version_info" attribute and is called exactly
+        once upon first connect.
+        
+        """
+
+        raise NotImplementedError()
         
-        DEPRECATED.  moving this towards dialect.default_schema_name (not complete).
+    def _get_default_schema_name(self, connection):
+        """Return the string name of the currently selected schema from the given connection.
+
+        This is used by the default implementation to populate the
+        "default_schema_name" attribute and is called exactly
+        once upon first connect.
         
         """
 
@@ -1427,7 +1443,7 @@ class Engine(Connectable):
             conn = connection
         if not schema:
             try:
-                schema =  self.dialect.get_default_schema_name(conn)
+                schema =  self.dialect.default_schema_name
             except NotImplementedError:
                 pass
         try:
index 37afa84ec1dcfa30d8c5ffba647810d168e890bb..a7f6dc4faeef17e7a355c6988795c4b3e699b5c7 100644 (file)
@@ -101,10 +101,15 @@ class DefaultDialect(base.Dialect):
         #self.returns_unicode_strings = True
 
     def initialize(self, connection):
-        if hasattr(self, '_get_server_version_info'):
+        try:
             self.server_version_info = self._get_server_version_info(connection)
-        if hasattr(self, '_get_default_schema_name'):
+        except NotImplementedError:
+            self.server_version_info = None
+        try:
             self.default_schema_name = self._get_default_schema_name(connection)
+        except NotImplementedError:
+            self.default_schema_name = None
+
         # Py2K
         self.returns_unicode_strings = self._check_unicode_returns(connection)
         # end Py2K