]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
o move mysqldb's DecimalType behavior out of base
authorPhilip Jenvey <pjenvey@underboss.org>
Sat, 25 Jul 2009 20:20:26 +0000 (20:20 +0000)
committerPhilip Jenvey <pjenvey@underboss.org>
Sat, 25 Jul 2009 20:20:26 +0000 (20:20 +0000)
o add a BIT result processor for zxjdbc

lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/zxjdbc.py

index d3e11c1fcc6e214b0d6bb2fc894c4a4b070e0dc0..b325f5ef58377d3e2a67ba858436e1254b4637a6 100644 (file)
@@ -177,7 +177,7 @@ timely information affecting MySQL in SQLAlchemy.
 
 """
 
-import datetime, decimal, inspect, re, sys
+import datetime, inspect, re, sys
 
 from sqlalchemy import schema as sa_schema
 from sqlalchemy import exc, log, sql, util
@@ -264,23 +264,6 @@ class _FloatType(_NumericType, sqltypes.Float):
         super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw)
         self.scale = scale
 
-class _DecimalType(_NumericType):
-    def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect):
-        # TODO: this behavior might by MySQLdb specific,
-        # i.e. that Decimals are returned by the DBAPI
-        if not self.asdecimal:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-        else:
-            return None
-
 class _IntegerType(_NumericType, sqltypes.Integer):
     def __init__(self, display_width=None, **kw):
         self.display_width = display_width
@@ -326,7 +309,7 @@ class _BinaryType(sqltypes.Binary):
                 return util.buffer(value)
         return process
 
-class NUMERIC(_DecimalType, sqltypes.NUMERIC):
+class NUMERIC(_NumericType, sqltypes.NUMERIC):
     """MySQL NUMERIC type."""
     
     __visit_name__ = 'NUMERIC'
@@ -350,7 +333,7 @@ class NUMERIC(_DecimalType, sqltypes.NUMERIC):
         super(NUMERIC, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
 
 
-class DECIMAL(_DecimalType, sqltypes.DECIMAL):
+class DECIMAL(_NumericType, sqltypes.DECIMAL):
     """MySQL DECIMAL type."""
     
     __visit_name__ = 'DECIMAL'
index adaea792ac83f22f75dcf7d88dcc97e1eaacf77a..b5f7779843816581e37e22e245ad28d5b11123a3 100644 (file)
@@ -20,12 +20,14 @@ strings, also pass ``use_unicode=0`` in the connection arguments::
   create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
 """
 
-from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext, MySQLCompiler
+import decimal
+import re
+
+from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext,
+                                            MySQLCompiler, NUMERIC, _NumericType)
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
-
-from sqlalchemy import exc, log, schema, sql, util
-import re
+from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
 
 class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
     def _lastrowid(self, cursor):
@@ -45,7 +47,28 @@ class MySQL_mysqldbCompiler(MySQLCompiler):
     
     def post_process_text(self, text):
         return text.replace('%', '%%')
-    
+
+
+class _DecimalType(_NumericType):
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return
+        def process(value):
+            if isinstance(value, decimal.Decimal):
+                return float(value)
+            else:
+                return value
+        return process
+
+
+class _MySQLdbNumeric(_DecimalType, NUMERIC):
+    pass
+
+
+class _MySQLdbDecimal(_DecimalType, DECIMAL):
+    pass
+
+
 class MySQL_mysqldb(MySQLDialect):
     driver = 'mysqldb'
     supports_unicode_statements = False
@@ -55,6 +78,14 @@ class MySQL_mysqldb(MySQLDialect):
     default_paramstyle = 'format'
     execution_ctx_cls = MySQL_mysqldbExecutionContext
     statement_compiler = MySQL_mysqldbCompiler
+
+    colspecs = util.update_copy(
+        MySQLDialect.colspecs,
+        {
+            sqltypes.Numeric: _MySQLdbNumeric,
+            DECIMAL: _MySQLdbDecimal
+        }
+    )
     
     @classmethod
     def dbapi(cls):
index 3ffe85f831bd0b254a81f187d5ed1998d60dc56e..31662484407c3bec8bf74c200d43ed4cdaf1cd8c 100644 (file)
@@ -1,6 +1,16 @@
+"""Support for the MySQL database via Jython's zxjdbc JDBC connector.
+
+JDBC Driver
+-----------
+
+The official MySQL JDBC driver is at
+http://dev.mysql.com/downloads/connector/j/.
+
+"""
+import decimal
 import re
 
-from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
+from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext
 from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
 from sqlalchemy import types as sqltypes, util
 
@@ -15,17 +25,36 @@ class MySQL_jdbcExecutionContext(MySQLExecutionContext):
         cursor.close()
         return lastrowid
 
-jdbc_colspecs = MySQLDialect.colspecs.copy()
-# Time's conversion not applicable
-jdbc_colspecs.pop(sqltypes.Time)
+
+class _JDBCBit(BIT):
+    def result_processor(self, dialect):
+        """Converts boolean or byte arrays from MySQL Connector/J to longs."""
+        def process(value):
+            if value is None:
+                return value
+            if isinstance(value, bool):
+                return int(value)
+            v = 0L
+            for i in value:
+                v = v << 8 | (i & 0xff)
+            value = v
+            return value
+        return process
+
 
 class MySQL_jdbc(ZxJDBCConnector, MySQLDialect):
     execution_ctx_cls = MySQL_jdbcExecutionContext
 
     jdbc_db_name = 'mysql'
-    jdbc_driver_name = "org.gjt.mm.mysql.Driver"
+    jdbc_driver_name = "com.mysql.jdbc.Driver"
 
-    colspecs = jdbc_colspecs
+    colspecs = util.update_copy(
+        MySQLDialect.colspecs,
+        {
+            sqltypes.Time: sqltypes.Time,
+            BIT: _JDBCBit
+        }
+    )
     
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""