code structure.
- dialect refactor
-
+ - server_version_info becomes a static attribute.
+ - create_engine() now establishes an initial connection immediately upon
+ creation, which is passed to the dialect to determine connection properties.
+
+- mysql
+ - all the _detect_XXX() functions now run once underneath dialect.initialize()
+
- new dialects
- pg8000
- - pyodbc+mysql
\ No newline at end of file
+ - pyodbc+mysql
+
+- mssql
+ - the "has_window_funcs" flag is removed. LIMIT/OFFSET usage will use ROW NUMBER as always,
+ and if on an older version of SQL Server, the operation fails. The behavior is exactly
+ the same except the error is raised by SQL server instead of the dialect, and no
+ flag setting is required to enable it.
+ - using new dialect.initialize() feature to set up version-dependent behavior.
\ No newline at end of file
else:
return False
- def _server_version_info(self, dbapi_con):
- """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
+ def _get_server_version_info(self, connection):
+ dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
pyodbc this defaults to ``True`` if the version of pyodbc being
used supports it.
-* *has_window_funcs* - indicates whether or not window functions
- (LIMIT and OFFSET) are supported on the version of MSSQL being
- used. If you're running MSSQL 2005 or later turn this on to get
- OFFSET support. Defaults to ``False``.
-
* *max_identifier_length* - allows you to se the maximum length of
identfiers supported by the database. Defaults to 128. For pymssql
the default is 30.
SELECT TOP n
-If the ``has_window_funcs`` flag is set then LIMIT with OFFSET
-support is available through the ``ROW_NUMBER OVER`` construct. This
-construct requires an ``ORDER BY`` to be specified as well and is
-only available on MSSQL 2005 and later.
+If using SQL Server 2005 or above, LIMIT with OFFSET
+support is available through the ``ROW_NUMBER OVER`` construct.
+For versions below 2005, LIMIT with OFFSET usage will fail.
Nullability
-----------
Date / Time Handling
--------------------
-For MSSQL versions that support the ``DATE`` and ``TIME`` types
-(MSSQL 2008+) the data type is used. For versions that do not
-support the ``DATE`` and ``TIME`` types a ``DATETIME`` type is used
-instead and the MSSQL dialect handles converting the results
-properly. This means ``Date()`` and ``Time()`` are fully supported
-on all versions of MSSQL. If you do not desire this behavior then
-do not use the ``Date()`` or ``Time()`` types.
+DATE and TIME are supported. Bind parameters are converted
+to datetime.datetime() objects as required by most MSSQL drivers,
+and results are processed from strings if needed.
+The DATE and TIME types are not available for MSSQL 2005 and
+previous - if a server version below 2008 is detected, DDL
+for these types will be issued as DATETIME.
Compatibility Levels
--------------------
does **not** work around
"""
-import datetime, decimal, inspect, operator, sys
+import datetime, decimal, inspect, operator, sys, re
from sqlalchemy import sql, schema, exc, util
from sqlalchemy.sql import compiler, expression, operators as sql_operators, functions as sql_functions
from sqlalchemy import types as sqltypes
from decimal import Decimal as _python_Decimal
+MS_2008_VERSION = (10,)
+#MS_2005_VERSION = ??
+#MS_2000_VERSION = ??
MSSQL_RESERVED_WORDS = set(['function'])
class MSTinyInteger(sqltypes.Integer):
__visit_name__ = 'TINYINT'
+# MSSQL DATE/TIME types have varied behavior, sometimes returning
+# strings. MSDate/MSTime check for everything, and always
+# filter bind parameters into datetime objects (required by pyodbc,
+# not sure about other dialects).
+
+class MSDate(sqltypes.Date):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+ return process
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+ def result_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.date()
+ elif isinstance(value, basestring):
+ return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()])
+ else:
+ return value
+ return process
+
class MSTime(sqltypes.Time):
def __init__(self, precision=None, **kwargs):
self.precision = precision
super(MSTime, self).__init__()
+ __zero_date = datetime.date(1900, 1, 1)
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ value = datetime.datetime.combine(self.__zero_date, value.time())
+ elif isinstance(value, datetime.time):
+ value = datetime.datetime.combine(self.__zero_date, value)
+ return value
+ return process
+
+ _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+ def result_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.time()
+ elif isinstance(value, basestring):
+ return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()])
+ else:
+ return value
+ return process
class MSDateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
- # most DBAPIs allow a datetime.date object
- # as a datetime.
def process(value):
- if type(value) is datetime.date:
+ if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
- return value
+ else:
+ return value
return process
class MSSmallDateTime(MSDateTime):
def __init__(self, precision=None, **kwargs):
self.precision = precision
-class MSDateTimeAsDate(sqltypes.TypeDecorator):
- """ This is an implementation of the Date type for versions of MSSQL that
- do not support that specific type. In order to make it work a ``DATETIME``
- column specification is used and the results get converted back to just
- the date portion.
-
- """
-
- impl = sqltypes.DateTime
-
- def process_bind_param(self, value, dialect):
- if type(value) is datetime.date:
- return datetime.datetime(value.year, value.month, value.day)
- return value
-
- def process_result_value(self, value, dialect):
- if type(value) is datetime.datetime:
- return value.date()
- return value
-
-class MSDateTimeAsTime(sqltypes.TypeDecorator):
- """ This is an implementation of the Time type for versions of MSSQL that
- do not support that specific type. In order to make it work a ``DATETIME``
- column specification is used and the results get converted back to just
- the time portion.
-
- """
-
- __zero_date = datetime.date(1900, 1, 1)
-
- impl = sqltypes.DateTime
-
- def process_bind_param(self, value, dialect):
- if type(value) is datetime.datetime:
- value = datetime.datetime.combine(self.__zero_date, value.time())
- elif type(value) is datetime.time:
- value = datetime.datetime.combine(self.__zero_date, value)
- return value
-
- def process_result_value(self, value, dialect):
- if type(value) is datetime.datetime:
- return value.time()
- elif type(value) is datetime.date:
- return datetime.time(0, 0, 0)
- return value
-
-
class _StringType(object):
"""Base for MSSQL string types."""
return self._extend("NVARCHAR", type_)
def visit_date(self, type_):
- # psudocode
- if self.dialect.version <= 10:
+ if self.dialect.server_version_info < MS_2008_VERSION:
return self.visit_DATETIME(type_)
else:
return self.visit_DATE(type_)
def visit_time(self, type_):
- # psudocode
- if self.dialect.version <= 10:
+ if self.dialect.server_version_info < MS_2008_VERSION:
return self.visit_DATETIME(type_)
else:
return self.visit_TIME(type_)
sqltypes.Unicode : MSNVarchar,
sqltypes.Numeric : MSNumeric,
sqltypes.DateTime : MSDateTime,
+ sqltypes.Date : MSDate,
sqltypes.Time : MSTime,
sqltypes.String : MSString,
sqltypes.Boolean : MSBoolean,
if select._limit:
if not select._offset:
s += "TOP %s " % (select._limit,)
- else:
- if not self.dialect.has_window_funcs:
- raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
return compiler.SQLCompiler.get_select_precolumns(self, select)
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
- if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
+ if not getattr(select, '_mssql_visit', None) and select._offset:
# to use ROW_NUMBER(), an ORDER BY is required.
orderby = self.process(select._order_by_clause)
if not orderby:
execution_ctx_cls = MSExecutionContext
text_as_varchar = False
use_scope_identity = False
- has_window_funcs = False
max_identifier_length = 128
schema_name = "dbo"
colspecs = colspecs
supports_unicode_binds = True
+ server_version_info = ()
+
statement_compiler = MSSQLCompiler
ddl_compiler = MSDDLCompiler
type_compiler = MSTypeCompiler
def __init__(self,
auto_identity_insert=True, query_timeout=None,
use_scope_identity=False,
- has_window_funcs=False, max_identifier_length=None,
+ max_identifier_length=None,
schema_name="dbo", **opts):
self.auto_identity_insert = bool(auto_identity_insert)
self.query_timeout = int(query_timeout or 0)
self.schema_name = schema_name
self.use_scope_identity = bool(use_scope_identity)
- self.has_window_funcs = bool(has_window_funcs)
self.max_identifier_length = int(max_identifier_length or 0) or 128
super(MSDialect, self).__init__(**opts)
-
- @base.connection_memoize(('mssql', 'server_version_info'))
- def server_version_info(self, connection):
- """A tuple of the database server version.
-
- Formats the remote server version as a tuple of version values,
- e.g. ``(9, 0, 1399)``. If there are strings in the version number
- they will be in the tuple too, so don't count on these all being
- ``int`` values.
-
- This is a fast check that does not require a round trip. It is also
- cached per-Connection.
- """
- return connection.dialect._server_version_info(connection.connection)
-
- def _server_version_info(self, dbapi_con):
- """Return a tuple of the database's version number."""
- raise NotImplementedError()
-
+
+ def initialize(self, connection):
+ self.server_version_info = self._get_server_version_info(connection)
+
def do_begin(self, connection):
cursor = connection.cursor()
cursor.execute("SET IMPLICIT_TRANSACTIONS OFF")
max_identifier_length = 30
driver = 'pymssql'
- # TODO: shouldnt this be based on server version <10 like pyodbc does ?
- colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.Date] = MSDateTimeAsDate
- colspecs[sqltypes.Time] = MSDateTimeAsTime
-
@classmethod
def import_dbapi(cls):
import pymssql as module
-from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSDateTimeAsDate, MSDateTimeAsTime
+from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy import types as sqltypes
def post_exec(self):
if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
- import pyodbc
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with no data (due to triggers, etc.)
while True:
try:
row = self.cursor.fetchone()
break
- except pyodbc.Error, e:
+ except self.dialect.dbapi.Error, e:
self.cursor.nextset()
self._last_inserted_ids = [int(row[0])]
else:
self.description_encoding = description_encoding
self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
- if self.server_version_info < (10,):
- self.colspecs = MSDialect.colspecs.copy()
- self.colspecs[sqltypes.Date] = MSDateTimeAsDate
- self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
def is_disconnect(self, e):
if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
ischema_names = ischema_names
def __init__(self, use_ansiquotes=None, **kwargs):
- self.use_ansiquotes = use_ansiquotes
default.DefaultDialect.__init__(self, **kwargs)
def do_executemany(self, cursor, statement, parameters, context=None):
try:
connection.commit()
except:
- if self._server_version_info(connection) < (3, 23, 15):
+ if self.server_version_info < (3, 23, 15):
args = sys.exc_info()[1].args
if args and args[0] == 1064:
return
try:
connection.rollback()
except:
- if self._server_version_info(connection) < (3, 23, 15):
+ if self.server_version_info < (3, 23, 15):
args = sys.exc_info()[1].args
if args and args[0] == 1064:
return
def table_names(self, connection, schema):
"""Return a Unicode SHOW TABLES from a given schema."""
- charset = self._detect_charset(connection)
- self._autoset_identifier_style(connection)
+ charset = self._server_charset
rp = connection.execute("SHOW TABLES FROM %s" %
self.identifier_preparer.quote_identifier(schema))
return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
# full_name = self.identifier_preparer.format_table(table,
# use_schema=True)
- self._autoset_identifier_style(connection)
full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
schema, table_name))
finally:
if rs:
rs.close()
-
- @engine_base.connection_memoize(('mysql', 'server_version_info'))
- def server_version_info(self, connection):
- """A tuple of the database server version.
-
- Formats the remote server version as a tuple of version values,
- e.g. ``(5, 0, 44)``. If there are strings in the version number
- they will be in the tuple too, so don't count on these all being
- ``int`` values.
-
- This is a fast check that does not require a round trip. It is also
- cached per-Connection.
- """
-
- # TODO: do we need to bypass ConnectionFairy here? other calls
- # to this seem to not do that.
- return self._server_version_info(connection.connection.connection)
-
+
+ def initialize(self, connection):
+ self.server_version_info = self._get_server_version_info(connection)
+ self._server_charset = self._detect_charset(connection)
+ self._server_casing = self._detect_casing(connection)
+ self._server_collations = self._detect_collations(connection)
+ self._server_ansiquotes = self._detect_ansiquotes(connection)
+ if self._server_ansiquotes:
+ self.preparer = MySQLANSIIdentifierPreparer
+ else:
+ self.preparer = MySQLIdentifierPreparer
+ self.identifier_preparer = self.preparer(self)
+
def reflecttable(self, connection, table, include_columns):
"""Load column definitions from the server."""
- charset = self._detect_charset(connection)
- self._autoset_identifier_style(connection)
+ charset = self._server_charset
try:
reflector = self.reflector
except AttributeError:
preparer = self.identifier_preparer
- if (self.server_version_info(connection) < (4, 1) and
- self.use_ansiquotes):
+ if (self.server_version_info < (4, 1) and
+ self._server_use_ansiquotes):
# ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
preparer = MySQLIdentifierPreparer(self)
columns = self._describe_table(connection, table, charset)
sql = reflector._describe_to_create(table, columns)
- self._adjust_casing(connection, table)
+ self._adjust_casing(table)
return reflector.reflect(connection, table, sql, charset,
only=include_columns)
- def _adjust_casing(self, connection, table, charset=None):
+ def _adjust_casing(self, table, charset=None):
"""Adjust Table name to the server case sensitivity, if needed."""
- casing = self._detect_casing(connection)
+ casing = self._server_casing
# For winxx database hosts. TODO: is this really needed?
if casing == 1 and table.name != table.name.lower():
"""
# http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
- charset = self._detect_charset(connection)
+ charset = self._server_charset
row = self._compat_fetchone(connection.execute(
"SHOW VARIABLES LIKE 'lower_case_table_names'"),
charset=charset)
cs = int(row[1])
row.close()
return cs
- _detect_casing = engine_base.connection_memoize(
- ('mysql', 'lower_case_table_names'))(_detect_casing)
def _detect_collations(self, connection):
"""Pull the active COLLATIONS list from the server.
"""
collations = {}
- if self.server_version_info(connection) < (4, 1, 0):
+ if self.server_version_info < (4, 1, 0):
pass
else:
- charset = self._detect_charset(connection)
+ charset = self._server_charset
rs = connection.execute('SHOW COLLATION')
for row in self._compat_fetchall(rs, charset):
collations[row[0]] = row[1]
return collations
- _detect_collations = engine_base.connection_memoize(
- ('mysql', 'collations'))(_detect_collations)
- def use_ansiquotes(self, useansi):
- self._use_ansiquotes = useansi
- if useansi:
- self.preparer = MySQLANSIIdentifierPreparer
- else:
- self.preparer = MySQLIdentifierPreparer
- # icky
- if hasattr(self, 'identifier_preparer'):
- self.identifier_preparer = self.preparer(self)
- if hasattr(self, 'reflector'):
- del self.reflector
-
- use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes,
- doc="True if ANSI_QUOTES is in effect.")
-
- def _autoset_identifier_style(self, connection, charset=None):
- """Detect and adjust for the ANSI_QUOTES sql mode.
-
- If the dialect's use_ansiquotes is unset, query the server's sql mode
- and reset the identifier style.
-
- Note that this currently *only* runs during reflection. Ideally this
- would run the first time a connection pool connects to the database,
- but the infrastructure for that is not yet in place.
- """
-
- if self.use_ansiquotes is not None:
- return
+ def _detect_ansiquotes(self, connection):
+ """Detect and adjust for the ANSI_QUOTES sql mode."""
row = self._compat_fetchone(
connection.execute("SHOW VARIABLES LIKE 'sql_mode'"),
- charset=charset)
+ charset=self._server_charset)
if not row:
mode = ''
else:
mode_no = int(mode)
mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or ''
- self.use_ansiquotes = 'ANSI_QUOTES' in mode
+ return 'ANSI_QUOTES' in mode
def _show_create_table(self, connection, table, charset=None,
full_name=None):
def do_ping(self, connection):
connection.ping()
- def _server_version_info(self, dbapi_con):
- """Convert a MySQL-python server_info string into a tuple."""
-
+ def _get_server_version_info(self,connection):
+ dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.get_server_info()):
except AttributeError:
return None
- @engine_base.connection_memoize(('mysql', 'charset'))
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
# on a connection via set_character_set()
- if self.server_version_info(connection) < (4, 1, 0):
+ if self.server_version_info < (4, 1, 0):
try:
return connection.connection.character_set_name()
except AttributeError:
MySQLDialect.__init__(self, **kw)
PyODBCConnector.__init__(self, **kw)
- @engine_base.connection_memoize(('mysql', 'charset'))
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
a :class:`~Compiled` class used to compile DDL
statements
+ server_version_info
+ a tuple containing a version number for the DB backend in use.
+ This value is only available for supporting dialects, and only for
+ a dialect that's been associated with a connection pool via
+ create_engine() or otherwise had its ``initialize()`` method called
+ with a conneciton.
+
execution_ctx_cls
a :class:`ExecutionContext` class used to handle statement execution
supports_default_values
Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported
+
"""
def create_connect_args(self, url):
raise NotImplementedError()
- def server_version_info(self, connection):
- """Return a tuple of the database's version number."""
-
- raise NotImplementedError()
+ def initialize(self, connection):
+ """Called during strategized creation of the dialect with a connection.
+
+ Allows dialects to configure options based on server version info or
+ other properties.
+
+ """
+ pass
def reflecttable(self, connection, table, include_columns=None):
"""Load table description from the database.
raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length))
self.label_length = label_length
self.description_encoding = getattr(self, 'description_encoding', encoding)
-
+
def type_descriptor(self, typeobj):
"""Provide a database-specific ``TypeEngine`` object, given
the generic object which comes from the types module.
dialect.__class__.__name__,
pool.__class__.__name__,
engineclass.__name__))
- return engineclass(pool, dialect, u, **engine_args)
+
+ engine = engineclass(pool, dialect, u, **engine_args)
+ conn = engine.connect()
+ try:
+ dialect.initialize(conn)
+ finally:
+ conn.close()
+ return engine
def pool_threadlocal(self):
raise NotImplementedError()
db = testing.db
if testing.against('oracle'):
- import sqlalchemy.databases.oracle as oracle
insert_data = [
(7, 'jack',
datetime.datetime(2005, 11, 10, 0, 0),
"select user_datetime from query_users_with_date",
typemap={'user_datetime':DateTime}).execute().fetchall()
- print repr(x)
self.assert_(isinstance(x[0][0], datetime.datetime))
x = testing.db.text(
"select * from query_users_with_date where user_datetime=:somedate",
bindparams=[bindparam('somedate', type_=types.DateTime)]).execute(
somedate=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall()
- print repr(x)
def testdate2(self):
meta = MetaData(testing.db)
if bind is None:
bind = config.db
- return bind.dialect.server_version_info(bind.contextual_connect())
+ return getattr(bind.dialect, 'server_version_info', ())
def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
if not db_spec(name)(config.db):
continue
- have = config.db.dialect.server_version_info(
- config.db.contextual_connect())
+ have = _server_version()
oper = hasattr(op, '__call__') and op or _ops[op]
if oper(have, spec):