From 15495de0edce3aa232991e97986b0a70f4107caa Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 25 Jan 2009 19:50:21 +0000 Subject: [PATCH] Corrections to MSSQL Date/Time types; generalized server_version_info to a create_engine() pre-step --- 06CHANGES | 17 ++- lib/sqlalchemy/connectors/pyodbc.py | 4 +- lib/sqlalchemy/dialects/mssql/base.py | 171 ++++++++++------------- lib/sqlalchemy/dialects/mssql/pymssql.py | 5 - lib/sqlalchemy/dialects/mssql/pyodbc.py | 10 +- lib/sqlalchemy/dialects/mysql/base.py | 97 ++++--------- lib/sqlalchemy/dialects/mysql/mysqldb.py | 8 +- lib/sqlalchemy/dialects/mysql/pyodbc.py | 1 - lib/sqlalchemy/engine/base.py | 20 ++- lib/sqlalchemy/engine/default.py | 2 +- lib/sqlalchemy/engine/strategies.py | 9 +- test/sql/testtypes.py | 3 - test/testlib/testing.py | 5 +- 13 files changed, 151 insertions(+), 201 deletions(-) diff --git a/06CHANGES b/06CHANGES index 17a0a50f23..61637907d8 100644 --- a/06CHANGES +++ b/06CHANGES @@ -9,7 +9,20 @@ code structure. - dialect refactor - + - server_version_info becomes a static attribute. + - create_engine() now establishes an initial connection immediately upon + creation, which is passed to the dialect to determine connection properties. + +- mysql + - all the _detect_XXX() functions now run once underneath dialect.initialize() + - new dialects - pg8000 - - pyodbc+mysql \ No newline at end of file + - pyodbc+mysql + +- mssql + - the "has_window_funcs" flag is removed. LIMIT/OFFSET usage will use ROW NUMBER as always, + and if on an older version of SQL Server, the operation fails. The behavior is exactly + the same except the error is raised by SQL server instead of the dialect, and no + flag setting is required to enable it. + - using new dialect.initialize() feature to set up version-dependent behavior. \ No newline at end of file diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index b94e5a4074..4f8d6d517f 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -68,8 +68,8 @@ class PyODBCConnector(Connector): else: return False - def _server_version_info(self, dbapi_con): - """Convert a pyodbc SQL_DBMS_VER string into a tuple.""" + def _get_server_version_info(self, connection): + dbapi_con = connection.connection version = [] r = re.compile('[.\-]') for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9f6ac48d35..1964b6ddc5 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -116,11 +116,6 @@ arguments on the URL, or as keyword argument to pyodbc this defaults to ``True`` if the version of pyodbc being used supports it. -* *has_window_funcs* - indicates whether or not window functions - (LIMIT and OFFSET) are supported on the version of MSSQL being - used. If you're running MSSQL 2005 or later turn this on to get - OFFSET support. Defaults to ``False``. - * *max_identifier_length* - allows you to se the maximum length of identfiers supported by the database. Defaults to 128. For pymssql the default is 30. @@ -182,10 +177,9 @@ will yield:: SELECT TOP n -If the ``has_window_funcs`` flag is set then LIMIT with OFFSET -support is available through the ``ROW_NUMBER OVER`` construct. This -construct requires an ``ORDER BY`` to be specified as well and is -only available on MSSQL 2005 and later. +If using SQL Server 2005 or above, LIMIT with OFFSET +support is available through the ``ROW_NUMBER OVER`` construct. +For versions below 2005, LIMIT with OFFSET usage will fail. Nullability ----------- @@ -206,13 +200,12 @@ If ``nullable`` is ``True`` or ``False`` then the column will be Date / Time Handling -------------------- -For MSSQL versions that support the ``DATE`` and ``TIME`` types -(MSSQL 2008+) the data type is used. For versions that do not -support the ``DATE`` and ``TIME`` types a ``DATETIME`` type is used -instead and the MSSQL dialect handles converting the results -properly. This means ``Date()`` and ``Time()`` are fully supported -on all versions of MSSQL. If you do not desire this behavior then -do not use the ``Date()`` or ``Time()`` types. +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. Compatibility Levels -------------------- @@ -234,7 +227,7 @@ Known Issues does **not** work around """ -import datetime, decimal, inspect, operator, sys +import datetime, decimal, inspect, operator, sys, re from sqlalchemy import sql, schema, exc, util from sqlalchemy.sql import compiler, expression, operators as sql_operators, functions as sql_functions @@ -242,6 +235,9 @@ from sqlalchemy.engine import default, base from sqlalchemy import types as sqltypes from decimal import Decimal as _python_Decimal +MS_2008_VERSION = (10,) +#MS_2005_VERSION = ?? +#MS_2000_VERSION = ?? MSSQL_RESERVED_WORDS = set(['function']) @@ -308,20 +304,65 @@ class MSReal(sqltypes.Float): class MSTinyInteger(sqltypes.Integer): __visit_name__ = 'TINYINT' +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/MSTime check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + +class MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, basestring): + return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + class MSTime(sqltypes.Time): def __init__(self, precision=None, **kwargs): self.precision = precision super(MSTime, self).__init__() + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, basestring): + return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process class MSDateTime(sqltypes.DateTime): def bind_processor(self, dialect): - # most DBAPIs allow a datetime.date object - # as a datetime. def process(value): - if type(value) is datetime.date: + if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) - return value + else: + return value return process class MSSmallDateTime(MSDateTime): @@ -339,53 +380,6 @@ class MSDateTimeOffset(sqltypes.TypeEngine): def __init__(self, precision=None, **kwargs): self.precision = precision -class MSDateTimeAsDate(sqltypes.TypeDecorator): - """ This is an implementation of the Date type for versions of MSSQL that - do not support that specific type. In order to make it work a ``DATETIME`` - column specification is used and the results get converted back to just - the date portion. - - """ - - impl = sqltypes.DateTime - - def process_bind_param(self, value, dialect): - if type(value) is datetime.date: - return datetime.datetime(value.year, value.month, value.day) - return value - - def process_result_value(self, value, dialect): - if type(value) is datetime.datetime: - return value.date() - return value - -class MSDateTimeAsTime(sqltypes.TypeDecorator): - """ This is an implementation of the Time type for versions of MSSQL that - do not support that specific type. In order to make it work a ``DATETIME`` - column specification is used and the results get converted back to just - the time portion. - - """ - - __zero_date = datetime.date(1900, 1, 1) - - impl = sqltypes.DateTime - - def process_bind_param(self, value, dialect): - if type(value) is datetime.datetime: - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif type(value) is datetime.time: - value = datetime.datetime.combine(self.__zero_date, value) - return value - - def process_result_value(self, value, dialect): - if type(value) is datetime.datetime: - return value.time() - elif type(value) is datetime.date: - return datetime.time(0, 0, 0) - return value - - class _StringType(object): """Base for MSSQL string types.""" @@ -672,15 +666,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("NVARCHAR", type_) def visit_date(self, type_): - # psudocode - if self.dialect.version <= 10: + if self.dialect.server_version_info < MS_2008_VERSION: return self.visit_DATETIME(type_) else: return self.visit_DATE(type_) def visit_time(self, type_): - # psudocode - if self.dialect.version <= 10: + if self.dialect.server_version_info < MS_2008_VERSION: return self.visit_DATETIME(type_) else: return self.visit_TIME(type_) @@ -791,6 +783,7 @@ colspecs = { sqltypes.Unicode : MSNVarchar, sqltypes.Numeric : MSNumeric, sqltypes.DateTime : MSDateTime, + sqltypes.Date : MSDate, sqltypes.Time : MSTime, sqltypes.String : MSString, sqltypes.Boolean : MSBoolean, @@ -861,9 +854,6 @@ class MSSQLCompiler(compiler.SQLCompiler): if select._limit: if not select._offset: s += "TOP %s " % (select._limit,) - else: - if not self.dialect.has_window_funcs: - raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s return compiler.SQLCompiler.get_select_precolumns(self, select) @@ -876,7 +866,7 @@ class MSSQLCompiler(compiler.SQLCompiler): so tries to wrap it in a subquery with ``row_number()`` criterion. """ - if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset: + if not getattr(select, '_mssql_visit', None) and select._offset: # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.process(select._order_by_clause) if not orderby: @@ -1061,7 +1051,6 @@ class MSDialect(default.DefaultDialect): execution_ctx_cls = MSExecutionContext text_as_varchar = False use_scope_identity = False - has_window_funcs = False max_identifier_length = 128 schema_name = "dbo" colspecs = colspecs @@ -1069,6 +1058,8 @@ class MSDialect(default.DefaultDialect): supports_unicode_binds = True + server_version_info = () + statement_compiler = MSSQLCompiler ddl_compiler = MSDDLCompiler type_compiler = MSTypeCompiler @@ -1077,35 +1068,19 @@ class MSDialect(default.DefaultDialect): def __init__(self, auto_identity_insert=True, query_timeout=None, use_scope_identity=False, - has_window_funcs=False, max_identifier_length=None, + max_identifier_length=None, schema_name="dbo", **opts): self.auto_identity_insert = bool(auto_identity_insert) self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = bool(use_scope_identity) - self.has_window_funcs = bool(has_window_funcs) self.max_identifier_length = int(max_identifier_length or 0) or 128 super(MSDialect, self).__init__(**opts) - - @base.connection_memoize(('mssql', 'server_version_info')) - def server_version_info(self, connection): - """A tuple of the database server version. - - Formats the remote server version as a tuple of version values, - e.g. ``(9, 0, 1399)``. If there are strings in the version number - they will be in the tuple too, so don't count on these all being - ``int`` values. - - This is a fast check that does not require a round trip. It is also - cached per-Connection. - """ - return connection.dialect._server_version_info(connection.connection) - - def _server_version_info(self, dbapi_con): - """Return a tuple of the database's version number.""" - raise NotImplementedError() - + + def initialize(self, connection): + self.server_version_info = self._get_server_version_info(connection) + def do_begin(self, connection): cursor = connection.cursor() cursor.execute("SET IMPLICIT_TRANSACTIONS OFF") diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index b7b775899e..475cc398af 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -7,11 +7,6 @@ class MSDialect_pymssql(MSDialect): max_identifier_length = 30 driver = 'pymssql' - # TODO: shouldnt this be based on server version <10 like pyodbc does ? - colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Date] = MSDateTimeAsDate - colspecs[sqltypes.Time] = MSDateTimeAsTime - @classmethod def import_dbapi(cls): import pymssql as module diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 5ff730c3f2..1b67cc04c4 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -1,4 +1,4 @@ -from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSDateTimeAsDate, MSDateTimeAsTime +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy import types as sqltypes @@ -14,14 +14,13 @@ class MSExecutionContext_pyodbc(MSExecutionContext): def post_exec(self): if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany: - import pyodbc # Fetch the last inserted id from the manipulated statement # We may have to skip over a number of result sets with no data (due to triggers, etc.) while True: try: row = self.cursor.fetchone() break - except pyodbc.Error, e: + except self.dialect.dbapi.Error, e: self.cursor.nextset() self._last_inserted_ids = [int(row[0])] else: @@ -43,11 +42,6 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): self.description_encoding = description_encoding self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') - if self.server_version_info < (10,): - self.colspecs = MSDialect.colspecs.copy() - self.colspecs[sqltypes.Date] = MSDateTimeAsDate - self.colspecs[sqltypes.Time] = MSDateTimeAsTime - def is_disconnect(self, e): if isinstance(e, self.dbapi.ProgrammingError): return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index bb6b7ab75f..412c4125ad 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1697,7 +1697,6 @@ class MySQLDialect(default.DefaultDialect): ischema_names = ischema_names def __init__(self, use_ansiquotes=None, **kwargs): - self.use_ansiquotes = use_ansiquotes default.DefaultDialect.__init__(self, **kwargs) def do_executemany(self, cursor, statement, parameters, context=None): @@ -1716,7 +1715,7 @@ class MySQLDialect(default.DefaultDialect): try: connection.commit() except: - if self._server_version_info(connection) < (3, 23, 15): + if self.server_version_info < (3, 23, 15): args = sys.exc_info()[1].args if args and args[0] == 1064: return @@ -1728,7 +1727,7 @@ class MySQLDialect(default.DefaultDialect): try: connection.rollback() except: - if self._server_version_info(connection) < (3, 23, 15): + if self.server_version_info < (3, 23, 15): args = sys.exc_info()[1].args if args and args[0] == 1064: return @@ -1786,8 +1785,7 @@ class MySQLDialect(default.DefaultDialect): def table_names(self, connection, schema): """Return a Unicode SHOW TABLES from a given schema.""" - charset = self._detect_charset(connection) - self._autoset_identifier_style(connection) + charset = self._server_charset rp = connection.execute("SHOW TABLES FROM %s" % self.identifier_preparer.quote_identifier(schema)) return [row[0] for row in self._compat_fetchall(rp, charset=charset)] @@ -1803,7 +1801,6 @@ class MySQLDialect(default.DefaultDialect): # full_name = self.identifier_preparer.format_table(table, # use_schema=True) - self._autoset_identifier_style(connection) full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( schema, table_name)) @@ -1823,36 +1820,30 @@ class MySQLDialect(default.DefaultDialect): finally: if rs: rs.close() - - @engine_base.connection_memoize(('mysql', 'server_version_info')) - def server_version_info(self, connection): - """A tuple of the database server version. - - Formats the remote server version as a tuple of version values, - e.g. ``(5, 0, 44)``. If there are strings in the version number - they will be in the tuple too, so don't count on these all being - ``int`` values. - - This is a fast check that does not require a round trip. It is also - cached per-Connection. - """ - - # TODO: do we need to bypass ConnectionFairy here? other calls - # to this seem to not do that. - return self._server_version_info(connection.connection.connection) - + + def initialize(self, connection): + self.server_version_info = self._get_server_version_info(connection) + self._server_charset = self._detect_charset(connection) + self._server_casing = self._detect_casing(connection) + self._server_collations = self._detect_collations(connection) + self._server_ansiquotes = self._detect_ansiquotes(connection) + if self._server_ansiquotes: + self.preparer = MySQLANSIIdentifierPreparer + else: + self.preparer = MySQLIdentifierPreparer + self.identifier_preparer = self.preparer(self) + def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" - charset = self._detect_charset(connection) - self._autoset_identifier_style(connection) + charset = self._server_charset try: reflector = self.reflector except AttributeError: preparer = self.identifier_preparer - if (self.server_version_info(connection) < (4, 1) and - self.use_ansiquotes): + if (self.server_version_info < (4, 1) and + self._server_use_ansiquotes): # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = MySQLIdentifierPreparer(self) @@ -1864,15 +1855,15 @@ class MySQLDialect(default.DefaultDialect): columns = self._describe_table(connection, table, charset) sql = reflector._describe_to_create(table, columns) - self._adjust_casing(connection, table) + self._adjust_casing(table) return reflector.reflect(connection, table, sql, charset, only=include_columns) - def _adjust_casing(self, connection, table, charset=None): + def _adjust_casing(self, table, charset=None): """Adjust Table name to the server case sensitivity, if needed.""" - casing = self._detect_casing(connection) + casing = self._server_casing # For winxx database hosts. TODO: is this really needed? if casing == 1 and table.name != table.name.lower(): @@ -1892,7 +1883,7 @@ class MySQLDialect(default.DefaultDialect): """ # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - charset = self._detect_charset(connection) + charset = self._server_charset row = self._compat_fetchone(connection.execute( "SHOW VARIABLES LIKE 'lower_case_table_names'"), charset=charset) @@ -1909,8 +1900,6 @@ class MySQLDialect(default.DefaultDialect): cs = int(row[1]) row.close() return cs - _detect_casing = engine_base.connection_memoize( - ('mysql', 'lower_case_table_names'))(_detect_casing) def _detect_collations(self, connection): """Pull the active COLLATIONS list from the server. @@ -1919,49 +1908,21 @@ class MySQLDialect(default.DefaultDialect): """ collations = {} - if self.server_version_info(connection) < (4, 1, 0): + if self.server_version_info < (4, 1, 0): pass else: - charset = self._detect_charset(connection) + charset = self._server_charset rs = connection.execute('SHOW COLLATION') for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations - _detect_collations = engine_base.connection_memoize( - ('mysql', 'collations'))(_detect_collations) - def use_ansiquotes(self, useansi): - self._use_ansiquotes = useansi - if useansi: - self.preparer = MySQLANSIIdentifierPreparer - else: - self.preparer = MySQLIdentifierPreparer - # icky - if hasattr(self, 'identifier_preparer'): - self.identifier_preparer = self.preparer(self) - if hasattr(self, 'reflector'): - del self.reflector - - use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes, - doc="True if ANSI_QUOTES is in effect.") - - def _autoset_identifier_style(self, connection, charset=None): - """Detect and adjust for the ANSI_QUOTES sql mode. - - If the dialect's use_ansiquotes is unset, query the server's sql mode - and reset the identifier style. - - Note that this currently *only* runs during reflection. Ideally this - would run the first time a connection pool connects to the database, - but the infrastructure for that is not yet in place. - """ - - if self.use_ansiquotes is not None: - return + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" row = self._compat_fetchone( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=charset) + charset=self._server_charset) if not row: mode = '' else: @@ -1971,7 +1932,7 @@ class MySQLDialect(default.DefaultDialect): mode_no = int(mode) mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' - self.use_ansiquotes = 'ANSI_QUOTES' in mode + return 'ANSI_QUOTES' in mode def _show_create_table(self, connection, table, charset=None, full_name=None): diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index b077774ea5..c947dc2fba 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -96,9 +96,8 @@ class MySQL_mysqldb(MySQLDialect): def do_ping(self, connection): connection.ping() - def _server_version_info(self, dbapi_con): - """Convert a MySQL-python server_info string into a tuple.""" - + def _get_server_version_info(self,connection): + dbapi_con = connection.connection version = [] r = re.compile('[.\-]') for n in r.split(dbapi_con.get_server_info()): @@ -114,7 +113,6 @@ class MySQL_mysqldb(MySQLDialect): except AttributeError: return None - @engine_base.connection_memoize(('mysql', 'charset')) def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" @@ -124,7 +122,7 @@ class MySQL_mysqldb(MySQLDialect): # Note: MySQL-python 1.2.1c7 seems to ignore changes made # on a connection via set_character_set() - if self.server_version_info(connection) < (4, 1, 0): + if self.server_version_info < (4, 1, 0): try: return connection.connection.character_set_name() except AttributeError: diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 3b9b373610..426b23cfdf 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -21,7 +21,6 @@ class MySQL_pyodbc(PyODBCConnector, MySQLDialect): MySQLDialect.__init__(self, **kw) PyODBCConnector.__init__(self, **kw) - @engine_base.connection_memoize(('mysql', 'charset')) def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index f3acc28597..f0432e16db 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -67,6 +67,13 @@ class Dialect(object): a :class:`~Compiled` class used to compile DDL statements + server_version_info + a tuple containing a version number for the DB backend in use. + This value is only available for supporting dialects, and only for + a dialect that's been associated with a connection pool via + create_engine() or otherwise had its ``initialize()`` method called + with a conneciton. + execution_ctx_cls a :class:`ExecutionContext` class used to handle statement execution @@ -114,6 +121,7 @@ class Dialect(object): supports_default_values Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported + """ def create_connect_args(self, url): @@ -141,10 +149,14 @@ class Dialect(object): raise NotImplementedError() - def server_version_info(self, connection): - """Return a tuple of the database's version number.""" - - raise NotImplementedError() + def initialize(self, connection): + """Called during strategized creation of the dialect with a connection. + + Allows dialects to configure options based on server version info or + other properties. + + """ + pass def reflecttable(self, connection, table, include_columns=None): """Load table description from the database. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 1f602eb6d3..beec145604 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -68,7 +68,7 @@ class DefaultDialect(base.Dialect): raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length)) self.label_length = label_length self.description_encoding = getattr(self, 'description_encoding', encoding) - + def type_descriptor(self, typeobj): """Provide a database-specific ``TypeEngine`` object, given the generic object which comes from the types module. diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index b1261da0a8..b763997d42 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -119,7 +119,14 @@ class DefaultEngineStrategy(EngineStrategy): dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) - return engineclass(pool, dialect, u, **engine_args) + + engine = engineclass(pool, dialect, u, **engine_args) + conn = engine.connect() + try: + dialect.initialize(conn) + finally: + conn.close() + return engine def pool_threadlocal(self): raise NotImplementedError() diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 29ed49d073..40ad8814ba 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -574,7 +574,6 @@ class DateTest(TestBase, AssertsExecutionResults): db = testing.db if testing.against('oracle'): - import sqlalchemy.databases.oracle as oracle insert_data = [ (7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), @@ -666,14 +665,12 @@ class DateTest(TestBase, AssertsExecutionResults): "select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall() - print repr(x) self.assert_(isinstance(x[0][0], datetime.datetime)) x = testing.db.text( "select * from query_users_with_date where user_datetime=:somedate", bindparams=[bindparam('somedate', type_=types.DateTime)]).execute( somedate=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall() - print repr(x) def testdate2(self): meta = MetaData(testing.db) diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 959c246b84..d9df784524 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -288,7 +288,7 @@ def _server_version(bind=None): if bind is None: bind = config.db - return bind.dialect.server_version_info(bind.contextual_connect()) + return getattr(bind.dialect, 'server_version_info', ()) def skip_if(predicate, reason=None): """Skip a test if predicate is true.""" @@ -454,8 +454,7 @@ def against(*queries): if not db_spec(name)(config.db): continue - have = config.db.dialect.server_version_info( - config.db.contextual_connect()) + have = _server_version() oper = hasattr(op, '__call__') and op or _ops[op] if oper(have, spec): -- 2.47.3