From 9911443b9d0b3df0c1d2a0996a6858b1a4fa9ca0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 10 Nov 2009 22:39:42 +0000 Subject: [PATCH] - new oursql dialect added. [ticket:1613] --- CHANGES | 7 + lib/sqlalchemy/dialects/mysql/__init__.py | 2 +- lib/sqlalchemy/dialects/mysql/oursql.py | 217 ++++++++++++++++++++++ test/dialect/test_mysql.py | 2 +- test/engine/test_execute.py | 2 +- test/orm/test_generative.py | 2 +- test/orm/test_query.py | 2 +- test/orm/test_relationships.py | 4 +- test/sql/test_defaults.py | 4 +- test/sql/test_types.py | 5 +- 10 files changed, 236 insertions(+), 11 deletions(-) create mode 100644 lib/sqlalchemy/dialects/mysql/oursql.py diff --git a/CHANGES b/CHANGES index 698c4760e2..41fcbb8a35 100644 --- a/CHANGES +++ b/CHANGES @@ -434,6 +434,9 @@ CHANGES object is passed in. - postgresql + - New dialects: pg8000, zxjdbc, and pypostgresql + on py3k. + - The "postgres" dialect is now named "postgresql" ! Connection strings look like: @@ -497,6 +500,10 @@ CHANGES used for such statements.) - mysql + - New dialects: oursql, a new native dialect, + MySQL Connector/Python, a native Python port of MySQLdb, + and of course zxjdbc on Jython. + - all the _detect_XXX() functions now run once underneath dialect.initialize() diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index e2a6fdc71d..1685295162 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -1,4 +1,4 @@ -from sqlalchemy.dialects.mysql import base, mysqldb, pyodbc, zxjdbc, myconnpy +from sqlalchemy.dialects.mysql import base, mysqldb, oursql, pyodbc, zxjdbc, myconnpy # default dialect base.dialect = mysqldb.dialect diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py new file mode 100644 index 0000000000..4836e76a50 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -0,0 +1,217 @@ +"""Support for the MySQL database via the oursql adapter. + +Character Sets +-------------- + +oursql defaults to using ``utf8`` as the connection charset, but other +encodings may be used instead. Like the MySQL-Python driver, unicode support +can be completely disabled:: + + # oursql sets the connection charset to utf8 automatically; all strings come + # back as utf8 str + create_engine('mysql+oursql:///mydb?use_unicode=0') + +To not automatically use ``utf8`` and instead use whatever the connection +defaults to, there is a separate parameter:: + + # use the default connection charset; all strings come back as unicode + create_engine('mysql+oursql:///mydb?default_charset=1') + + # use latin1 as the connection charset; all strings come back as unicode + create_engine('mysql+oursql:///mydb?charset=latin1') +""" + +import decimal +import re + +from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext, + MySQLCompiler, MySQLIdentifierPreparer, NUMERIC, _NumericType) +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import exc, log, schema, sql, types as sqltypes, util + + +class _PlainQuery(unicode): + pass + + +class _oursqlNumeric(NUMERIC): + def result_processor(self, dialect): + if self.asdecimal: + return + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +class _oursqlBIT(BIT): + def result_processor(self, dialect): + """oursql already converts mysql bits, so.""" + def process(value): + return value + return process + + +class MySQL_oursql(MySQLDialect): + driver = 'oursql' + supports_unicode_statements = True + supports_unicode_binds = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Time: sqltypes.Time, + sqltypes.Numeric: _oursqlNumeric, + BIT: _oursqlBIT, + } + ) + + @classmethod + def dbapi(cls): + return __import__('oursql') + + def do_execute(self, cursor, statement, parameters, context=None): + """Provide an implementation of *cursor.execute(statement, parameters)*.""" + if isinstance(statement, _PlainQuery): + cursor.execute(statement, plain_query=True) + else: + cursor.execute(statement, parameters) + + def do_begin(self, connection): + connection.cursor().execute('BEGIN', plain_query=True) + + def _xa_query(self, connection, query, xid): + connection.execute(_PlainQuery(query % connection.connection._escape_string(xid))) + + # Because mysql is bad, these methods have to be reimplemented to use _PlainQuery. Basically, some queries + # refuse to return any data if they're run through the parameterized query API, or refuse to be parameterized + # in the first place. + def do_begin_twophase(self, connection, xid): + self._xa_query(connection, 'XA BEGIN "%s"', xid) + + def do_prepare_twophase(self, connection, xid): + self._xa_query(connection, 'XA END "%s"', xid) + self._xa_query(connection, 'XA PREPARE "%s"', xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self._xa_query(connection, 'XA END "%s"', xid) + self._xa_query(connection, 'XA ROLLBACK "%s"', xid) + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + self._xa_query(connection, 'XA COMMIT "%s"', xid) + + def has_table(self, connection, table_name, schema=None): + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + + st = "DESCRIBE %s" % full_name + rs = None + try: + try: + rs = connection.execute(_PlainQuery(st)) + have = rs.rowcount > 0 + rs.close() + return have + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + return False + raise + finally: + if rs: + rs.close() + + def _show_create_table(self, connection, table, charset=None, + full_name=None): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None + try: + try: + rp = connection.execute(_PlainQuery(st)) + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + raise exc.NoSuchTableError(full_name) + else: + raise + row = rp.fetchone() + if not row: + raise exc.NoSuchTableError(full_name) + return row[1].strip() + finally: + if rp: + rp.close() + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.ProgrammingError): # if underlying connection is closed, this is the error you get + return e.errno is None and e[1].endswith('closed') + else: + return e.errno in (2006, 2013, 2014, 2045, 2055) + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'port', int) + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'autoping', bool) + + util.coerce_kw_type(opts, 'default_charset', bool) + if opts.pop('default_charset', False): + opts['charset'] = None + else: + util.coerce_kw_type(opts, 'charset', str) + util.coerce_kw_type(opts, 'use_unicode', bool) + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + opts['found_rows'] = True + # And sqlalchemy assumes that you get an exception when mysql reports a warning. + opts['raise_on_warnings'] = True + return [[], opts] + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.server_info): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _extract_error_code(self, exception): + try: + return exception.orig.errno + except AttributeError: + return None + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + return connection.connection.charset + + def _compat_fetchall(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchone() + + +dialect = MySQL_oursql diff --git a/test/dialect/test_mysql.py b/test/dialect/test_mysql.py index 49dde1520f..b65ab6312d 100644 --- a/test/dialect/test_mysql.py +++ b/test/dialect/test_mysql.py @@ -595,7 +595,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): # This is known to fail with MySQLDB 1.2.2 beta versions # which return these as sets.Set(['a']), sets.Set(['b']) # (even on Pythons with __builtin__.set) - if (not testing.against('+zxjdbc') and + if (testing.against('mysql+mysqldb') and testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and testing.db.dialect.dbapi.version_info >= (1, 2, 2)): # these mysqldb seem to always uses 'sets', even on later pythons diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 7ec4124a98..4a1342bd5c 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -28,7 +28,7 @@ class ExecuteTest(TestBase): def teardown_class(cls): metadata.drop_all() - @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc') + @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc', 'mysql+oursql') def test_raw_qmark(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) diff --git a/test/orm/test_generative.py b/test/orm/test_generative.py index 4274b46861..f30d844320 100644 --- a/test/orm/test_generative.py +++ b/test/orm/test_generative.py @@ -80,7 +80,7 @@ class GenerativeQueryTest(_base.MappedTest): @testing.resolve_artifact_names def test_aggregate_1(self): - if (testing.against('mysql') and not testing.against('+zxjdbc') and + if (testing.against('mysql+mysqldb') and testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')): return diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 83550b060b..be763f0096 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -213,7 +213,7 @@ class GetTest(QueryTest): assert u.addresses[0].email_address == 'jack@bean.com' assert u.orders[1].items[2].description == 'item 5' - @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc') + @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc', 'mysql+oursql') def test_query_str(self): s = create_session() q = s.query(User).filter(User.id==1) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index e8a7f76b12..aa1565794f 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -603,7 +603,7 @@ class RelationTest5(_base.MappedTest): lineItems=relation(LineItem, lazy=True, cascade='all, delete-orphan', - order_by=sa.asc(items.c.type), + order_by=sa.asc(items.c.id), primaryjoin=sa.and_( container_select.c.policyNum==items.c.policyNum, container_select.c.policyEffDate==items.c.policyEffDate, @@ -630,7 +630,7 @@ class RelationTest5(_base.MappedTest): assert con.policyNum == newcon.policyNum assert len(newcon.lineItems) == 10 for old, new in zip(con.lineItems, newcon.lineItems): - assert old.id == new.id + eq_(old.id, new.id) class RelationTest6(_base.MappedTest): """test a relation with a non-column entity in the primary join, diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 092c7640e3..f49e4d0d3e 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -285,7 +285,7 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_insertmany(self): # MySQL-Python 1.2.2 breaks functions in execute_many :( - if (testing.against('mysql') and not testing.against('+zxjdbc') and + if (testing.against('mysql+mysqldb') and testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): return @@ -319,7 +319,7 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_updatemany(self): # MySQL-Python 1.2.2 breaks functions in execute_many :( - if (testing.against('mysql') and not testing.against('+zxjdbc') and + if (testing.against('mysql+mysqldb') and testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): return diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 15815a420c..c0b86c1e43 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -263,8 +263,9 @@ class UnicodeTest(TestBase, AssertsExecutionResults): ( ('postgresql','psycopg2'), ('postgresql','pg8000'), - ('postgresql','zxjdbc'), - ('mysql','zxjdbc'), + ('postgresql','zxjdbc'), + ('mysql','oursql'), + ('mysql','zxjdbc'), ('sqlite','pysqlite'), )), \ "name: %s driver %s returns_unicode_strings=%s" % \ -- 2.47.3