From de7eab027ad01ec8e47712bc43e049c884eba648 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 14 Jan 2009 17:52:10 +0000 Subject: [PATCH] - mysql+pyodbc working for regular usage, ORM, etc. types and unicode still flaky. - updated testing decorators to receive "name+driver"-style specifications --- lib/sqlalchemy/connectors/__init__.py | 6 ++ lib/sqlalchemy/connectors/pyodbc.py | 75 ++++++++++++++++++++++++ lib/sqlalchemy/dialects/mysql/base.py | 7 ++- lib/sqlalchemy/dialects/mysql/mysqldb.py | 3 + lib/sqlalchemy/dialects/mysql/pyodbc.py | 20 ++++++- lib/sqlalchemy/engine/base.py | 7 +++ test/orm/query.py | 2 +- test/testlib/engines.py | 2 +- test/testlib/testing.py | 60 +++++++++++++------ 9 files changed, 157 insertions(+), 25 deletions(-) diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py index e69de29bb2..f1383ad829 100644 --- a/lib/sqlalchemy/connectors/__init__.py +++ b/lib/sqlalchemy/connectors/__init__.py @@ -0,0 +1,6 @@ + + +class Connector(object): + pass + + \ No newline at end of file diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index e69de29bb2..27220b2c5c 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -0,0 +1,75 @@ +from sqlalchemy.connectors import Connector +import sys +import re + +class PyODBCConnector(Connector): + driver='pyodbc' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + # PyODBC unicode is broken on UCS-4 builds + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = supports_unicode + default_paramstyle = 'named' + + @classmethod + def dbapi(cls): + return __import__('pyodbc') + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + + keys = opts + query = url.query + + if 'odbc_connect' in keys: + connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] + else: + dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) + if dsn_connection: + connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] + else: + port = '' + if 'port' in keys and not 'port' in query: + port = ',%d' % int(keys.pop('port')) + + connectors = ["DRIVER={%s}" % keys.pop('driver'), + 'Server=%s%s' % (keys.pop('host', ''), port), + 'Database=%s' % keys.pop('database', '') ] + + user = keys.pop("user", None) + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % keys.pop('password', '')) + else: + connectors.append("TrustedConnection=Yes") + + # if set to 'Yes', the ODBC layer will try to automagically convert + # textual data from your database encoding to your client encoding + # This should obviously be set to 'No' if you query a cp1253 encoded + # database from a latin1 client... + if 'odbc_autotranslate' in keys: + connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) + + connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) + + return [[";".join (connectors)], {}] + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + + def _server_version_info(self, dbapi_con): + """Convert a pyodbc SQL_DBMS_VER string into a tuple.""" + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 67c73efb71..e4345a181a 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1763,7 +1763,7 @@ class MySQLDialect(default.DefaultDialect): def is_disconnect(self, e): if isinstance(e, self.dbapi.OperationalError): - return e.args[0] in (2006, 2013, 2014, 2045, 2055) + return self._extract_error_code(e) in (2006, 2013, 2014, 2045, 2055) elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get return "(0, '')" in str(e) else: @@ -1775,6 +1775,9 @@ class MySQLDialect(default.DefaultDialect): def _compat_fetchone(self, rp, charset=None): return rp.fetchone() + def _extract_error_code(self, exception): + raise NotImplementedError() + def get_default_schema_name(self, connection): return connection.execute('SELECT DATABASE()').scalar() get_default_schema_name = engine_base.connection_memoize( @@ -1814,7 +1817,7 @@ class MySQLDialect(default.DefaultDialect): rs.close() return have except exc.SQLError, e: - if e.orig.args[0] == 1146: + if self._extract_error_code(e) == 1146: return False raise finally: diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index b644374a8a..ef9cf6b3e4 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -95,6 +95,9 @@ class MySQL_mysqldb(MySQLDialect): version.append(n) return tuple(version) + def _extract_error_code(self, exception): + return exception.orig.args[0] + def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index b2698b16d3..06c6551a87 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -1,12 +1,26 @@ from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector +import re class MySQL_pyodbcExecutionContext(MySQLExecutionContext): def _lastrowid(self, cursor): cursor.execute("SELECT LAST_INSERT_ID()") return cursor.fetchone()[0] -class MySQL_pyodbc(MySQLDialect): - pass - +class MySQL_pyodbc(PyODBCConnector, MySQLDialect): + supports_unicode_statements = False + execution_ctx_cls = MySQL_pyodbcExecutionContext + + def __init__(self, **kw): + MySQLDialect.__init__(self, **kw) + PyODBCConnector.__init__(self, **kw) + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + else: + return None dialect = MySQL_pyodbc \ No newline at end of file diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index f95da22731..535c5fc1c8 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1122,6 +1122,12 @@ class Engine(Connectable): return self.dialect.name + @property + def driver(self): + "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." + + return self.dialect.driver + echo = log.echo_property() def __repr__(self): @@ -1456,6 +1462,7 @@ class ResultProxy(object): for i, item in enumerate(metadata): colname = item[0] + if self.dialect.description_encoding: colname = colname.decode(self.dialect.description_encoding) diff --git a/test/orm/query.py b/test/orm/query.py index cba57914d1..f0633f16d0 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -190,7 +190,7 @@ class GetTest(QueryTest): assert u.addresses[0].email_address == 'jack@bean.com' assert u.orders[1].items[2].description == 'item 5' - @testing.fails_on_everything_except('sqlite', 'mssql') + @testing.fails_on_everything_except('sqlite', '+pyodbc') def test_query_str(self): s = create_session() q = s.query(User).filter(User.id==1) diff --git a/test/testlib/engines.py b/test/testlib/engines.py index 85e1efa3a4..4f8811e45a 100644 --- a/test/testlib/engines.py +++ b/test/testlib/engines.py @@ -126,7 +126,7 @@ def utf8_engine(url=None, options=None): from sqlalchemy.engine import url as engine_url - if config.db.name == 'mysql': + if config.db.driver == 'mysqldb': dbapi_ver = config.db.dialect.dbapi.version_info if (dbapi_ver < (1, 2, 1) or dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), diff --git a/test/testlib/testing.py b/test/testlib/testing.py index fb77b07bb1..af0877beb4 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -91,6 +91,19 @@ def future(fn): "Unexpected success for future test '%s'" % fn_name) return _function_named(decorated, fn_name) +def db_spec(*dbs): + dialects = set([x for x in dbs if '+' not in x]) + drivers = set([x[1:] for x in dbs if x.startswith('+')]) + specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers]) + + def check(engine): + return engine.name in dialects or \ + engine.driver in drivers or \ + (engine.name, engine.driver) in specs + + return check + + def fails_on(dbs, reason): """Mark a test as expected to fail on the specified database implementation. @@ -101,23 +114,25 @@ def fails_on(dbs, reason): succeeds, a failure is reported. """ + spec = db_spec(dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name != dbs: + if not spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, reason)) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason)) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return _function_named(maybe, fn_name) return decorate @@ -128,23 +143,25 @@ def fails_on_everything_except(*dbs): databases except those listed. """ + spec = db_spec(*dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name in dbs: + if spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, str(ex))) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, str(ex))) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return _function_named(maybe, fn_name) return decorate @@ -156,12 +173,13 @@ def crashes(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -180,12 +198,13 @@ def _block_unconditionally(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -209,6 +228,7 @@ def exclude(db, op, spec, reason): """ carp = _should_carp_about_exclusion(reason) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): @@ -253,7 +273,9 @@ def _is_excluded(db, op, spec): _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) """ - if config.db.name != db: + spec = db_spec(db) + + if not spec(config.db): return False version = _server_version() @@ -330,10 +352,12 @@ def emits_warning_on(db, *warnings): strings; these will be matched to the root of the warning description by warnings.filterwarnings(). """ + spec = db_spec(db) + def decorate(fn): def maybe(*args, **kw): if isinstance(db, basestring): - if config.db.name != db: + if not spec(config.db): return fn(*args, **kw) else: wrapped = emits_warning(*warnings)(fn) -- 2.47.3