From b9cc0ef3260ab5c93af3c011db7d062ba15252fe Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 8 Mar 2010 18:27:35 -0500 Subject: [PATCH] working on getting operators/left hand type awareness into the "bind" coercion. this system has to be figured out somehow --- lib/sqlalchemy/sql/expression.py | 55 +++++++++++----------- lib/sqlalchemy/types.py | 60 +++++++++++++++++++++++- test/sql/test_types.py | 78 +++++++++++++++++++++++++++----- 3 files changed, 154 insertions(+), 39 deletions(-) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 1c3961f1f7..49ec34ab22 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1443,7 +1443,7 @@ class _CompareMixin(ColumnOperators): else: raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") else: - obj = self._check_literal(obj) + obj = self._check_literal(op, obj) if reverse: return _BinaryExpression(obj, @@ -1459,7 +1459,7 @@ class _CompareMixin(ColumnOperators): negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): - obj = self._check_literal(obj) + obj = self._check_literal(op, obj) if reverse: left, right = obj, self @@ -1532,7 +1532,7 @@ class _CompareMixin(ColumnOperators): "in() function accepts either a list of non-selectable values, " "or a selectable: %r" % o) else: - o = self._bind_param(o) + o = self._bind_param(op, o) args.append(o) if len(args) == 0: @@ -1558,7 +1558,7 @@ class _CompareMixin(ColumnOperators): # use __radd__ to force string concat behavior return self.__compare( operators.like_op, - literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)), + literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(operators.like_op, other)), escape=escape) def endswith(self, other, escape=None): @@ -1566,7 +1566,7 @@ class _CompareMixin(ColumnOperators): return self.__compare( operators.like_op, - literal_column("'%'", type_=sqltypes.String) + self._check_literal(other), + literal_column("'%'", type_=sqltypes.String) + self._check_literal(operators.like_op, other), escape=escape) def contains(self, other, escape=None): @@ -1575,7 +1575,7 @@ class _CompareMixin(ColumnOperators): return self.__compare( operators.like_op, literal_column("'%'", type_=sqltypes.String) + - self._check_literal(other) + + self._check_literal(operators.like_op, other) + literal_column("'%'", type_=sqltypes.String), escape=escape) @@ -1585,7 +1585,7 @@ class _CompareMixin(ColumnOperators): The allowed contents of ``other`` are database backend specific. """ - return self.__compare(operators.match_op, self._check_literal(other)) + return self.__compare(operators.match_op, self._check_literal(operators.match_op, other)) def label(self, name): """Produce a column label, i.e. `` AS ``. @@ -1615,8 +1615,8 @@ class _CompareMixin(ColumnOperators): return _BinaryExpression( self, ClauseList( - self._check_literal(cleft), - self._check_literal(cright), + self._check_literal(operators.and_, cleft), + self._check_literal(operators.and_, cright), operator=operators.and_, group=False), operators.between_op) @@ -1651,17 +1651,18 @@ class _CompareMixin(ColumnOperators): """ return lambda other: self.__operate(operator, other) - def _bind_param(self, obj): - return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) - def _check_literal(self, other): - if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): + def _check_literal(self, operator, other): + if isinstance(other, _BindParamClause) and \ + isinstance(other.type, sqltypes.NullType): other.type = self.type return other elif hasattr(other, '__clause_element__'): return other.__clause_element__() elif not isinstance(other, ClauseElement): - return self._bind_param(other) + return self._bind_param(operator, other) elif isinstance(other, (_SelectBaseMixin, Alias)): return other.as_scalar() else: @@ -2108,7 +2109,8 @@ class _BindParamClause(ColumnElement): def __init__(self, key, value, type_=None, unique=False, isoutparam=False, required=False, - _fallback_type=None): + _compared_to_operator=None, + _compared_to_type=None): """Construct a _BindParamClause. key @@ -2154,9 +2156,10 @@ class _BindParamClause(ColumnElement): self.required = required if type_ is None: - self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE) - if _fallback_type and _fallback_type._type_affinity == self.type._type_affinity: - self.type = _fallback_type + if _compared_to_type is not None: + self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value) + else: + self.type = sqltypes.NULLTYPE elif isinstance(type_, type): self.type = type_() else: @@ -2434,9 +2437,9 @@ class _Tuple(ClauseList, ColumnElement): def _select_iterable(self): return (self, ) - def _bind_param(self, obj): + def _bind_param(self, operator, obj): return _Tuple(*[ - _BindParamClause(None, o, _fallback_type=self.type, unique=True) + _BindParamClause(None, o, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) for o in obj ]).self_group() @@ -2538,8 +2541,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): def execute(self): return self.select().execute() - def _bind_param(self, obj): - return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) class Function(FunctionElement): @@ -2555,8 +2558,8 @@ class Function(FunctionElement): FunctionElement.__init__(self, *clauses, **kw) - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) class _Cast(ColumnElement): @@ -3165,8 +3168,8 @@ class ColumnClause(_Immutable, ColumnElement): else: return [] - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) def _make_proxy(self, selectable, name=None, attach=True): # propagate the "is_literal" flag only if we are keeping our name, diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index cdbf7927ef..4d6a28aadc 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -116,6 +116,13 @@ class AbstractType(Visitable): typ = t else: return self.__class__ + + def _coerce_compared_value(self, op, value): + _coerced_type = type_map.get(type(value), NULLTYPE) + if _coerced_type._type_affinity == self._type_affinity: + return self + else: + return _coerced_type def _compare_type_affinity(self, other): return self._type_affinity is other._type_affinity @@ -239,7 +246,7 @@ class TypeDecorator(AbstractType): # strips it off on the way out. impl = types.Unicode - + def process_bind_param(self, value, dialect): return "PREFIX:" + value @@ -255,6 +262,44 @@ class TypeDecorator(AbstractType): given; in this case, the "impl" variable can reference ``TypeEngine`` as a placeholder. + Types that receive a Python type that isn't similar to the + ultimate type used may want to define the :meth:`TypeDecorator.coerce_compared_value` + method=. This is used to give the expression system a hint + when coercing Python objects + into bind parameters within expressions. Consider this expression:: + + mytable.c.somecol + datetime.date(2009, 5, 15) + + Above, if "somecol" is an ``Integer`` variant, it makes sense that + we doing date arithmetic, where above is usually interpreted + by databases as adding a number of days to the given date. + The expression system does the right thing by not attempting to + coerce the "date()" value into an integer-oriented bind parameter. + + However, suppose "somecol" is a ``TypeDecorator`` that is wrapping + an ``Integer``, and our ``TypeDecorator`` is actually storing dates + as an "epoch", i.e. a total number of days from a fixed starting + date. So in this case, we *do* want the expression system to wrap + the date() into our ``TypeDecorator`` type's system of coercing + dates into integers. So we would want to define:: + + class MyEpochType(types.TypeDecorator): + impl = types.Integer + + epoch = datetime.date(1970, 1, 1) + + def process_bind_param(self, value, dialect): + return (value - self.epoch).days + + def process_result_value(self, value, dialect): + return self.epoch + timedelta(days=value) + + def coerce_compared_value(self, op, value): + if isinstance(value, datetime.date): + return Date + else: + raise ValueError("Python date expected.") + The reason that type behavior is modified using class decoration instead of subclassing is due to the way dialect specific types are used. Such as with the example above, when using the mysql @@ -365,7 +410,13 @@ class TypeDecorator(AbstractType): return process else: return self.impl.result_processor(dialect, coltype) + + def coerce_compared_value(self, op, value): + return self.impl._coerce_compared_value(op, value) + def _coerce_compared_value(self, op, value): + return self.coerce_compared_value(op, value) + def copy(self): instance = self.__class__.__new__(self.__class__) instance.__dict__.update(self.__dict__) @@ -384,6 +435,11 @@ class TypeDecorator(AbstractType): def is_mutable(self): return self.impl.is_mutable() + def _adapt_expression(self, op, othertype): + return self.impl._adapt_expression(op, othertype) + + + class MutableType(object): """A mixin that marks a Type as holding a mutable object. @@ -461,7 +517,7 @@ class Concatenable(object): """A mixin that marks a type as supporting 'concatenation', typically strings.""" def _adapt_expression(self, op, othertype): - if op is operators.add and isinstance(othertype, (Concatenable, NullType)): + if op is operators.add and issubclass(othertype._type_affinity, (Concatenable, NullType)): return operators.concat_op, self else: return op, self diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 3ac8baf004..ad58f1867e 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -4,7 +4,7 @@ import decimal import datetime, os, re from sqlalchemy import * from sqlalchemy import exc, types, util, schema -from sqlalchemy.sql import operators, column +from sqlalchemy.sql import operators, column, table from sqlalchemy.test.testing import eq_ import sqlalchemy.engine.url as url from sqlalchemy.databases import * @@ -687,10 +687,10 @@ class BinaryTest(TestBase, AssertsExecutionResults): f = os.path.join(os.path.dirname(__file__), "..", name) return open(f, mode='rb').read() -class ExpressionTest(TestBase, AssertsExecutionResults): +class ExpressionTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): @classmethod def setup_class(cls): - global test_table, meta, MyCustomType + global test_table, meta, MyCustomType, MyTypeDec class MyCustomType(types.UserDefinedType): def get_col_spec(self): @@ -705,13 +705,24 @@ class ExpressionTest(TestBase, AssertsExecutionResults): return process def adapt_operator(self, op): return {operators.add:operators.sub, operators.sub:operators.add}.get(op, op) + + class MyTypeDec(types.TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + return "BIND_IN" + str(value) + def process_result_value(self, value, dialect): + return value + "BIND_OUT" + meta = MetaData(testing.db) test_table = Table('test', meta, Column('id', Integer, primary_key=True), Column('data', String(30)), Column('atimestamp', Date), - Column('avalue', MyCustomType)) + Column('avalue', MyCustomType), + Column('bvalue', MyTypeDec), + ) meta.create_all() @@ -719,7 +730,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults): 'id':1, 'data':'somedata', 'atimestamp':datetime.date(2007, 10, 15), - 'avalue':25}) + 'avalue':25, 'bvalue':'foo'}) @classmethod def teardown_class(cls): @@ -730,7 +741,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults): eq_( test_table.select().execute().fetchall(), - [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + [(1, 'somedata', datetime.date(2007, 10, 15), 25, "BIND_INfooBIND_OUT")] ) def test_bind_adapt(self): @@ -740,17 +751,26 @@ class ExpressionTest(TestBase, AssertsExecutionResults): eq_( testing.db.execute( - test_table.select().where(expr), + select([test_table.c.id, test_table.c.data, test_table.c.atimestamp]) + .where(expr), {"thedate":datetime.date(2007, 10, 15)}).fetchall(), - [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + [(1, 'somedata', datetime.date(2007, 10, 15))] ) expr = test_table.c.avalue == bindparam("somevalue") eq_(expr.right.type._type_affinity, MyCustomType) - + eq_( testing.db.execute(test_table.select().where(expr), {"somevalue":25}).fetchall(), - [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + [(1, 'somedata', datetime.date(2007, 10, 15), 25, 'BIND_INfooBIND_OUT')] + ) + + expr = test_table.c.bvalue == bindparam("somevalue") + eq_(expr.right.type._type_affinity, String) + + eq_( + testing.db.execute(test_table.select().where(expr), {"somevalue":"foo"}).fetchall(), + [(1, 'somedata', datetime.date(2007, 10, 15), 25, 'BIND_INfooBIND_OUT')] ) def test_literal_adapt(self): @@ -799,7 +819,43 @@ class ExpressionTest(TestBase, AssertsExecutionResults): # this one relies upon anonymous labeling to assemble result # processing rules on the column. assert testing.db.execute(select([expr])).scalar() == -15 - + + def test_typedec_operator_adapt(self): + expr = test_table.c.bvalue + "hi" + + assert expr.type.__class__ is String + + eq_( + testing.db.execute(select([expr.label('foo')])).scalar(), + "BIND_INfooBIND_INhiBIND_OUT" + ) + + def test_typedec_righthand_coercion(self): + class MyTypeDec(types.TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + return "BIND_IN" + str(value) + + def process_result_value(self, value, dialect): + return value + "BIND_OUT" + + tab = table('test', column('bvalue', MyTypeDec)) + expr = tab.c.bvalue + 6 + + self.assert_compile( + expr, + "test.bvalue || :bvalue_1", + use_default_dialect=True + ) + + assert expr.type.__class__ is String + eq_( + testing.db.execute(select([expr.label('foo')])).scalar(), + "BIND_INfooBIND_IN6BIND_OUT" + ) + + def test_bind_typing(self): from sqlalchemy.sql import column -- 2.47.3