]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more mysql+pyodbc fixes
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 23:52:14 +0000 (23:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 23:52:14 +0000 (23:52 +0000)
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
test/engine/execute.py
test/engine/reflection.py
test/ext/declarative.py
test/sql/testtypes.py
test/testlib/testing.py

index 27220b2c5cee1bc3b3278b3e7f76d440b2bb3a35..590971bd6fa10171a5b837bf60bc294945e67280 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy.connectors import Connector
+
 import sys
 import re
 
index e4345a181a8b2854d42367a91f6dbb6b181d4ae2..e7e250762cebbf143fbe5aa48f5faa549573e083 100644 (file)
@@ -183,6 +183,7 @@ from sqlalchemy import exc, log, schema, sql, util
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy.sql import functions as sql_functions
 from sqlalchemy.sql import compiler
+from array import array as _array
 
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy import types as sqltypes
@@ -1330,12 +1331,6 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
 
-
-    def post_process_text(self, text):
-        if '%%' in text:
-            util.warn("The SQLAlchemy MySQLDB dialect now automatically escapes '%' in text() expressions to '%%'.")
-        return text.replace('%', '%%')
-
     def get_select_precolumns(self, select):
         if isinstance(select._distinct, basestring):
             return select._distinct.upper() + " "
@@ -1739,23 +1734,23 @@ class MySQLDialect(default.DefaultDialect):
             raise
 
     def do_begin_twophase(self, connection, xid):
-        connection.execute("XA BEGIN %s", xid)
+        connection.execute(sql.text("XA BEGIN :xid"), xid=xid)
 
     def do_prepare_twophase(self, connection, xid):
-        connection.execute("XA END %s", xid)
-        connection.execute("XA PREPARE %s", xid)
+        connection.execute(sql.text("XA END :xid"), xid=xid)
+        connection.execute(sql.text("XA PREPARE :xid"), xid=xid)
 
     def do_rollback_twophase(self, connection, xid, is_prepared=True,
                              recover=False):
         if not is_prepared:
-            connection.execute("XA END %s", xid)
-        connection.execute("XA ROLLBACK %s", xid)
+            connection.execute(sql.text("XA END :xid"), xid=xid)
+        connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid)
 
     def do_commit_twophase(self, connection, xid, is_prepared=True,
                            recover=False):
         if not is_prepared:
             self.do_prepare_twophase(connection, xid)
-        connection.execute("XA COMMIT %s", xid)
+        connection.execute(sql.text("XA COMMIT :xid"), xid=xid)
 
     def do_recover_twophase(self, connection):
         resultset = connection.execute("XA RECOVER")
@@ -1770,10 +1765,14 @@ class MySQLDialect(default.DefaultDialect):
             return False
 
     def _compat_fetchall(self, rp, charset=None):
-        return rp.fetchall()
+        """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
+
+        return [_DecodingRowProxy(row, charset) for row in rp.fetchall()]
 
     def _compat_fetchone(self, rp, charset=None):
-        return rp.fetchone()
+        """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
+
+        return _DecodingRowProxy(rp.fetchone(), charset)
 
     def _extract_error_code(self, exception):
         raise NotImplementedError()
@@ -1881,33 +1880,7 @@ class MySQLDialect(default.DefaultDialect):
             table.metadata.tables[lc_alias] = table
 
     def _detect_charset(self, connection):
-        """Sniff out the character set in use for connection results."""
-
-        # Allow user override, won't sniff if force_charset is set.
-        if ('mysql', 'force_charset') in connection.info:
-            return connection.info[('mysql', 'force_charset')]
-
-        # Prefer 'character_set_results' for the current connection over the
-        # value in the driver.  SET NAMES or individual variable SETs will
-        # change the charset without updating the driver's view of the world.
-        #
-        # If it's decided that issuing that sort of SQL leaves you SOL, then
-        # this can prefer the driver value.
-        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
-        opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
-
-        if 'character_set_results' in opts:
-            return opts['character_set_results']
-        # Still no charset on < 1.2.1 final...
-        if 'character_set' in opts:
-            return opts['character_set']
-        else:
-            util.warn(
-                "Could not detect the connection character set.  Assuming latin1.")
-            return 'latin1'
-    _detect_charset = engine_base.connection_memoize(
-        ('mysql', 'charset'))(_detect_charset)
-
+        raise NotImplementedError()
 
     def _detect_casing(self, connection):
         """Sniff out identifier case sensitivity.
@@ -2631,6 +2604,40 @@ class MySQLSchemaReflector(object):
 log.class_logger(MySQLSchemaReflector)
 
 
+class _DecodingRowProxy(object):
+    """Return unicode-decoded values based on type inspection.
+
+    Smooth over data type issues (esp. with alpha driver versions) and
+    normalize strings as Unicode regardless of user-configured driver
+    encoding settings.
+
+    """
+
+    # Some MySQL-python versions can return some columns as
+    # sets.Set(['value']) (seriously) but thankfully that doesn't
+    # seem to come up in DDL queries.
+
+    def __init__(self, rowproxy, charset):
+        self.rowproxy = rowproxy
+        self.charset = charset
+    def __getitem__(self, index):
+        item = self.rowproxy[index]
+        if isinstance(item, _array):
+            item = item.tostring()
+        if self.charset and isinstance(item, str):
+            return item.decode(self.charset)
+        else:
+            return item
+    def __getattr__(self, attr):
+        item = getattr(self.rowproxy, attr)
+        if isinstance(item, _array):
+            item = item.tostring()
+        if self.charset and isinstance(item, str):
+            return item.decode(self.charset)
+        else:
+            return item
+
+
 class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
     """MySQL-specific schema identifier configuration."""
 
index ef9cf6b3e4422ebc0bf03279e7797724fe7870da..6ad8d044735ffd519d0fe85e4f98abdb6a57eda0 100644 (file)
@@ -20,21 +20,27 @@ strings, also pass ``use_unicode=0`` in the connection arguments::
   create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
 """
 
-from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
+from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext, MySQLCompiler
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy import exc, log, schema, sql, util
 import re
-from array import array as _array
 
 class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
     def _lastrowid(self, cursor):
         return cursor.lastrowid
 
+class MySQL_mysqldbCompiler(MySQLCompiler):
+    def post_process_text(self, text):
+        if '%%' in text:
+            util.warn("The SQLAlchemy mysql+mysqldb dialect now automatically escapes '%' in text() expressions to '%%'.")
+        return text.replace('%', '%%')
+    
 class MySQL_mysqldb(MySQLDialect):
     driver = 'mysqldb'
     supports_unicode_statements = False
     default_paramstyle = 'format'
     execution_ctx_cls = MySQL_mysqldbExecutionContext
+    sql_compiler = MySQL_mysqldbCompiler
     
     @classmethod
     def dbapi(cls):
@@ -98,6 +104,7 @@ class MySQL_mysqldb(MySQLDialect):
     def _extract_error_code(self, exception):
         return exception.orig.args[0]
 
+    @engine_base.connection_memoize(('mysql', 'charset'))
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
 
@@ -138,51 +145,6 @@ class MySQL_mysqldb(MySQLDialect):
                     "combination of MySQL server and MySQL-python. "
                     "MySQL-python >= 1.2.2 is recommended.  Assuming latin1.")
                 return 'latin1'
-    _detect_charset = engine_base.connection_memoize(
-        ('mysql', 'charset'))(_detect_charset)
-
-
-    def _compat_fetchall(self, rp, charset=None):
-        """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
-
-        return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
-
-    def _compat_fetchone(self, rp, charset=None):
-        """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
-
-        return _MySQLPythonRowProxy(rp.fetchone(), charset)
-
-class _MySQLPythonRowProxy(object):
-    """Return consistent column values for all versions of MySQL-python.
-
-    Smooth over data type issues (esp. with alpha driver versions) and
-    normalize strings as Unicode regardless of user-configured driver
-    encoding settings.
-    """
-
-    # Some MySQL-python versions can return some columns as
-    # sets.Set(['value']) (seriously) but thankfully that doesn't
-    # seem to come up in DDL queries.
-
-    def __init__(self, rowproxy, charset):
-        self.rowproxy = rowproxy
-        self.charset = charset
-    def __getitem__(self, index):
-        item = self.rowproxy[index]
-        if isinstance(item, _array):
-            item = item.tostring()
-        if self.charset and isinstance(item, str):
-            return item.decode(self.charset)
-        else:
-            return item
-    def __getattr__(self, attr):
-        item = getattr(self.rowproxy, attr)
-        if isinstance(item, _array):
-            item = item.tostring()
-        if self.charset and isinstance(item, str):
-            return item.decode(self.charset)
-        else:
-            return item
 
 
 dialect = MySQL_mysqldb
\ No newline at end of file
index 06c6551a874bd3dd68c9d2c6e83b940d50976dcd..b6f428ed266b1d0c5f63e050512211513fb3a913 100644 (file)
@@ -1,5 +1,7 @@
 from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
 from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy.engine import base as engine_base
+from sqlalchemy import util
 import re
 
 class MySQL_pyodbcExecutionContext(MySQLExecutionContext):
@@ -14,6 +16,29 @@ class MySQL_pyodbc(PyODBCConnector, MySQLDialect):
     def __init__(self, **kw):
         MySQLDialect.__init__(self, **kw)
         PyODBCConnector.__init__(self, **kw)
+
+    @engine_base.connection_memoize(('mysql', 'charset'))
+    def _detect_charset(self, connection):
+        """Sniff out the character set in use for connection results."""
+
+        # Allow user override, won't sniff if force_charset is set.
+        if ('mysql', 'force_charset') in connection.info:
+            return connection.info[('mysql', 'force_charset')]
+
+        # Prefer 'character_set_results' for the current connection over the
+        # value in the driver.  SET NAMES or individual variable SETs will
+        # change the charset without updating the driver's view of the world.
+        #
+        # If it's decided that issuing that sort of SQL leaves you SOL, then
+        # this can prefer the driver value.
+        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
+        for key in ('character_set_connection', 'character_set'):
+            if opts.get(key, None):
+                return opts[key]
+
+        util.warn("Could not detect the connection character set.  Assuming latin1.")
+        return 'latin1'
     
     def _extract_error_code(self, exception):
         m = re.compile(r"\((\d+)\)").search(str(exception.orig.args))
index 515c99d309bd9d44d55ec44459fb0dcac7cd8d1d..9d9a57e7c068cd04bf18fb272bddb588023ce8d7 100644 (file)
@@ -23,7 +23,7 @@ class ExecuteTest(TestBase):
     def tearDownAll(self):
         metadata.drop_all()
 
-    @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite')
+    @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc')
     def test_raw_qmark(self):
         for conn in (testing.db, testing.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
@@ -35,7 +35,7 @@ class ExecuteTest(TestBase):
             assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
             conn.execute("delete from users")
 
-    @testing.fails_on_everything_except('mysql', 'postgres')
+    @testing.fails_on_everything_except('mysql+mysqldb', 'postgres')
     # some psycopg2 versions bomb this.
     def test_raw_sprintf(self):
         for conn in (testing.db, testing.db.connect()):
@@ -49,7 +49,7 @@ class ExecuteTest(TestBase):
 
     # pyformat is supported for mysql, but skipping because a few driver
     # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
-    @testing.skip_if(lambda: testing.against('mysql'), 'db-api flaky')
+    @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky')
     @testing.fails_on_everything_except('postgres')
     def test_raw_python(self):
         for conn in (testing.db, testing.db.connect()):
index a448540825bcf2721a9f4889042c7d1991c8e8e4..4e6601951f2ccdb343dc5023564bd4b3a15aae6c 100644 (file)
@@ -744,7 +744,7 @@ class SchemaTest(TestBase):
     def test_explicit_default_schema(self):
         engine = testing.db
 
-        if testing.against('mysql'):
+        if testing.against('mysql+mysqldb'):
             schema = testing.db.url.database
         elif testing.against('postgres'):
             schema = 'public'
index c9477b5d85c1e53f51fa6dbb7b1c1dd6b90f2984..c305fa1d03af0c6693a7e43b1f020a3796ef7651 100644 (file)
@@ -1065,7 +1065,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
 
         class Engineer(Person):
             __mapper_args__ = {'polymorphic_identity':'engineer'}
-            primary_language_id = Column(String(50), ForeignKey('languages.id'))
+            primary_language_id = Column(Integer, ForeignKey('languages.id'))
             primary_language = relation("Language")
             
         class Language(Base, ComparableEntity):
index 9ce7b7662ab1715aedfa5a0978e89721268e83d0..ca22fcb270ce57c3651d91157d55230163dd4b21 100644 (file)
@@ -302,7 +302,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
         self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
         if isinstance(x['plain_varchar'], unicode):
             # SQLLite and MSSQL return non-unicode data as unicode
-            self.assert_(testing.against('sqlite', 'mssql'))
+            self.assert_(testing.against('sqlite', '+pyodbc'))
             if not testing.against('sqlite'):
                 self.assert_(x['plain_varchar'] == unicodedata)
             print "it's %s!" % testing.db.name
index af0877beb436ef94d2117b90b5863aab5a5c0a1c..d2bccf94db7178ac18c9b1b262e428ec54c6f53b 100644 (file)
@@ -202,7 +202,7 @@ def _block_unconditionally(db, reason):
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if spec(db):
+            if spec(config.db):
                 msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
                     fn_name, config.db.name, config.db.driver, reason)
                 print msg
@@ -447,11 +447,11 @@ def against(*queries):
 
     for query in queries:
         if isinstance(query, basestring):
-            if config.db.name == query:
+            if db_spec(query)(config.db):
                 return True
         else:
             name, op, spec = query
-            if config.db.name != name:
+            if not db_spec(name)(config.db):
                 continue
 
             have = config.db.dialect.server_version_info(