From 3a7eb0b1b8299ee43376fbada30b2be4319cb695 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 16 Jan 2009 19:31:28 +0000 Subject: [PATCH] - more or less pg8000 support. has a rough time with non-ascii data. - removed "send unicode straight through" logic from sqlite, this becomes base dialect configurable - simplfied Interval type to not have awareness of PG dialect. dialects can name TypeDecorator classes in their colspecs dict. --- lib/sqlalchemy/dialects/mysql/base.py | 23 ++-- lib/sqlalchemy/dialects/mysql/mysqldb.py | 16 ++- lib/sqlalchemy/dialects/postgres/base.py | 56 ++++++++-- lib/sqlalchemy/dialects/postgres/pg8000.py | 77 +++++++++++++ lib/sqlalchemy/dialects/postgres/psycopg2.py | 50 ++++----- lib/sqlalchemy/dialects/sqlite/base.py | 35 ++++++ lib/sqlalchemy/dialects/sqlite/pysqlite.py | 78 +------------- lib/sqlalchemy/engine/base.py | 3 + lib/sqlalchemy/engine/default.py | 2 + lib/sqlalchemy/sql/compiler.py | 2 +- lib/sqlalchemy/types.py | 108 ++++++++++--------- lib/sqlalchemy/util.py | 9 ++ test/engine/reconnect.py | 5 +- test/sql/query.py | 1 + test/sql/testtypes.py | 6 +- test/testlib/requires.py | 1 + 16 files changed, 289 insertions(+), 183 deletions(-) create mode 100644 lib/sqlalchemy/dialects/postgres/pg8000.py diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e7e250762c..3c66945e80 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1283,16 +1283,19 @@ class MySQLExecutionContext(default.DefaultExecutionContext): return AUTOCOMMIT_RE.match(statement) class MySQLCompiler(compiler.SQLCompiler): - operators = compiler.SQLCompiler.operators.copy() - operators.update({ - sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), - sql_operators.mod: '%%', - sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) - }) - functions = compiler.SQLCompiler.functions.copy() - functions.update ({ - sql_functions.random: 'rand%(expr)s', - "utc_timestamp":"UTC_TIMESTAMP" + operators = util.update_copy( + compiler.SQLCompiler.operators, + { + sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), + sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) + } + ) + + functions = util.update_copy( + compiler.SQLCompiler.functions, + { + sql_functions.random: 'rand%(expr)s', + "utc_timestamp":"UTC_TIMESTAMP" }) def visit_typeclause(self, typeclause): diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 6ad8d04473..61f9d3f671 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -22,6 +22,8 @@ strings, also pass ``use_unicode=0`` in the connection arguments:: from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext, MySQLCompiler from sqlalchemy.engine import base as engine_base, default +from sqlalchemy.sql import operators as sql_operators + from sqlalchemy import exc, log, schema, sql, util import re @@ -30,6 +32,13 @@ class MySQL_mysqldbExecutionContext(MySQLExecutionContext): return cursor.lastrowid class MySQL_mysqldbCompiler(MySQLCompiler): + operators = util.update_copy( + MySQLCompiler.operators, + { + sql_operators.mod: '%%', + } + ) + def post_process_text(self, text): if '%%' in text: util.warn("The SQLAlchemy mysql+mysqldb dialect now automatically escapes '%' in text() expressions to '%%'.") @@ -40,7 +49,7 @@ class MySQL_mysqldb(MySQLDialect): supports_unicode_statements = False default_paramstyle = 'format' execution_ctx_cls = MySQL_mysqldbExecutionContext - sql_compiler = MySQL_mysqldbCompiler + statement_compiler = MySQL_mysqldbCompiler @classmethod def dbapi(cls): @@ -102,7 +111,10 @@ class MySQL_mysqldb(MySQLDialect): return tuple(version) def _extract_error_code(self, exception): - return exception.orig.args[0] + try: + return exception.orig.args[0] + except AttributeError: + return None @engine_base.connection_memoize(('mysql', 'charset')) def _detect_charset(self, connection): diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 15ed21c77d..8fd4ef5ef2 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -150,38 +150,69 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): return process +colspecs = { + sqltypes.Interval:PGInterval +} + +ischema_names = { + 'integer' : sqltypes.Integer, + 'bigint' : PGBigInteger, + 'smallint' : sqltypes.SmallInteger, + 'character varying' : sqltypes.String, + 'character' : sqltypes.CHAR, + 'text' : sqltypes.Text, + 'numeric' : sqltypes.Numeric, + 'float' : sqltypes.Float, + 'real' : sqltypes.Float, + 'inet': PGInet, + 'cidr': PGCidr, + 'macaddr': PGMacAddr, + 'double precision' : sqltypes.Float, + 'timestamp' : sqltypes.DateTime, + 'timestamp with time zone' : sqltypes.DateTime, + 'timestamp without time zone' : sqltypes.DateTime, + 'time with time zone' : sqltypes.Time, + 'time without time zone' : sqltypes.Time, + 'date' : sqltypes.Date, + 'time': sqltypes.Time, + 'bytea' : sqltypes.Binary, + 'boolean' : sqltypes.Boolean, + 'interval':PGInterval, +} class PGCompiler(compiler.SQLCompiler): - operators = compiler.SQLCompiler.operators.copy() - operators.update( + + operators = util.update_copy( + compiler.SQLCompiler.operators, { sql_operators.mod : '%%', + sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y), } ) - functions = compiler.SQLCompiler.functions.copy() - functions.update ( + functions = util.update_copy( + compiler.SQLCompiler.functions, { 'TIMESTAMP':lambda x:'TIMESTAMP %s' % x, } ) + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.") + return text.replace('%', '%%') + def visit_sequence(self, seq): if seq.optional: return None else: return "nextval('%s')" % self.preparer.format_sequence(seq) - def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.") - return text.replace('%', '%%') - def limit_clause(self, select): text = "" if select._limit is not None: @@ -369,7 +400,9 @@ class PGDialect(default.DefaultDialect): supports_default_values = True supports_empty_insert = False default_paramstyle = 'pyformat' - + ischema_names = ischema_names + colspecs = colspecs + statement_compiler = PGCompiler ddl_compiler = PGDDLCompiler type_compiler = PGTypeCompiler @@ -457,6 +490,9 @@ class PGDialect(default.DefaultDialect): raise AssertionError("Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3)]) + def type_descriptor(self, typeobj): + return sqltypes.adapt_type(typeobj, self.colspecs) + def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer if table.schema is not None: diff --git a/lib/sqlalchemy/dialects/postgres/pg8000.py b/lib/sqlalchemy/dialects/postgres/pg8000.py new file mode 100644 index 0000000000..43ed3ee1d4 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgres/pg8000.py @@ -0,0 +1,77 @@ +"""Support for the PostgreSQL database via the pg8000. + +Connecting +---------- + +URLs are of the form `postgres+pg8000://user@password@host:port/dbname[?key=value&key=value...]`. + +Unicode +------- + +Unicode data which contains non-ascii characters don't seem to be supported yet. non-ascii +schema identifiers though *are* supported, if you set the client_encoding=utf8 in the postgresql.conf +file. + +Interval +-------- + +Passing data from/to the Interval type is not supported as of yet. + +""" + +import decimal, random, re, string + +from sqlalchemy import sql, schema, exc, util +from sqlalchemy.engine import base, default +from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \ + PGBigInteger, PGInterval + +class PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + +class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext): + pass + +class Postgres_pg8000(PGDialect): + driver = 'pg8000' + + supports_unicode_statements = False #True + + # this one doesn't matter, cant pass non-ascii through + # pending further investigation + supports_unicode_binds = False #True + + default_paramstyle = 'format' + supports_sane_multi_rowcount = False + execution_ctx_cls = Postgres_pg8000ExecutionContext + + @classmethod + def dbapi(cls): + return __import__('pg8000').dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + return "connection is closed" in e + +dialect = Postgres_pg8000 diff --git a/lib/sqlalchemy/dialects/postgres/psycopg2.py b/lib/sqlalchemy/dialects/postgres/psycopg2.py index bd0815a3f3..f46da21827 100644 --- a/lib/sqlalchemy/dialects/postgres/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgres/psycopg2.py @@ -39,8 +39,8 @@ from sqlalchemy.engine import base, default from sqlalchemy.sql import compiler, expression from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes -from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \ - PGBigInteger, PGInterval +from sqlalchemy.dialects.postgres.base import PGDialect, PGCompiler, PGInet, PGCidr, PGMacAddr, PGArray, \ + PGBigInteger, PGInterval, colspecs class PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): @@ -58,36 +58,11 @@ class PGNumeric(sqltypes.Numeric): return process -colspecs = { +colspecs = PGDialect.colspecs.copy() +colspecs.update({ sqltypes.Numeric : PGNumeric, sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used -} - -ischema_names = { - 'integer' : sqltypes.Integer, - 'bigint' : PGBigInteger, - 'smallint' : sqltypes.SmallInteger, - 'character varying' : sqltypes.String, - 'character' : sqltypes.CHAR, - 'text' : sqltypes.Text, - 'numeric' : PGNumeric, - 'float' : sqltypes.Float, - 'real' : sqltypes.Float, - 'inet': PGInet, - 'cidr': PGCidr, - 'macaddr': PGMacAddr, - 'double precision' : sqltypes.Float, - 'timestamp' : sqltypes.DateTime, - 'timestamp with time zone' : sqltypes.DateTime, - 'timestamp without time zone' : sqltypes.DateTime, - 'time with time zone' : sqltypes.Time, - 'time without time zone' : sqltypes.Time, - 'date' : sqltypes.Date, - 'time': sqltypes.Time, - 'bytea' : sqltypes.Binary, - 'boolean' : sqltypes.Boolean, - 'interval':PGInterval, -} +}) # TODO: filter out 'FOR UPDATE' statements SERVER_SIDE_CURSOR_RE = re.compile( @@ -122,13 +97,26 @@ class Postgres_psycopg2ExecutionContext(default.DefaultExecutionContext): else: return base.ResultProxy(self) +class Postgres_psycopg2Compiler(PGCompiler): + operators = util.update_copy( + PGCompiler.operators, + { + sql_operators.mod : '%%', + } + ) + + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.") + return text.replace('%', '%%') + class Postgres_psycopg2(PGDialect): driver = 'psycopg2' supports_unicode_statements = False default_paramstyle = 'pyformat' supports_sane_multi_rowcount = False execution_ctx_cls = Postgres_psycopg2ExecutionContext - ischema_names = ischema_names + statement_compiler = Postgres_psycopg2Compiler def __init__(self, server_side_cursors=False, **kwargs): PGDialect.__init__(self, **kwargs) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index ba08ccbb90..773501d64c 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -136,6 +136,36 @@ class SLBoolean(sqltypes.Boolean): return value and True or False return process +colspecs = { + sqltypes.Boolean: SLBoolean, + sqltypes.Date: SLDate, + sqltypes.DateTime: SLDateTime, + sqltypes.Float: SLFloat, + sqltypes.Numeric: SLNumeric, + sqltypes.Time: SLTime, +} + +ischema_names = { + 'BLOB': sqltypes.Binary, + 'BOOL': sqltypes.Boolean, + 'BOOLEAN': sqltypes.Boolean, + 'CHAR': sqltypes.CHAR, + 'DATE': sqltypes.Date, + 'DATETIME': sqltypes.DateTime, + 'DECIMAL': sqltypes.Numeric, + 'FLOAT': sqltypes.Numeric, + 'INT': sqltypes.Integer, + 'INTEGER': sqltypes.Integer, + 'NUMERIC': sqltypes.Numeric, + 'REAL': sqltypes.Numeric, + 'SMALLINT': sqltypes.SmallInteger, + 'TEXT': sqltypes.Text, + 'TIME': sqltypes.Time, + 'TIMESTAMP': sqltypes.DateTime, + 'VARCHAR': sqltypes.String, +} + + class SQLiteCompiler(compiler.SQLCompiler): functions = compiler.SQLCompiler.functions.copy() @@ -216,6 +246,7 @@ class SQLiteDialect(default.DefaultDialect): name = 'sqlite' supports_alter = False supports_unicode_statements = True + supports_unicode_binds = True supports_default_values = True supports_empty_insert = False supports_cast = True @@ -224,6 +255,10 @@ class SQLiteDialect(default.DefaultDialect): ddl_compiler = SQLiteDDLCompiler type_compiler = SQLiteTypeCompiler preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + + def type_descriptor(self, typeobj): + return sqltypes.adapt_type(typeobj, colspecs) def table_names(self, connection, schema): if schema is not None: diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index b00f9e7a00..b4b9ca33d0 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -104,85 +104,13 @@ always represented by an actual database result string. """ -from sqlalchemy.dialects.sqlite.base import SLNumeric, SLFloat, SQLiteDialect, SLBoolean, SLDate, SLDateTime, SLTime +from sqlalchemy.dialects.sqlite.base import SQLiteDialect from sqlalchemy import schema, exc, pool from sqlalchemy.engine import default from sqlalchemy import types as sqltypes from sqlalchemy import util from types import NoneType -class SLUnicodeMixin(object): - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if self.assert_unicode is None: - assert_unicode = dialect.assert_unicode - else: - assert_unicode = self.assert_unicode - - if not assert_unicode: - return None - - def process(value): - if not isinstance(value, (unicode, NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value - return process - else: - return None - - def result_processor(self, dialect): - return None - -class SLText(SLUnicodeMixin, sqltypes.Text): - pass - -class SLString(SLUnicodeMixin, sqltypes.String): - pass - -class SLChar(SLUnicodeMixin, sqltypes.CHAR): - pass - - -colspecs = { - sqltypes.Boolean: SLBoolean, - sqltypes.CHAR: SLChar, - sqltypes.Date: SLDate, - sqltypes.DateTime: SLDateTime, - sqltypes.Float: SLFloat, - sqltypes.NCHAR: SLChar, - sqltypes.Numeric: SLNumeric, - sqltypes.String: SLString, - sqltypes.Text: SLText, - sqltypes.Time: SLTime, -} - -ischema_names = { - 'BLOB': sqltypes.Binary, - 'BOOL': SLBoolean, - 'BOOLEAN': SLBoolean, - 'CHAR': SLChar, - 'DATE': SLDate, - 'DATETIME': SLDateTime, - 'DECIMAL': SLNumeric, - 'FLOAT': SLNumeric, - 'INT': sqltypes.Integer, - 'INTEGER': sqltypes.Integer, - 'NUMERIC': SLNumeric, - 'REAL': SLNumeric, - 'SMALLINT': sqltypes.SmallInteger, - 'TEXT': SLText, - 'TIME': SLTime, - 'TIMESTAMP': SLDateTime, - 'VARCHAR': SLString, -} - - class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext): def post_exec(self): if self.isinsert and not self.executemany: @@ -195,7 +123,6 @@ class SQLite_pysqlite(SQLiteDialect): poolclass = pool.SingletonThreadPool execution_ctx_cls = SQLite_pysqliteExecutionContext driver = 'pysqlite' - ischema_names = ischema_names def __init__(self, **kwargs): SQLiteDialect.__init__(self, **kwargs) @@ -246,9 +173,6 @@ class SQLite_pysqlite(SQLiteDialect): return ([filename], opts) - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - def is_disconnect(self, e): return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 535c5fc1c8..f3acc28597 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -83,6 +83,9 @@ class Dialect(object): supports_unicode_statements Indicate whether the DB-API can receive SQL statements as Python unicode strings + supports_unicode_binds + Indicate whether the DB-API can receive string bind parameters as Python unicode strings + supports_sane_rowcount Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 12b1661925..1dc3d720ef 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -32,6 +32,8 @@ class DefaultDialect(base.Dialect): supports_sequences = False sequences_optional = False supports_unicode_statements = False + supports_unicode_binds = False + max_identifier_length = 9999 supports_sane_rowcount = True supports_sane_multi_rowcount = True diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d00a05436a..6838319987 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1021,7 +1021,7 @@ class GenericTypeCompiler(engine.TypeCompiler): raise NotImplementedError("Can't generate DDL for the null type") def visit_type_decorator(self, type_): - return self.process(type_.dialect_impl(self.dialect).impl) + return self.process(type_.type_engine(self.dialect)) def visit_user_defined(self, type_): return type_.get_col_spec() diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 986d3d1332..92ee125b63 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -86,6 +86,14 @@ class AbstractType(Visitable): """ return op + def get_search_list(self): + """return a list of classes to test for a match + when adapting this type to a dialect-specific type. + + """ + + return self.__class__.__mro__[0:-1] + def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, @@ -136,14 +144,6 @@ class TypeEngine(AbstractType): def adapt(self, cls): return cls() - def get_search_list(self): - """return a list of classes to test for a match - when adapting this type to a dialect-specific type. - - """ - - return self.__class__.__mro__[0:-1] - class UserDefinedType(TypeEngine): """Base for user defined types. @@ -227,7 +227,7 @@ class TypeDecorator(AbstractType): raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") self.impl = self.__class__.impl(*args, **kwargs) - def dialect_impl(self, dialect, **kwargs): + def dialect_impl(self, dialect): try: return self._impl_dict[dialect] except AttributeError: @@ -235,6 +235,17 @@ class TypeDecorator(AbstractType): except KeyError: pass + # adapt the TypeDecorator first, in + # the case that the dialect maps the TD + # to one of its native types (i.e. PGInterval) + adapted = dialect.type_descriptor(self) + if adapted is not self: + self._impl_dict[dialect] = adapted + return adapted + + # otherwise adapt the impl type, link + # to a copy of this TypeDecorator and return + # that. typedesc = self.load_dialect_impl(dialect) tt = self.copy() if not isinstance(tt, self.__class__): @@ -244,13 +255,20 @@ class TypeDecorator(AbstractType): self._impl_dict[dialect] = tt return tt + def type_engine(self, dialect): + impl = self.dialect_impl(dialect) + if not isinstance(impl, TypeDecorator): + return impl + else: + return impl.impl + def load_dialect_impl(self, dialect): """Loads the dialect-specific implementation of this type. by default calls dialect.type_descriptor(self.impl), but can be overridden to provide different behavior. + """ - if isinstance(self.impl, TypeDecorator): return self.impl.dialect_impl(dialect) else: @@ -452,18 +470,33 @@ class String(Concatenable, TypeEngine): assert_unicode = dialect.assert_unicode else: assert_unicode = self.assert_unicode - def process(value): - if isinstance(value, unicode): - return value.encode(dialect.encoding) - elif assert_unicode and not isinstance(value, (unicode, NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) + + if dialect.supports_unicode_binds and assert_unicode: + def process(value): + if not isinstance(value, (unicode, NoneType)): + if assert_unicode == 'warn': + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + else: + raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) + else: return value + elif dialect.supports_unicode_binds: + return None + else: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + elif assert_unicode and not isinstance(value, (unicode, NoneType)): + if assert_unicode == 'warn': + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + else: + raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value + return value return process else: return None @@ -492,9 +525,6 @@ class Text(String): """ __visit_name__ = 'text' - - def dialect_impl(self, dialect, **kwargs): - return TypeEngine.dialect_impl(self, dialect, **kwargs) class Unicode(String): """A variable length Unicode string. @@ -840,35 +870,17 @@ class Interval(TypeDecorator): """ - impl = TypeEngine - - def __init__(self): - super(Interval, self).__init__() - import sqlalchemy.dialects.postgres.base as pg - self.__supported = {pg.PGDialect:pg.PGInterval} - del pg - - def load_dialect_impl(self, dialect): - if dialect.__class__ in self.__supported: - return self.__supported[dialect.__class__]() - else: - return dialect.type_descriptor(DateTime) + impl = DateTime def process_bind_param(self, value, dialect): - if dialect.__class__ in self.__supported: - return value - else: - if value is None: - return None - return dt.datetime.utcfromtimestamp(0) + value + if value is None: + return None + return dt.datetime.utcfromtimestamp(0) + value def process_result_value(self, value, dialect): - if dialect.__class__ in self.__supported: - return value - else: - if value is None: - return None - return value - dt.datetime.utcfromtimestamp(0) + if value is None: + return None + return value - dt.datetime.utcfromtimestamp(0) class FLOAT(Float): """The SQL FLOAT type.""" diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 12f155d606..8e810ffa6a 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -265,6 +265,15 @@ else: def decode_slice(slc): return (slc.start, slc.stop, slc.step) +def update_copy(d, _new=None, **kw): + """Copy the given dict and update with the given values.""" + + d = d.copy() + if _new: + d.update(_new) + d.update(**kw) + return d + def flatten_iterator(x): """Given an iterator of which further sub-elements may also be iterators, flatten the sub-elements into a single iterator. diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py index 4f383d2dde..10c80e1352 100644 --- a/test/engine/reconnect.py +++ b/test/engine/reconnect.py @@ -332,7 +332,8 @@ class InvalidateDuringResultTest(TestBase): meta.drop_all() engine.dispose() - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close") + @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close") def test_invalidate_on_results(self): conn = engine.connect() @@ -342,7 +343,7 @@ class InvalidateDuringResultTest(TestBase): engine.test_shutdown() try: - result.fetchone() + print "ghost result: %r" % result.fetchone() assert False except tsa.exc.DBAPIError, e: if not e.connection_invalidated: diff --git a/test/sql/query.py b/test/sql/query.py index bf178ae8f5..660529c25c 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -252,6 +252,7 @@ class QueryTest(TestBase): eq_(expr.execute().fetchall(), result) + @testing.fails_on("+pg8000", "can't interpret result column from '%%'") @testing.emits_warning('.*now automatically escapes.*') def test_percents_in_text(self): for expr, result in ( diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index ca22fcb270..39f79e540c 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -291,7 +291,8 @@ class UnicodeTest(TestBase, AssertsExecutionResults): assert unicode_table.c.unicode_varchar.type.length == 250 rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') - if testing.against('sqlite'): + + if testing.against('sqlite', '>' '2.4'): rawdata = "something" unicode_table.insert().execute(unicode_varchar=unicodedata, @@ -300,12 +301,12 @@ class UnicodeTest(TestBase, AssertsExecutionResults): x = unicode_table.select().execute().fetchone() self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) + if isinstance(x['plain_varchar'], unicode): # SQLLite and MSSQL return non-unicode data as unicode self.assert_(testing.against('sqlite', '+pyodbc')) if not testing.against('sqlite'): self.assert_(x['plain_varchar'] == unicodedata) - print "it's %s!" % testing.db.name else: self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) @@ -778,6 +779,7 @@ class IntervalTest(TestBase, AssertsExecutionResults): def tearDownAll(self): metadata.drop_all() + @testing.fails_on("+pg8000", "Not yet known how to pass values of the INTERVAL type") def test_roundtrip(self): delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17) interval_table.insert().execute(interval=delta) diff --git a/test/testlib/requires.py b/test/testlib/requires.py index 200fb01b11..4ccce96205 100644 --- a/test/testlib/requires.py +++ b/test/testlib/requires.py @@ -98,6 +98,7 @@ def two_phase_transactions(fn): fn, no_support('access', 'not supported by database'), no_support('firebird', 'no SA implementation'), + no_support('+pg8000', 'FIXME: not sure how to accomplish'), no_support('maxdb', 'not supported by database'), no_support('mssql', 'FIXME: guessing, needs confirmation'), no_support('oracle', 'no SA implementation'), -- 2.47.3