]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- more or less pg8000 support. has a rough time with non-ascii data.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Jan 2009 19:31:28 +0000 (19:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Jan 2009 19:31:28 +0000 (19:31 +0000)
- removed "send unicode straight through" logic from sqlite, this becomes
base dialect configurable
- simplfied Interval type to not have awareness of PG dialect.  dialects
can name TypeDecorator classes in their colspecs dict.

16 files changed:
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/postgres/pg8000.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgres/psycopg2.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/engine/reconnect.py
test/sql/query.py
test/sql/testtypes.py
test/testlib/requires.py

index e7e250762cebbf143fbe5aa48f5faa549573e083..3c66945e80afa37c68219c144096f06c270f09dd 100644 (file)
@@ -1283,16 +1283,19 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
         return AUTOCOMMIT_RE.match(statement)
 
 class MySQLCompiler(compiler.SQLCompiler):
-    operators = compiler.SQLCompiler.operators.copy()
-    operators.update({
-        sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
-        sql_operators.mod: '%%',
-        sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
-    })
-    functions = compiler.SQLCompiler.functions.copy()
-    functions.update ({
-        sql_functions.random: 'rand%(expr)s',
-        "utc_timestamp":"UTC_TIMESTAMP"
+    operators = util.update_copy(
+        compiler.SQLCompiler.operators,
+        {
+            sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
+            sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
+        }
+    )
+
+    functions = util.update_copy(
+        compiler.SQLCompiler.functions,
+        {
+            sql_functions.random: 'rand%(expr)s',
+            "utc_timestamp":"UTC_TIMESTAMP"
         })
 
     def visit_typeclause(self, typeclause):
index 6ad8d044735ffd519d0fe85e4f98abdb6a57eda0..61f9d3f6719d3015eea897ae1ce362fbf20168d4 100644 (file)
@@ -22,6 +22,8 @@ strings, also pass ``use_unicode=0`` in the connection arguments::
 
 from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext, MySQLCompiler
 from sqlalchemy.engine import base as engine_base, default
+from sqlalchemy.sql import operators as sql_operators
+
 from sqlalchemy import exc, log, schema, sql, util
 import re
 
@@ -30,6 +32,13 @@ class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
         return cursor.lastrowid
 
 class MySQL_mysqldbCompiler(MySQLCompiler):
+    operators = util.update_copy(
+        MySQLCompiler.operators,
+        {
+            sql_operators.mod: '%%',
+        }
+    )
+    
     def post_process_text(self, text):
         if '%%' in text:
             util.warn("The SQLAlchemy mysql+mysqldb dialect now automatically escapes '%' in text() expressions to '%%'.")
@@ -40,7 +49,7 @@ class MySQL_mysqldb(MySQLDialect):
     supports_unicode_statements = False
     default_paramstyle = 'format'
     execution_ctx_cls = MySQL_mysqldbExecutionContext
-    sql_compiler = MySQL_mysqldbCompiler
+    statement_compiler = MySQL_mysqldbCompiler
     
     @classmethod
     def dbapi(cls):
@@ -102,7 +111,10 @@ class MySQL_mysqldb(MySQLDialect):
         return tuple(version)
 
     def _extract_error_code(self, exception):
-        return exception.orig.args[0]
+        try:
+            return exception.orig.args[0]
+        except AttributeError:
+            return None
 
     @engine_base.connection_memoize(('mysql', 'charset'))
     def _detect_charset(self, connection):
index 15ed21c77d58f3e00c9c355779bd4d3e5311e9e7..8fd4ef5ef284fd39ff0962fdd7a72d7e591ecf4d 100644 (file)
@@ -150,38 +150,69 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
         return process
 
 
+colspecs = {
+    sqltypes.Interval:PGInterval
+}
+
+ischema_names = {
+    'integer' : sqltypes.Integer,
+    'bigint' : PGBigInteger,
+    'smallint' : sqltypes.SmallInteger,
+    'character varying' : sqltypes.String,
+    'character' : sqltypes.CHAR,
+    'text' : sqltypes.Text,
+    'numeric' : sqltypes.Numeric,
+    'float' : sqltypes.Float,
+    'real' : sqltypes.Float,
+    'inet': PGInet,
+    'cidr': PGCidr,
+    'macaddr': PGMacAddr,
+    'double precision' : sqltypes.Float,
+    'timestamp' : sqltypes.DateTime,
+    'timestamp with time zone' : sqltypes.DateTime,
+    'timestamp without time zone' : sqltypes.DateTime,
+    'time with time zone' : sqltypes.Time,
+    'time without time zone' : sqltypes.Time,
+    'date' : sqltypes.Date,
+    'time': sqltypes.Time,
+    'bytea' : sqltypes.Binary,
+    'boolean' : sqltypes.Boolean,
+    'interval':PGInterval,
+}
 
 
 
 class PGCompiler(compiler.SQLCompiler):
-    operators = compiler.SQLCompiler.operators.copy()
-    operators.update(
+    
+    operators = util.update_copy(
+        compiler.SQLCompiler.operators,
         {
             sql_operators.mod : '%%',
+        
             sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
             sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
             sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
         }
     )
 
-    functions = compiler.SQLCompiler.functions.copy()
-    functions.update (
+    functions = util.update_copy(
+        compiler.SQLCompiler.functions,
         {
             'TIMESTAMP':lambda x:'TIMESTAMP %s' % x,
         }
     )
 
+    def post_process_text(self, text):
+        if '%%' in text:
+            util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
+        return text.replace('%', '%%')
+
     def visit_sequence(self, seq):
         if seq.optional:
             return None
         else:
             return "nextval('%s')" % self.preparer.format_sequence(seq)
 
-    def post_process_text(self, text):
-        if '%%' in text:
-            util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
-        return text.replace('%', '%%')
-
     def limit_clause(self, select):
         text = ""
         if select._limit is not None:
@@ -369,7 +400,9 @@ class PGDialect(default.DefaultDialect):
     supports_default_values = True
     supports_empty_insert = False
     default_paramstyle = 'pyformat'
-
+    ischema_names = ischema_names
+    colspecs = colspecs
+    
     statement_compiler = PGCompiler
     ddl_compiler = PGDDLCompiler
     type_compiler = PGTypeCompiler
@@ -457,6 +490,9 @@ class PGDialect(default.DefaultDialect):
             raise AssertionError("Could not determine version from string '%s'" % v)
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, self.colspecs)
+
     def reflecttable(self, connection, table, include_columns):
         preparer = self.identifier_preparer
         if table.schema is not None:
diff --git a/lib/sqlalchemy/dialects/postgres/pg8000.py b/lib/sqlalchemy/dialects/postgres/pg8000.py
new file mode 100644 (file)
index 0000000..43ed3ee
--- /dev/null
@@ -0,0 +1,77 @@
+"""Support for the PostgreSQL database via the pg8000.
+
+Connecting
+----------
+
+URLs are of the form `postgres+pg8000://user@password@host:port/dbname[?key=value&key=value...]`.
+
+Unicode
+-------
+
+Unicode data which contains non-ascii characters don't seem to be supported yet.  non-ascii
+schema identifiers though *are* supported, if you set the client_encoding=utf8 in the postgresql.conf 
+file.
+
+Interval
+--------
+
+Passing data from/to the Interval type is not supported as of yet.
+
+"""
+
+import decimal, random, re, string
+
+from sqlalchemy import sql, schema, exc, util
+from sqlalchemy.engine import base, default
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \
+ PGBigInteger, PGInterval
+
+class PGNumeric(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return None
+        else:
+            def process(value):
+                if isinstance(value, decimal.Decimal):
+                    return float(value)
+                else:
+                    return value
+            return process
+
+class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext):
+    pass
+
+class Postgres_pg8000(PGDialect):
+    driver = 'pg8000'
+
+    supports_unicode_statements = False #True
+    
+    # this one doesn't matter, cant pass non-ascii through
+    # pending further investigation
+    supports_unicode_binds = False #True
+    
+    default_paramstyle = 'format'
+    supports_sane_multi_rowcount = False
+    execution_ctx_cls = Postgres_pg8000ExecutionContext
+    
+    @classmethod
+    def dbapi(cls):
+        return __import__('pg8000').dbapi
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        if 'port' in opts:
+            opts['port'] = int(opts['port'])
+        opts.update(url.query)
+        return ([], opts)
+
+    def is_disconnect(self, e):
+        return "connection is closed" in e
+
+dialect = Postgres_pg8000
index bd0815a3f319343c8570f516ad96579b4900d244..f46da2182706769ae306addfff5ae6865f3d121f 100644 (file)
@@ -39,8 +39,8 @@ from sqlalchemy.engine import base, default
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \
- PGBigInteger, PGInterval
+from sqlalchemy.dialects.postgres.base import PGDialect, PGCompiler, PGInet, PGCidr, PGMacAddr, PGArray, \
+ PGBigInteger, PGInterval, colspecs
 
 class PGNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
@@ -58,36 +58,11 @@ class PGNumeric(sqltypes.Numeric):
             return process
 
 
-colspecs = {
+colspecs = PGDialect.colspecs.copy()
+colspecs.update({
     sqltypes.Numeric : PGNumeric,
     sqltypes.Float: sqltypes.Float,  # prevents PGNumeric from being used
-}
-
-ischema_names = {
-    'integer' : sqltypes.Integer,
-    'bigint' : PGBigInteger,
-    'smallint' : sqltypes.SmallInteger,
-    'character varying' : sqltypes.String,
-    'character' : sqltypes.CHAR,
-    'text' : sqltypes.Text,
-    'numeric' : PGNumeric,
-    'float' : sqltypes.Float,
-    'real' : sqltypes.Float,
-    'inet': PGInet,
-    'cidr': PGCidr,
-    'macaddr': PGMacAddr,
-    'double precision' : sqltypes.Float,
-    'timestamp' : sqltypes.DateTime,
-    'timestamp with time zone' : sqltypes.DateTime,
-    'timestamp without time zone' : sqltypes.DateTime,
-    'time with time zone' : sqltypes.Time,
-    'time without time zone' : sqltypes.Time,
-    'date' : sqltypes.Date,
-    'time': sqltypes.Time,
-    'bytea' : sqltypes.Binary,
-    'boolean' : sqltypes.Boolean,
-    'interval':PGInterval,
-}
+})
 
 # TODO: filter out 'FOR UPDATE' statements
 SERVER_SIDE_CURSOR_RE = re.compile(
@@ -122,13 +97,26 @@ class Postgres_psycopg2ExecutionContext(default.DefaultExecutionContext):
         else:
             return base.ResultProxy(self)
 
+class Postgres_psycopg2Compiler(PGCompiler):
+    operators = util.update_copy(
+        PGCompiler.operators, 
+        {
+            sql_operators.mod : '%%',
+        }
+    )
+    
+    def post_process_text(self, text):
+        if '%%' in text:
+            util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
+        return text.replace('%', '%%')
+
 class Postgres_psycopg2(PGDialect):
     driver = 'psycopg2'
     supports_unicode_statements = False
     default_paramstyle = 'pyformat'
     supports_sane_multi_rowcount = False
     execution_ctx_cls = Postgres_psycopg2ExecutionContext
-    ischema_names = ischema_names
+    statement_compiler = Postgres_psycopg2Compiler
     
     def __init__(self, server_side_cursors=False, **kwargs):
         PGDialect.__init__(self, **kwargs)
index ba08ccbb90fb73729cb493f74133b183e43a3900..773501d64c5164e8a00e91793acbbf82aa9c347a 100644 (file)
@@ -136,6 +136,36 @@ class SLBoolean(sqltypes.Boolean):
             return value and True or False
         return process
 
+colspecs = {
+    sqltypes.Boolean: SLBoolean,
+    sqltypes.Date: SLDate,
+    sqltypes.DateTime: SLDateTime,
+    sqltypes.Float: SLFloat,
+    sqltypes.Numeric: SLNumeric,
+    sqltypes.Time: SLTime,
+}
+
+ischema_names = {
+    'BLOB': sqltypes.Binary,
+    'BOOL': sqltypes.Boolean,
+    'BOOLEAN': sqltypes.Boolean,
+    'CHAR': sqltypes.CHAR,
+    'DATE': sqltypes.Date,
+    'DATETIME': sqltypes.DateTime,
+    'DECIMAL': sqltypes.Numeric,
+    'FLOAT': sqltypes.Numeric,
+    'INT': sqltypes.Integer,
+    'INTEGER': sqltypes.Integer,
+    'NUMERIC': sqltypes.Numeric,
+    'REAL': sqltypes.Numeric,
+    'SMALLINT': sqltypes.SmallInteger,
+    'TEXT': sqltypes.Text,
+    'TIME': sqltypes.Time,
+    'TIMESTAMP': sqltypes.DateTime,
+    'VARCHAR': sqltypes.String,
+}
+
+
 
 class SQLiteCompiler(compiler.SQLCompiler):
     functions = compiler.SQLCompiler.functions.copy()
@@ -216,6 +246,7 @@ class SQLiteDialect(default.DefaultDialect):
     name = 'sqlite'
     supports_alter = False
     supports_unicode_statements = True
+    supports_unicode_binds = True
     supports_default_values = True
     supports_empty_insert = False
     supports_cast = True
@@ -224,6 +255,10 @@ class SQLiteDialect(default.DefaultDialect):
     ddl_compiler = SQLiteDDLCompiler
     type_compiler = SQLiteTypeCompiler
     preparer = SQLiteIdentifierPreparer
+    ischema_names = ischema_names
+
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, colspecs)
 
     def table_names(self, connection, schema):
         if schema is not None:
index b00f9e7a00961b7bd8255685df6e542418a56a9f..b4b9ca33d0a4da47858f0a7cc7d84ad371b51b29 100644 (file)
@@ -104,85 +104,13 @@ always represented by an actual database result string.
 
 """
 
-from sqlalchemy.dialects.sqlite.base import SLNumeric, SLFloat, SQLiteDialect, SLBoolean, SLDate, SLDateTime, SLTime
+from sqlalchemy.dialects.sqlite.base import SQLiteDialect
 from sqlalchemy import schema, exc, pool
 from sqlalchemy.engine import default
 from sqlalchemy import types as sqltypes
 from sqlalchemy import util
 from types import NoneType
 
-class SLUnicodeMixin(object):
-    def bind_processor(self, dialect):
-        if self.convert_unicode or dialect.convert_unicode:
-            if self.assert_unicode is None:
-                assert_unicode = dialect.assert_unicode
-            else:
-                assert_unicode = self.assert_unicode
-                
-            if not assert_unicode:
-                return None
-                
-            def process(value):
-                if not isinstance(value, (unicode, NoneType)):
-                    if assert_unicode == 'warn':
-                        util.warn("Unicode type received non-unicode bind "
-                                  "param value %r" % value)
-                        return value
-                    else:
-                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
-                else:
-                    return value
-            return process
-        else:
-            return None
-
-    def result_processor(self, dialect):
-        return None
-    
-class SLText(SLUnicodeMixin, sqltypes.Text):
-    pass
-
-class SLString(SLUnicodeMixin, sqltypes.String):
-    pass
-
-class SLChar(SLUnicodeMixin, sqltypes.CHAR):
-    pass
-
-
-colspecs = {
-    sqltypes.Boolean: SLBoolean,
-    sqltypes.CHAR: SLChar,
-    sqltypes.Date: SLDate,
-    sqltypes.DateTime: SLDateTime,
-    sqltypes.Float: SLFloat,
-    sqltypes.NCHAR: SLChar,
-    sqltypes.Numeric: SLNumeric,
-    sqltypes.String: SLString,
-    sqltypes.Text: SLText,
-    sqltypes.Time: SLTime,
-}
-
-ischema_names = {
-    'BLOB': sqltypes.Binary,
-    'BOOL': SLBoolean,
-    'BOOLEAN': SLBoolean,
-    'CHAR': SLChar,
-    'DATE': SLDate,
-    'DATETIME': SLDateTime,
-    'DECIMAL': SLNumeric,
-    'FLOAT': SLNumeric,
-    'INT': sqltypes.Integer,
-    'INTEGER': sqltypes.Integer,
-    'NUMERIC': SLNumeric,
-    'REAL': SLNumeric,
-    'SMALLINT': sqltypes.SmallInteger,
-    'TEXT': SLText,
-    'TIME': SLTime,
-    'TIMESTAMP': SLDateTime,
-    'VARCHAR': SLString,
-}
-
-
 class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
         if self.isinsert and not self.executemany:
@@ -195,7 +123,6 @@ class SQLite_pysqlite(SQLiteDialect):
     poolclass = pool.SingletonThreadPool
     execution_ctx_cls = SQLite_pysqliteExecutionContext
     driver = 'pysqlite'
-    ischema_names = ischema_names
     
     def __init__(self, **kwargs):
         SQLiteDialect.__init__(self, **kwargs)
@@ -246,9 +173,6 @@ class SQLite_pysqlite(SQLiteDialect):
 
         return ([filename], opts)
 
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
     def is_disconnect(self, e):
         return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
 
index 535c5fc1c8b54bf70121e378377608db1ee37b68..f3acc28597358a4c9a62c3740f27a963a06298f7 100644 (file)
@@ -83,6 +83,9 @@ class Dialect(object):
     supports_unicode_statements
       Indicate whether the DB-API can receive SQL statements as Python unicode strings
 
+    supports_unicode_binds
+      Indicate whether the DB-API can receive string bind parameters as Python unicode strings
+
     supports_sane_rowcount
       Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements.
 
index 12b1661925da9cb5a919c0dab4193a8b98add426..1dc3d720ef8aae6c5ba2d140034d0d78079c14c8 100644 (file)
@@ -32,6 +32,8 @@ class DefaultDialect(base.Dialect):
     supports_sequences = False
     sequences_optional = False
     supports_unicode_statements = False
+    supports_unicode_binds = False
+    
     max_identifier_length = 9999
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
index d00a05436a9e5eeed378808c9cef2c63818d660b..683831998737177105e427d56fcf2ae25196ca4b 100644 (file)
@@ -1021,7 +1021,7 @@ class GenericTypeCompiler(engine.TypeCompiler):
         raise NotImplementedError("Can't generate DDL for the null type")
         
     def visit_type_decorator(self, type_):
-        return self.process(type_.dialect_impl(self.dialect).impl)
+        return self.process(type_.type_engine(self.dialect))
         
     def visit_user_defined(self, type_):
         return type_.get_col_spec()
index 986d3d1332d811ca026c0d4317a0758e6c1c0023..92ee125b631915df91a154537f976bde871f5ea0 100644 (file)
@@ -86,6 +86,14 @@ class AbstractType(Visitable):
         """
         return op
 
+    def get_search_list(self):
+        """return a list of classes to test for a match
+        when adapting this type to a dialect-specific type.
+
+        """
+
+        return self.__class__.__mro__[0:-1]
+
     def __repr__(self):
         return "%s(%s)" % (
             self.__class__.__name__,
@@ -136,14 +144,6 @@ class TypeEngine(AbstractType):
     def adapt(self, cls):
         return cls()
 
-    def get_search_list(self):
-        """return a list of classes to test for a match
-        when adapting this type to a dialect-specific type.
-
-        """
-
-        return self.__class__.__mro__[0:-1]
-
 class UserDefinedType(TypeEngine):
     """Base for user defined types.
     
@@ -227,7 +227,7 @@ class TypeDecorator(AbstractType):
             raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
         self.impl = self.__class__.impl(*args, **kwargs)
 
-    def dialect_impl(self, dialect, **kwargs):
+    def dialect_impl(self, dialect):
         try:
             return self._impl_dict[dialect]
         except AttributeError:
@@ -235,6 +235,17 @@ class TypeDecorator(AbstractType):
         except KeyError:
             pass
 
+        # adapt the TypeDecorator first, in 
+        # the case that the dialect maps the TD
+        # to one of its native types (i.e. PGInterval)
+        adapted = dialect.type_descriptor(self)
+        if adapted is not self:
+            self._impl_dict[dialect] = adapted
+            return adapted
+        
+        # otherwise adapt the impl type, link
+        # to a copy of this TypeDecorator and return
+        # that.
         typedesc = self.load_dialect_impl(dialect)
         tt = self.copy()
         if not isinstance(tt, self.__class__):
@@ -244,13 +255,20 @@ class TypeDecorator(AbstractType):
         self._impl_dict[dialect] = tt
         return tt
 
+    def type_engine(self, dialect):
+        impl = self.dialect_impl(dialect)
+        if not isinstance(impl, TypeDecorator):
+            return impl
+        else:
+            return impl.impl
+
     def load_dialect_impl(self, dialect):
         """Loads the dialect-specific implementation of this type.
 
         by default calls dialect.type_descriptor(self.impl), but
         can be overridden to provide different behavior.
+        
         """
-
         if isinstance(self.impl, TypeDecorator):
             return self.impl.dialect_impl(dialect)
         else:
@@ -452,18 +470,33 @@ class String(Concatenable, TypeEngine):
                 assert_unicode = dialect.assert_unicode
             else:
                 assert_unicode = self.assert_unicode
-            def process(value):
-                if isinstance(value, unicode):
-                    return value.encode(dialect.encoding)
-                elif assert_unicode and not isinstance(value, (unicode, NoneType)):
-                    if assert_unicode == 'warn':
-                        util.warn("Unicode type received non-unicode bind "
-                                  "param value %r" % value)
+            
+            if dialect.supports_unicode_binds and assert_unicode:
+                def process(value):
+                    if not isinstance(value, (unicode, NoneType)):
+                        if assert_unicode == 'warn':
+                            util.warn("Unicode type received non-unicode bind "
+                                      "param value %r" % value)
+                            return value
+                        else:
+                            raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+                    else:
                         return value
+            elif dialect.supports_unicode_binds:
+                return None
+            else:
+                def process(value):
+                    if isinstance(value, unicode):
+                        return value.encode(dialect.encoding)
+                    elif assert_unicode and not isinstance(value, (unicode, NoneType)):
+                        if assert_unicode == 'warn':
+                            util.warn("Unicode type received non-unicode bind "
+                                      "param value %r" % value)
+                            return value
+                        else:
+                            raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
                     else:
-                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
-                else:
-                    return value
+                        return value
             return process
         else:
             return None
@@ -492,9 +525,6 @@ class Text(String):
     """
     
     __visit_name__ = 'text'
-    
-    def dialect_impl(self, dialect, **kwargs):
-        return TypeEngine.dialect_impl(self, dialect, **kwargs)
 
 class Unicode(String):
     """A variable length Unicode string.
@@ -840,35 +870,17 @@ class Interval(TypeDecorator):
 
     """
 
-    impl = TypeEngine
-
-    def __init__(self):
-        super(Interval, self).__init__()
-        import sqlalchemy.dialects.postgres.base as pg
-        self.__supported = {pg.PGDialect:pg.PGInterval}
-        del pg
-
-    def load_dialect_impl(self, dialect):
-        if dialect.__class__ in self.__supported:
-            return self.__supported[dialect.__class__]()
-        else:
-            return dialect.type_descriptor(DateTime)
+    impl = DateTime
 
     def process_bind_param(self, value, dialect):
-        if dialect.__class__ in self.__supported:
-            return value
-        else:
-            if value is None:
-                return None
-            return dt.datetime.utcfromtimestamp(0) + value
+        if value is None:
+            return None
+        return dt.datetime.utcfromtimestamp(0) + value
 
     def process_result_value(self, value, dialect):
-        if dialect.__class__ in self.__supported:
-            return value
-        else:
-            if value is None:
-                return None
-            return value - dt.datetime.utcfromtimestamp(0)
+        if value is None:
+            return None
+        return value - dt.datetime.utcfromtimestamp(0)
 
 class FLOAT(Float):
     """The SQL FLOAT type."""
index 12f155d606420157320f88616aa992bc64675621..8e810ffa6afa4bd48040eb5f2ded9173af21f027 100644 (file)
@@ -265,6 +265,15 @@ else:
     def decode_slice(slc):
         return (slc.start, slc.stop, slc.step)
 
+def update_copy(d, _new=None, **kw):
+    """Copy the given dict and update with the given values."""
+    
+    d = d.copy()
+    if _new:
+        d.update(_new)
+    d.update(**kw)
+    return d
+    
 def flatten_iterator(x):
     """Given an iterator of which further sub-elements may also be
     iterators, flatten the sub-elements into a single iterator.
index 4f383d2dde6e5f88c7e51c522e82ce194978082e..10c80e13526083eda8ffbb4a685bb4dd0572123b 100644 (file)
@@ -332,7 +332,8 @@ class InvalidateDuringResultTest(TestBase):
         meta.drop_all()
         engine.dispose()
 
-    @testing.fails_on('mysql', 'FIXME: unknown')
+    @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close")
+    @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close")
     def test_invalidate_on_results(self):
         conn = engine.connect()
 
@@ -342,7 +343,7 @@ class InvalidateDuringResultTest(TestBase):
 
         engine.test_shutdown()
         try:
-            result.fetchone()
+            print "ghost result: %r" % result.fetchone()
             assert False
         except tsa.exc.DBAPIError, e:
             if not e.connection_invalidated:
index bf178ae8f541e02553279a60f3327dfc9d4e0faf..660529c25c194ea9bf0e0d736d625cdb312cea6a 100644 (file)
@@ -252,6 +252,7 @@ class QueryTest(TestBase):
             eq_(expr.execute().fetchall(), result)
     
 
+    @testing.fails_on("+pg8000", "can't interpret result column from '%%'")
     @testing.emits_warning('.*now automatically escapes.*')
     def test_percents_in_text(self):
         for expr, result in (
index ca22fcb270ce57c3651d91157d55230163dd4b21..39f79e540c6024c124904eecfb80449ed5acbcbd 100644 (file)
@@ -291,7 +291,8 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
         assert unicode_table.c.unicode_varchar.type.length == 250
         rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
         unicodedata = rawdata.decode('utf-8')
-        if testing.against('sqlite'):
+        
+        if testing.against('sqlite', '>' '2.4'):
             rawdata = "something"
             
         unicode_table.insert().execute(unicode_varchar=unicodedata,
@@ -300,12 +301,12 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
         x = unicode_table.select().execute().fetchone()
         self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
         self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
+
         if isinstance(x['plain_varchar'], unicode):
             # SQLLite and MSSQL return non-unicode data as unicode
             self.assert_(testing.against('sqlite', '+pyodbc'))
             if not testing.against('sqlite'):
                 self.assert_(x['plain_varchar'] == unicodedata)
-            print "it's %s!" % testing.db.name
         else:
             self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
 
@@ -778,6 +779,7 @@ class IntervalTest(TestBase, AssertsExecutionResults):
     def tearDownAll(self):
         metadata.drop_all()
 
+    @testing.fails_on("+pg8000", "Not yet known how to pass values of the INTERVAL type")
     def test_roundtrip(self):
         delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17)
         interval_table.insert().execute(interval=delta)
index 200fb01b112ca7be3b2775bc63d5873da8e415f2..4ccce962050af2cff9cc2948755c644ccb454414 100644 (file)
@@ -98,6 +98,7 @@ def two_phase_transactions(fn):
         fn,
         no_support('access', 'not supported by database'),
         no_support('firebird', 'no SA implementation'),
+        no_support('+pg8000', 'FIXME: not sure how to accomplish'),
         no_support('maxdb', 'not supported by database'),
         no_support('mssql', 'FIXME: guessing, needs confirmation'),
         no_support('oracle', 'no SA implementation'),