]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Corrections to MSSQL Date/Time types; generalized server_version_info to a create_eng...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jan 2009 19:50:21 +0000 (19:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jan 2009 19:50:21 +0000 (19:50 +0000)
13 files changed:
06CHANGES
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/pymssql.py
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
test/sql/testtypes.py
test/testlib/testing.py

index 17a0a50f23a9ab7e63818451b35d8e672d7ca46c..61637907d8ce0f6205650fd29c3699ca888b0c4e 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
@@ -9,7 +9,20 @@
       code structure.
     
 - dialect refactor
-
+    - server_version_info becomes a static attribute.
+    - create_engine() now establishes an initial connection immediately upon
+      creation, which is passed to the dialect to determine connection properties.
+      
+- mysql
+    - all the _detect_XXX() functions now run once underneath dialect.initialize()
+    
 - new dialects
     - pg8000
-    - pyodbc+mysql
\ No newline at end of file
+    - pyodbc+mysql
+    
+- mssql
+    - the "has_window_funcs" flag is removed.  LIMIT/OFFSET usage will use ROW NUMBER as always,
+      and if on an older version of SQL Server, the operation fails.  The behavior is exactly
+      the same except the error is raised by SQL server instead of the dialect, and no
+      flag setting is required to enable it.
+    - using new dialect.initialize() feature to set up version-dependent behavior.
\ No newline at end of file
index b94e5a40744a1e27f7ceb26a7cbfaade96a9f696..4f8d6d517f7751f181fa4a4d6c9f2117662d2961 100644 (file)
@@ -68,8 +68,8 @@ class PyODBCConnector(Connector):
         else:
             return False
 
-    def _server_version_info(self, dbapi_con):
-        """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
+    def _get_server_version_info(self, connection):
+        dbapi_con = connection.connection
         version = []
         r = re.compile('[.\-]')
         for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
index 9f6ac48d35f824483a632cc92a2f8f3661ed6262..1964b6ddc5dad954ed9d012f6632c87de5773077 100644 (file)
@@ -116,11 +116,6 @@ arguments on the URL, or as keyword argument to
   pyodbc this defaults to ``True`` if the version of pyodbc being
   used supports it.
 
-* *has_window_funcs* - indicates whether or not window functions
-  (LIMIT and OFFSET) are supported on the version of MSSQL being
-  used. If you're running MSSQL 2005 or later turn this on to get
-  OFFSET support. Defaults to ``False``.
-
 * *max_identifier_length* - allows you to se the maximum length of
   identfiers supported by the database. Defaults to 128. For pymssql
   the default is 30.
@@ -182,10 +177,9 @@ will yield::
 
     SELECT TOP n
 
-If the ``has_window_funcs`` flag is set then LIMIT with OFFSET
-support is available through the ``ROW_NUMBER OVER`` construct. This
-construct requires an ``ORDER BY`` to be specified as well and is
-only available on MSSQL 2005 and later.
+If using SQL Server 2005 or above, LIMIT with OFFSET
+support is available through the ``ROW_NUMBER OVER`` construct. 
+For versions below 2005, LIMIT with OFFSET usage will fail.
 
 Nullability
 -----------
@@ -206,13 +200,12 @@ If ``nullable`` is ``True`` or ``False`` then the column will be
 
 Date / Time Handling
 --------------------
-For MSSQL versions that support the ``DATE`` and ``TIME`` types
-(MSSQL 2008+) the data type is used. For versions that do not
-support the ``DATE`` and ``TIME`` types a ``DATETIME`` type is used
-instead and the MSSQL dialect handles converting the results
-properly. This means ``Date()`` and ``Time()`` are fully supported
-on all versions of MSSQL. If you do not desire this behavior then
-do not use the ``Date()`` or ``Time()`` types.
+DATE and TIME are supported.   Bind parameters are converted
+to datetime.datetime() objects as required by most MSSQL drivers,
+and results are processed from strings if needed.
+The DATE and TIME types are not available for MSSQL 2005 and
+previous - if a server version below 2008 is detected, DDL
+for these types will be issued as DATETIME.
 
 Compatibility Levels
 --------------------
@@ -234,7 +227,7 @@ Known Issues
   does **not** work around
 
 """
-import datetime, decimal, inspect, operator, sys
+import datetime, decimal, inspect, operator, sys, re
 
 from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.sql import compiler, expression, operators as sql_operators, functions as sql_functions
@@ -242,6 +235,9 @@ from sqlalchemy.engine import default, base
 from sqlalchemy import types as sqltypes
 from decimal import Decimal as _python_Decimal
 
+MS_2008_VERSION = (10,)
+#MS_2005_VERSION = ??
+#MS_2000_VERSION = ??
 
 MSSQL_RESERVED_WORDS = set(['function'])
 
@@ -308,20 +304,65 @@ class MSReal(sqltypes.Float):
 class MSTinyInteger(sqltypes.Integer):
     __visit_name__ = 'TINYINT'
 
+# MSSQL DATE/TIME types have varied behavior, sometimes returning
+# strings.  MSDate/MSTime check for everything, and always
+# filter bind parameters into datetime objects (required by pyodbc,
+# not sure about other dialects).
+
+class MSDate(sqltypes.Date):
+    def bind_processor(self, dialect):
+        def process(value):
+            if type(value) == datetime.date:
+                return datetime.datetime(value.year, value.month, value.day)
+            else:
+                return value
+        return process
+
+    _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+    def result_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                return value.date()
+            elif isinstance(value, basestring):
+                return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()])
+            else:
+                return value
+        return process
+    
 class MSTime(sqltypes.Time):
     def __init__(self, precision=None, **kwargs):
         self.precision = precision
         super(MSTime, self).__init__()
 
+    __zero_date = datetime.date(1900, 1, 1)
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                value = datetime.datetime.combine(self.__zero_date, value.time())
+            elif isinstance(value, datetime.time):
+                value = datetime.datetime.combine(self.__zero_date, value)
+            return value
+        return process
+
+    _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+    def result_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                return value.time()
+            elif isinstance(value, basestring):
+                return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()])
+            else:
+                return value
+        return process
 
 class MSDateTime(sqltypes.DateTime):
     def bind_processor(self, dialect):
-        # most DBAPIs allow a datetime.date object
-        # as a datetime.
         def process(value):
-            if type(value) is datetime.date:
+            if type(value) == datetime.date:
                 return datetime.datetime(value.year, value.month, value.day)
-            return value
+            else:
+                return value
         return process
     
 class MSSmallDateTime(MSDateTime):
@@ -339,53 +380,6 @@ class MSDateTimeOffset(sqltypes.TypeEngine):
     def __init__(self, precision=None, **kwargs):
         self.precision = precision
 
-class MSDateTimeAsDate(sqltypes.TypeDecorator):
-    """ This is an implementation of the Date type for versions of MSSQL that
-    do not support that specific type. In order to make it work a ``DATETIME``
-    column specification is used and the results get converted back to just
-    the date portion.
-
-    """
-
-    impl = sqltypes.DateTime
-
-    def process_bind_param(self, value, dialect):
-        if type(value) is datetime.date:
-            return datetime.datetime(value.year, value.month, value.day)
-        return value
-
-    def process_result_value(self, value, dialect):
-        if type(value) is datetime.datetime:
-            return value.date()
-        return value
-
-class MSDateTimeAsTime(sqltypes.TypeDecorator):
-    """ This is an implementation of the Time type for versions of MSSQL that
-    do not support that specific type. In order to make it work a ``DATETIME``
-    column specification is used and the results get converted back to just
-    the time portion.
-
-    """
-
-    __zero_date = datetime.date(1900, 1, 1)
-
-    impl = sqltypes.DateTime
-
-    def process_bind_param(self, value, dialect):
-        if type(value) is datetime.datetime:
-            value = datetime.datetime.combine(self.__zero_date, value.time())
-        elif type(value) is datetime.time:
-            value = datetime.datetime.combine(self.__zero_date, value)
-        return value
-
-    def process_result_value(self, value, dialect):
-        if type(value) is datetime.datetime:
-            return value.time()
-        elif type(value) is datetime.date:
-            return datetime.time(0, 0, 0)
-        return value
-
-
 class _StringType(object):
     """Base for MSSQL string types."""
 
@@ -672,15 +666,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
         return self._extend("NVARCHAR", type_)
 
     def visit_date(self, type_):
-        # psudocode
-        if self.dialect.version <= 10:
+        if self.dialect.server_version_info < MS_2008_VERSION:
             return self.visit_DATETIME(type_)
         else:
             return self.visit_DATE(type_)
 
     def visit_time(self, type_):
-        # psudocode
-        if self.dialect.version <= 10:
+        if self.dialect.server_version_info < MS_2008_VERSION:
             return self.visit_DATETIME(type_)
         else:
             return self.visit_TIME(type_)
@@ -791,6 +783,7 @@ colspecs = {
     sqltypes.Unicode : MSNVarchar,
     sqltypes.Numeric : MSNumeric,
     sqltypes.DateTime : MSDateTime,
+    sqltypes.Date : MSDate,
     sqltypes.Time : MSTime,
     sqltypes.String : MSString,
     sqltypes.Boolean : MSBoolean,
@@ -861,9 +854,6 @@ class MSSQLCompiler(compiler.SQLCompiler):
             if select._limit:
                 if not select._offset:
                     s += "TOP %s " % (select._limit,)
-                else:
-                    if not self.dialect.has_window_funcs:
-                        raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
             return s
         return compiler.SQLCompiler.get_select_precolumns(self, select)
 
@@ -876,7 +866,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
         so tries to wrap it in a subquery with ``row_number()`` criterion.
 
         """
-        if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
+        if not getattr(select, '_mssql_visit', None) and select._offset:
             # to use ROW_NUMBER(), an ORDER BY is required.
             orderby = self.process(select._order_by_clause)
             if not orderby:
@@ -1061,7 +1051,6 @@ class MSDialect(default.DefaultDialect):
     execution_ctx_cls = MSExecutionContext
     text_as_varchar = False
     use_scope_identity = False
-    has_window_funcs = False
     max_identifier_length = 128
     schema_name = "dbo"
     colspecs = colspecs
@@ -1069,6 +1058,8 @@ class MSDialect(default.DefaultDialect):
 
     supports_unicode_binds = True
 
+    server_version_info = ()
+    
     statement_compiler = MSSQLCompiler
     ddl_compiler = MSDDLCompiler
     type_compiler = MSTypeCompiler
@@ -1077,35 +1068,19 @@ class MSDialect(default.DefaultDialect):
     def __init__(self,
                  auto_identity_insert=True, query_timeout=None,
                  use_scope_identity=False,
-                 has_window_funcs=False, max_identifier_length=None,
+                 max_identifier_length=None,
                  schema_name="dbo", **opts):
         self.auto_identity_insert = bool(auto_identity_insert)
         self.query_timeout = int(query_timeout or 0)
         self.schema_name = schema_name
 
         self.use_scope_identity = bool(use_scope_identity)
-        self.has_window_funcs =  bool(has_window_funcs)
         self.max_identifier_length = int(max_identifier_length or 0) or 128
         super(MSDialect, self).__init__(**opts)
-
-    @base.connection_memoize(('mssql', 'server_version_info'))
-    def server_version_info(self, connection):
-        """A tuple of the database server version.
-
-        Formats the remote server version as a tuple of version values,
-        e.g. ``(9, 0, 1399)``.  If there are strings in the version number
-        they will be in the tuple too, so don't count on these all being
-        ``int`` values.
-
-        This is a fast check that does not require a round trip.  It is also
-        cached per-Connection.
-        """
-        return connection.dialect._server_version_info(connection.connection)
-
-    def _server_version_info(self, dbapi_con):
-        """Return a tuple of the database's version number."""
-        raise NotImplementedError()
-
+    
+    def initialize(self, connection):
+        self.server_version_info = self._get_server_version_info(connection)
+    
     def do_begin(self, connection):
         cursor = connection.cursor()
         cursor.execute("SET IMPLICIT_TRANSACTIONS OFF")
index b7b775899ef42c6f1fc7ad115338ba3c55bef8c4..475cc398af26a4bc75932b42af372c291eb817a0 100644 (file)
@@ -7,11 +7,6 @@ class MSDialect_pymssql(MSDialect):
     max_identifier_length = 30
     driver = 'pymssql'
 
-    # TODO: shouldnt this be based on server version <10 like pyodbc does ?
-    colspecs = MSSQLDialect.colspecs.copy()
-    colspecs[sqltypes.Date] = MSDateTimeAsDate
-    colspecs[sqltypes.Time] = MSDateTimeAsTime
-    
     @classmethod
     def import_dbapi(cls):
         import pymssql as module
index 5ff730c3f247c388e4a29afcb524a2910b01f02d..1b67cc04c41319c55336182127184dadab99efa2 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSDateTimeAsDate, MSDateTimeAsTime
+from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
 from sqlalchemy.connectors.pyodbc import PyODBCConnector
 from sqlalchemy import types as sqltypes
 
@@ -14,14 +14,13 @@ class MSExecutionContext_pyodbc(MSExecutionContext):
 
     def post_exec(self):
         if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
-            import pyodbc
             # Fetch the last inserted id from the manipulated statement
             # We may have to skip over a number of result sets with no data (due to triggers, etc.)
             while True:
                 try:
                     row = self.cursor.fetchone()
                     break
-                except pyodbc.Error, e:
+                except self.dialect.dbapi.Error, e:
                     self.cursor.nextset()
             self._last_inserted_ids = [int(row[0])]
         else:
@@ -43,11 +42,6 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
         self.description_encoding = description_encoding
         self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
         
-        if self.server_version_info < (10,):
-            self.colspecs = MSDialect.colspecs.copy()
-            self.colspecs[sqltypes.Date] = MSDateTimeAsDate
-            self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
     def is_disconnect(self, e):
         if isinstance(e, self.dbapi.ProgrammingError):
             return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
index bb6b7ab75f8f86c16ed43a49f601254266252aee..412c4125addeea8cacac50d98b90a28f61a29bcf 100644 (file)
@@ -1697,7 +1697,6 @@ class MySQLDialect(default.DefaultDialect):
     ischema_names = ischema_names
     
     def __init__(self, use_ansiquotes=None, **kwargs):
-        self.use_ansiquotes = use_ansiquotes
         default.DefaultDialect.__init__(self, **kwargs)
 
     def do_executemany(self, cursor, statement, parameters, context=None):
@@ -1716,7 +1715,7 @@ class MySQLDialect(default.DefaultDialect):
         try:
             connection.commit()
         except:
-            if self._server_version_info(connection) < (3, 23, 15):
+            if self.server_version_info < (3, 23, 15):
                 args = sys.exc_info()[1].args
                 if args and args[0] == 1064:
                     return
@@ -1728,7 +1727,7 @@ class MySQLDialect(default.DefaultDialect):
         try:
             connection.rollback()
         except:
-            if self._server_version_info(connection) < (3, 23, 15):
+            if self.server_version_info < (3, 23, 15):
                 args = sys.exc_info()[1].args
                 if args and args[0] == 1064:
                     return
@@ -1786,8 +1785,7 @@ class MySQLDialect(default.DefaultDialect):
     def table_names(self, connection, schema):
         """Return a Unicode SHOW TABLES from a given schema."""
 
-        charset = self._detect_charset(connection)
-        self._autoset_identifier_style(connection)
+        charset = self._server_charset
         rp = connection.execute("SHOW TABLES FROM %s" %
             self.identifier_preparer.quote_identifier(schema))
         return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
@@ -1803,7 +1801,6 @@ class MySQLDialect(default.DefaultDialect):
         # full_name = self.identifier_preparer.format_table(table,
         #                                                   use_schema=True)
 
-        self._autoset_identifier_style(connection)
 
         full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
             schema, table_name))
@@ -1823,36 +1820,30 @@ class MySQLDialect(default.DefaultDialect):
         finally:
             if rs:
                 rs.close()
-
-    @engine_base.connection_memoize(('mysql', 'server_version_info'))
-    def server_version_info(self, connection):
-        """A tuple of the database server version.
-
-        Formats the remote server version as a tuple of version values,
-        e.g. ``(5, 0, 44)``.  If there are strings in the version number
-        they will be in the tuple too, so don't count on these all being
-        ``int`` values.
-
-        This is a fast check that does not require a round trip.  It is also
-        cached per-Connection.
-        """
-
-        # TODO: do we need to bypass ConnectionFairy here?  other calls
-        # to this seem to not do that.
-        return self._server_version_info(connection.connection.connection)
-
+    
+    def initialize(self, connection):
+        self.server_version_info = self._get_server_version_info(connection)
+        self._server_charset = self._detect_charset(connection)
+        self._server_casing = self._detect_casing(connection)
+        self._server_collations = self._detect_collations(connection)
+        self._server_ansiquotes = self._detect_ansiquotes(connection)
+        if self._server_ansiquotes:
+            self.preparer = MySQLANSIIdentifierPreparer
+        else:
+            self.preparer = MySQLIdentifierPreparer
+        self.identifier_preparer = self.preparer(self)
+        
     def reflecttable(self, connection, table, include_columns):
         """Load column definitions from the server."""
 
-        charset = self._detect_charset(connection)
-        self._autoset_identifier_style(connection)
+        charset = self._server_charset
 
         try:
             reflector = self.reflector
         except AttributeError:
             preparer = self.identifier_preparer
-            if (self.server_version_info(connection) < (4, 1) and
-                self.use_ansiquotes):
+            if (self.server_version_info < (4, 1) and
+                self._server_use_ansiquotes):
                 # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
                 preparer = MySQLIdentifierPreparer(self)
 
@@ -1864,15 +1855,15 @@ class MySQLDialect(default.DefaultDialect):
             columns = self._describe_table(connection, table, charset)
             sql = reflector._describe_to_create(table, columns)
 
-        self._adjust_casing(connection, table)
+        self._adjust_casing(table)
 
         return reflector.reflect(connection, table, sql, charset,
                                  only=include_columns)
 
-    def _adjust_casing(self, connection, table, charset=None):
+    def _adjust_casing(self, table, charset=None):
         """Adjust Table name to the server case sensitivity, if needed."""
 
-        casing = self._detect_casing(connection)
+        casing = self._server_casing
 
         # For winxx database hosts.  TODO: is this really needed?
         if casing == 1 and table.name != table.name.lower():
@@ -1892,7 +1883,7 @@ class MySQLDialect(default.DefaultDialect):
         """
         # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
 
-        charset = self._detect_charset(connection)
+        charset = self._server_charset
         row = self._compat_fetchone(connection.execute(
             "SHOW VARIABLES LIKE 'lower_case_table_names'"),
                                charset=charset)
@@ -1909,8 +1900,6 @@ class MySQLDialect(default.DefaultDialect):
                 cs = int(row[1])
             row.close()
         return cs
-    _detect_casing = engine_base.connection_memoize(
-        ('mysql', 'lower_case_table_names'))(_detect_casing)
 
     def _detect_collations(self, connection):
         """Pull the active COLLATIONS list from the server.
@@ -1919,49 +1908,21 @@ class MySQLDialect(default.DefaultDialect):
         """
 
         collations = {}
-        if self.server_version_info(connection) < (4, 1, 0):
+        if self.server_version_info < (4, 1, 0):
             pass
         else:
-            charset = self._detect_charset(connection)
+            charset = self._server_charset
             rs = connection.execute('SHOW COLLATION')
             for row in self._compat_fetchall(rs, charset):
                 collations[row[0]] = row[1]
         return collations
-    _detect_collations = engine_base.connection_memoize(
-        ('mysql', 'collations'))(_detect_collations)
 
-    def use_ansiquotes(self, useansi):
-        self._use_ansiquotes = useansi
-        if useansi:
-            self.preparer = MySQLANSIIdentifierPreparer
-        else:
-            self.preparer = MySQLIdentifierPreparer
-        # icky
-        if hasattr(self, 'identifier_preparer'):
-            self.identifier_preparer = self.preparer(self)
-        if hasattr(self, 'reflector'):
-            del self.reflector
-
-    use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes,
-                              doc="True if ANSI_QUOTES is in effect.")
-
-    def _autoset_identifier_style(self, connection, charset=None):
-        """Detect and adjust for the ANSI_QUOTES sql mode.
-
-        If the dialect's use_ansiquotes is unset, query the server's sql mode
-        and reset the identifier style.
-
-        Note that this currently *only* runs during reflection.  Ideally this
-        would run the first time a connection pool connects to the database,
-        but the infrastructure for that is not yet in place.
-        """
-
-        if self.use_ansiquotes is not None:
-            return
+    def _detect_ansiquotes(self, connection):
+        """Detect and adjust for the ANSI_QUOTES sql mode."""
 
         row = self._compat_fetchone(
             connection.execute("SHOW VARIABLES LIKE 'sql_mode'"),
-                               charset=charset)
+                               charset=self._server_charset)
         if not row:
             mode = ''
         else:
@@ -1971,7 +1932,7 @@ class MySQLDialect(default.DefaultDialect):
                 mode_no = int(mode)
                 mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or ''
 
-        self.use_ansiquotes = 'ANSI_QUOTES' in mode
+        return 'ANSI_QUOTES' in mode
 
     def _show_create_table(self, connection, table, charset=None,
                            full_name=None):
index b077774ea51d2c6d8987a1e876d2b292bc9c756a..c947dc2fbaad339c2116423095ddd0336fed7f14 100644 (file)
@@ -96,9 +96,8 @@ class MySQL_mysqldb(MySQLDialect):
     def do_ping(self, connection):
         connection.ping()
 
-    def _server_version_info(self, dbapi_con):
-        """Convert a MySQL-python server_info string into a tuple."""
-
+    def _get_server_version_info(self,connection):
+        dbapi_con = connection.connection
         version = []
         r = re.compile('[.\-]')
         for n in r.split(dbapi_con.get_server_info()):
@@ -114,7 +113,6 @@ class MySQL_mysqldb(MySQLDialect):
         except AttributeError:
             return None
 
-    @engine_base.connection_memoize(('mysql', 'charset'))
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
 
@@ -124,7 +122,7 @@ class MySQL_mysqldb(MySQLDialect):
 
         # Note: MySQL-python 1.2.1c7 seems to ignore changes made
         # on a connection via set_character_set()
-        if self.server_version_info(connection) < (4, 1, 0):
+        if self.server_version_info < (4, 1, 0):
             try:
                 return connection.connection.character_set_name()
             except AttributeError:
index 3b9b373610d8055a08f77160b79ec7d0c90a5eb5..426b23cfdf25fb9b5e62f2259c40a9b9bf418465 100644 (file)
@@ -21,7 +21,6 @@ class MySQL_pyodbc(PyODBCConnector, MySQLDialect):
         MySQLDialect.__init__(self, **kw)
         PyODBCConnector.__init__(self, **kw)
 
-    @engine_base.connection_memoize(('mysql', 'charset'))
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
 
index f3acc28597358a4c9a62c3740f27a963a06298f7..f0432e16dba3b0cb6787ebb7126e1ab4e9645d2e 100644 (file)
@@ -67,6 +67,13 @@ class Dialect(object):
       a :class:`~Compiled` class used to compile DDL
       statements
 
+    server_version_info
+      a tuple containing a version number for the DB backend in use.
+      This value is only available for supporting dialects, and only for 
+      a dialect that's been associated with a connection pool via
+      create_engine() or otherwise had its ``initialize()`` method called
+      with a conneciton.
+
     execution_ctx_cls
       a :class:`ExecutionContext` class used to handle statement execution
 
@@ -114,6 +121,7 @@ class Dialect(object):
 
     supports_default_values
       Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported
+      
     """
 
     def create_connect_args(self, url):
@@ -141,10 +149,14 @@ class Dialect(object):
         raise NotImplementedError()
 
 
-    def server_version_info(self, connection):
-        """Return a tuple of the database's version number."""
-
-        raise NotImplementedError()
+    def initialize(self, connection):
+        """Called during strategized creation of the dialect with a connection.
+        
+        Allows dialects to configure options based on server version info or
+        other properties.
+        
+        """
+        pass
 
     def reflecttable(self, connection, table, include_columns=None):
         """Load table description from the database.
index 1f602eb6d3e8fcf0eafe666bd69a92dfa02a74ff..beec145604362c61ddd7382ce36fa72df1ca78a4 100644 (file)
@@ -68,7 +68,7 @@ class DefaultDialect(base.Dialect):
             raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length))
         self.label_length = label_length
         self.description_encoding = getattr(self, 'description_encoding', encoding)
-
+    
     def type_descriptor(self, typeobj):
         """Provide a database-specific ``TypeEngine`` object, given
         the generic object which comes from the types module.
index b1261da0a88be3015217e097f359cb1f55fba437..b763997d42ce243f8e2fd821c9bfd4ace5824a28 100644 (file)
@@ -119,7 +119,14 @@ class DefaultEngineStrategy(EngineStrategy):
                                     dialect.__class__.__name__,
                                     pool.__class__.__name__,
                                     engineclass.__name__))
-        return engineclass(pool, dialect, u, **engine_args)
+                                    
+        engine = engineclass(pool, dialect, u, **engine_args)
+        conn = engine.connect()
+        try:
+            dialect.initialize(conn)
+        finally:
+            conn.close()
+        return engine
 
     def pool_threadlocal(self):
         raise NotImplementedError()
index 29ed49d073b33c8750a03b13b94d6d1e96e15daa..40ad8814babf74ea4557034ee75bf7693b8d70f5 100644 (file)
@@ -574,7 +574,6 @@ class DateTest(TestBase, AssertsExecutionResults):
 
         db = testing.db
         if testing.against('oracle'):
-            import sqlalchemy.databases.oracle as oracle
             insert_data =  [
                     (7, 'jack',
                      datetime.datetime(2005, 11, 10, 0, 0),
@@ -666,14 +665,12 @@ class DateTest(TestBase, AssertsExecutionResults):
             "select user_datetime from query_users_with_date",
             typemap={'user_datetime':DateTime}).execute().fetchall()
 
-        print repr(x)
         self.assert_(isinstance(x[0][0], datetime.datetime))
 
         x = testing.db.text(
             "select * from query_users_with_date where user_datetime=:somedate",
             bindparams=[bindparam('somedate', type_=types.DateTime)]).execute(
             somedate=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall()
-        print repr(x)
 
     def testdate2(self):
         meta = MetaData(testing.db)
index 959c246b84f42e7a66229011952e96c30cabd874..d9df784524bcc044498d6da0f49b1d45e2e9f41a 100644 (file)
@@ -288,7 +288,7 @@ def _server_version(bind=None):
 
     if bind is None:
         bind = config.db
-    return bind.dialect.server_version_info(bind.contextual_connect())
+    return getattr(bind.dialect, 'server_version_info', ())
 
 def skip_if(predicate, reason=None):
     """Skip a test if predicate is true."""
@@ -454,8 +454,7 @@ def against(*queries):
             if not db_spec(name)(config.db):
                 continue
 
-            have = config.db.dialect.server_version_info(
-                config.db.contextual_connect())
+            have = _server_version()
 
             oper = hasattr(op, '__call__') and op or _ops[op]
             if oper(have, spec):