]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- mysql+pyodbc working for regular usage, ORM, etc. types and unicode still flaky.
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 17:52:10 +0000 (17:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 17:52:10 +0000 (17:52 +0000)
- updated testing decorators to receive  "name+driver"-style specifications

lib/sqlalchemy/connectors/__init__.py
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
lib/sqlalchemy/engine/base.py
test/orm/query.py
test/testlib/engines.py
test/testlib/testing.py

index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f1383ad829d2995a3b4c7f2b029e7da50ba3d9f5 100644 (file)
@@ -0,0 +1,6 @@
+
+
+class Connector(object):
+    pass
+    
+    
\ No newline at end of file
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..27220b2c5cee1bc3b3278b3e7f76d440b2bb3a35 100644 (file)
@@ -0,0 +1,75 @@
+from sqlalchemy.connectors import Connector
+import sys
+import re
+
+class PyODBCConnector(Connector):
+    driver='pyodbc'
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+    # PyODBC unicode is broken on UCS-4 builds
+    supports_unicode = sys.maxunicode == 65535
+    supports_unicode_statements = supports_unicode
+    default_paramstyle = 'named'
+    
+    @classmethod
+    def dbapi(cls):
+        return __import__('pyodbc')
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        opts.update(url.query)
+        
+        keys = opts
+        query = url.query
+
+        if 'odbc_connect' in keys:
+            connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
+        else:
+            dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
+            if dsn_connection:
+                connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
+            else:
+                port = ''
+                if 'port' in keys and not 'port' in query:
+                    port = ',%d' % int(keys.pop('port'))
+
+                connectors = ["DRIVER={%s}" % keys.pop('driver'),
+                              'Server=%s%s' % (keys.pop('host', ''), port),
+                              'Database=%s' % keys.pop('database', '') ]
+
+            user = keys.pop("user", None)
+            if user:
+                connectors.append("UID=%s" % user)
+                connectors.append("PWD=%s" % keys.pop('password', ''))
+            else:
+                connectors.append("TrustedConnection=Yes")
+
+            # if set to 'Yes', the ODBC layer will try to automagically convert 
+            # textual data from your database encoding to your client encoding 
+            # This should obviously be set to 'No' if you query a cp1253 encoded 
+            # database from a latin1 client... 
+            if 'odbc_autotranslate' in keys:
+                connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
+
+            connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
+
+        return [[";".join (connectors)], {}]
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.ProgrammingError):
+            return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
+        elif isinstance(e, self.dbapi.Error):
+            return '[08S01]' in str(e)
+        else:
+            return False
+
+    def _server_version_info(self, dbapi_con):
+        """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
+        version = []
+        r = re.compile('[.\-]')
+        for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
+            try:
+                version.append(int(n))
+            except ValueError:
+                version.append(n)
+        return tuple(version)
index 67c73efb7185a99cadb91539391946af5ca89fcf..e4345a181a8b2854d42367a91f6dbb6b181d4ae2 100644 (file)
@@ -1763,7 +1763,7 @@ class MySQLDialect(default.DefaultDialect):
 
     def is_disconnect(self, e):
         if isinstance(e, self.dbapi.OperationalError):
-            return e.args[0] in (2006, 2013, 2014, 2045, 2055)
+            return self._extract_error_code(e) in (2006, 2013, 2014, 2045, 2055)
         elif isinstance(e, self.dbapi.InterfaceError):  # if underlying connection is closed, this is the error you get
             return "(0, '')" in str(e)
         else:
@@ -1775,6 +1775,9 @@ class MySQLDialect(default.DefaultDialect):
     def _compat_fetchone(self, rp, charset=None):
         return rp.fetchone()
 
+    def _extract_error_code(self, exception):
+        raise NotImplementedError()
+        
     def get_default_schema_name(self, connection):
         return connection.execute('SELECT DATABASE()').scalar()
     get_default_schema_name = engine_base.connection_memoize(
@@ -1814,7 +1817,7 @@ class MySQLDialect(default.DefaultDialect):
                 rs.close()
                 return have
             except exc.SQLError, e:
-                if e.orig.args[0] == 1146:
+                if self._extract_error_code(e) == 1146:
                     return False
                 raise
         finally:
index b644374a8a7f1fbebab10905eb82e3d03ef95d0d..ef9cf6b3e4422ebc0bf03279e7797724fe7870da 100644 (file)
@@ -95,6 +95,9 @@ class MySQL_mysqldb(MySQLDialect):
                 version.append(n)
         return tuple(version)
 
+    def _extract_error_code(self, exception):
+        return exception.orig.args[0]
+
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
 
index b2698b16d37315015764109e6f315bc85a62fa65..06c6551a874bd3dd68c9d2c6e83b940d50976dcd 100644 (file)
@@ -1,12 +1,26 @@
 from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+import re
 
 class MySQL_pyodbcExecutionContext(MySQLExecutionContext):
     def _lastrowid(self, cursor):
         cursor.execute("SELECT LAST_INSERT_ID()")
         return cursor.fetchone()[0]
 
-class MySQL_pyodbc(MySQLDialect):
-    pass
-
+class MySQL_pyodbc(PyODBCConnector, MySQLDialect):
+    supports_unicode_statements = False
+    execution_ctx_cls = MySQL_pyodbcExecutionContext
+    
+    def __init__(self, **kw):
+        MySQLDialect.__init__(self, **kw)
+        PyODBCConnector.__init__(self, **kw)
+    
+    def _extract_error_code(self, exception):
+        m = re.compile(r"\((\d+)\)").search(str(exception.orig.args))
+        c = m.group(1)
+        if c:
+            return int(c)
+        else:
+            return None
 
 dialect = MySQL_pyodbc
\ No newline at end of file
index f95da22731c172a1b00889e27d9e08c95690b091..535c5fc1c8b54bf70121e378377608db1ee37b68 100644 (file)
@@ -1122,6 +1122,12 @@ class Engine(Connectable):
         
         return self.dialect.name
 
+    @property
+    def driver(self):
+        "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``."
+
+        return self.dialect.driver
+
     echo = log.echo_property()
 
     def __repr__(self):
@@ -1456,6 +1462,7 @@ class ResultProxy(object):
 
         for i, item in enumerate(metadata):
             colname = item[0]
+
             if self.dialect.description_encoding:
                 colname = colname.decode(self.dialect.description_encoding)
 
index cba57914d1a4ac597bb3739a70b316d41f8dd144..f0633f16d025e698a47570581f1f318dd46809cc 100644 (file)
@@ -190,7 +190,7 @@ class GetTest(QueryTest):
         assert u.addresses[0].email_address == 'jack@bean.com'
         assert u.orders[1].items[2].description == 'item 5'
 
-    @testing.fails_on_everything_except('sqlite', 'mssql')
+    @testing.fails_on_everything_except('sqlite', '+pyodbc')
     def test_query_str(self):
         s = create_session()
         q = s.query(User).filter(User.id==1)
index 85e1efa3a4e8420a9c6782522c483a53c7dcd552..4f8811e45a82443792320aed776207ef33ce8ea3 100644 (file)
@@ -126,7 +126,7 @@ def utf8_engine(url=None, options=None):
 
     from sqlalchemy.engine import url as engine_url
 
-    if config.db.name == 'mysql':
+    if config.db.driver == 'mysqldb':
         dbapi_ver = config.db.dialect.dbapi.version_info
         if (dbapi_ver < (1, 2, 1) or
             dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
index fb77b07bb15413d16a8ae6160fa702afe7bfc472..af0877beb436ef94d2117b90b5863aab5a5c0a1c 100644 (file)
@@ -91,6 +91,19 @@ def future(fn):
                 "Unexpected success for future test '%s'" % fn_name)
     return _function_named(decorated, fn_name)
 
+def db_spec(*dbs):
+    dialects = set([x for x in dbs if '+' not in x])
+    drivers = set([x[1:] for x in dbs if x.startswith('+')])
+    specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers])
+
+    def check(engine):
+        return engine.name in dialects or \
+            engine.driver in drivers or \
+            (engine.name, engine.driver) in specs
+    
+    return check
+        
+
 def fails_on(dbs, reason):
     """Mark a test as expected to fail on the specified database 
     implementation.
@@ -101,23 +114,25 @@ def fails_on(dbs, reason):
     succeeds, a failure is reported.
     """
 
+    spec = db_spec(dbs)
+    
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name != dbs:
+            if not spec(config.db):
                 return fn(*args, **kw)
             else:
                 try:
                     fn(*args, **kw)
                 except Exception, ex:
                     print ("'%s' failed as expected on DB implementation "
-                           "'%s': %s" % (
-                        fn_name, config.db.name, reason))
+                           "'%s+%s': %s" % (
+                        fn_name, config.db.name, config.db.driver, reason))
                     return True
                 else:
                     raise AssertionError(
-                        "Unexpected success for '%s' on DB implementation '%s'" %
-                        (fn_name, config.db.name))
+                        "Unexpected success for '%s' on DB implementation '%s+%s'" %
+                        (fn_name, config.db.name, config.db.driver))
         return _function_named(maybe, fn_name)
     return decorate
 
@@ -128,23 +143,25 @@ def fails_on_everything_except(*dbs):
     databases except those listed.
     """
 
+    spec = db_spec(*dbs)
+    
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name in dbs:
+            if spec(config.db):
                 return fn(*args, **kw)
             else:
                 try:
                     fn(*args, **kw)
                 except Exception, ex:
                     print ("'%s' failed as expected on DB implementation "
-                           "'%s': %s" % (
-                        fn_name, config.db.name, str(ex)))
+                           "'%s+%s': %s" % (
+                        fn_name, config.db.name, config.db.driver, str(ex)))
                     return True
                 else:
                     raise AssertionError(
-                        "Unexpected success for '%s' on DB implementation '%s'" %
-                        (fn_name, config.db.name))
+                        "Unexpected success for '%s' on DB implementation '%s+%s'" %
+                        (fn_name, config.db.name, config.db.driver))
         return _function_named(maybe, fn_name)
     return decorate
 
@@ -156,12 +173,13 @@ def crashes(db, reason):
 
     """
     carp = _should_carp_about_exclusion(reason)
+    spec = db_spec(db)
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name == db:
-                msg = "'%s' unsupported on DB implementation '%s': %s" % (
-                    fn_name, config.db.name, reason)
+            if spec(config.db):
+                msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+                    fn_name, config.db.name, config.db.driver, reason)
                 print msg
                 if carp:
                     print >> sys.stderr, msg
@@ -180,12 +198,13 @@ def _block_unconditionally(db, reason):
 
     """
     carp = _should_carp_about_exclusion(reason)
+    spec = db_spec(db)
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name == db:
-                msg = "'%s' unsupported on DB implementation '%s': %s" % (
-                    fn_name, config.db.name, reason)
+            if spec(db):
+                msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+                    fn_name, config.db.name, config.db.driver, reason)
                 print msg
                 if carp:
                     print >> sys.stderr, msg
@@ -209,6 +228,7 @@ def exclude(db, op, spec, reason):
 
     """
     carp = _should_carp_about_exclusion(reason)
+    
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
@@ -253,7 +273,9 @@ def _is_excluded(db, op, spec):
       _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
     """
 
-    if config.db.name != db:
+    spec = db_spec(db)
+
+    if not spec(config.db):
         return False
 
     version = _server_version()
@@ -330,10 +352,12 @@ def emits_warning_on(db, *warnings):
     strings; these will be matched to the root of the warning description by
     warnings.filterwarnings().
     """
+    spec = db_spec(db)
+    
     def decorate(fn):
         def maybe(*args, **kw):
             if isinstance(db, basestring):
-                if config.db.name != db:
+                if not spec(config.db):
                     return fn(*args, **kw)
                 else:
                     wrapped = emits_warning(*warnings)(fn)