From f327eaea478670198fbaa5b16047be73e9dd6aba Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 16 Aug 2012 16:11:42 -0400 Subject: [PATCH] _adapt_expression() moves fully to _DefaultColumnComparator which resumes its original role as stateful, forms the basis of TypeEngine.Comparator. lots of code goes back mostly as it was just with cleaner typing behavior, such as simple flow in _binary_operate now. --- lib/sqlalchemy/dialects/postgresql/base.py | 7 +- lib/sqlalchemy/sql/expression.py | 81 ++++++++---- lib/sqlalchemy/types.py | 143 +++++++-------------- test/sql/test_operators.py | 58 ++------- test/sql/test_types.py | 1 + 5 files changed, 117 insertions(+), 173 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index a9ff988e81..36da14d333 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -708,9 +708,10 @@ class PGCompiler(compiler.SQLCompiler): affinity = None casts = { - sqltypes.Date:'date', - sqltypes.DateTime:'timestamp', - sqltypes.Interval:'interval', sqltypes.Time:'time' + sqltypes.Date: 'date', + sqltypes.DateTime: 'timestamp', + sqltypes.Interval: 'interval', + sqltypes.Time: 'time' } cast = casts.get(affinity, None) if isinstance(extract.expr, sql.ColumnElement) and cast is not None: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 844293c736..63fa23c15d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1875,7 +1875,7 @@ class Immutable(object): return self -class _DefaultColumnComparator(object): +class _DefaultColumnComparator(operators.ColumnOperators): """Defines comparison and math operations. See :class:`.ColumnOperators` and :class:`.Operators` for descriptions @@ -1883,6 +1883,45 @@ class _DefaultColumnComparator(object): """ + @util.memoized_property + def type(self): + return self.expr.type + + def operate(self, op, *other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, self.expr, op, *(other + o[1:]), **kwargs) + + def reverse_operate(self, op, other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, self.expr, op, other, reverse=True, *o[1:], **kwargs) + + def _adapt_expression(self, op, other_comparator): + """evaluate the return type of , + and apply any adaptations to the given operator. + + This method determines the type of a resulting binary expression + given two source types and an operator. For example, two + :class:`.Column` objects, both of the type :class:`.Integer`, will + produce a :class:`.BinaryExpression` that also has the type + :class:`.Integer` when compared via the addition (``+``) operator. + However, using the addition operator with an :class:`.Integer` + and a :class:`.Date` object will produce a :class:`.Date`, assuming + "days delta" behavior by the database (in reality, most databases + other than Postgresql don't accept this particular operation). + + The method returns a tuple of the form , . + The resulting operator and type will be those applied to the + resulting :class:`.BinaryExpression` as the final operator and the + right-hand side of the expression. + + Note that only a subset of operators make usage of + :meth:`._adapt_expression`, + including math operators and user-defined operators, but not + boolean comparison or special SQL keywords like MATCH or BETWEEN. + + """ + return op, other_comparator.type + def _boolean_compare(self, expr, op, obj, negate=None, reverse=False, **kwargs ): @@ -1912,7 +1951,7 @@ class _DefaultColumnComparator(object): type_=sqltypes.BOOLEANTYPE, negate=negate, modifiers=kwargs) - def _binary_operate(self, expr, op, obj, result_type, reverse=False): + def _binary_operate(self, expr, op, obj, reverse=False): obj = self._check_literal(expr, op, obj) if reverse: @@ -1920,6 +1959,8 @@ class _DefaultColumnComparator(object): else: left, right = expr, obj + op, result_type = left.comparator._adapt_expression(op, right.comparator) + return BinaryExpression(left, right, op, type_=result_type) def _scalar(self, expr, op, fn, **kw): @@ -1986,7 +2027,8 @@ class _DefaultColumnComparator(object): expr, operators.like_op, literal_column("'%'", type_=sqltypes.String).__radd__( - self._check_literal(expr, operators.like_op, other) + self._check_literal(expr, + operators.like_op, other) ), escape=escape) @@ -2068,21 +2110,16 @@ class _DefaultColumnComparator(object): "neg": (_neg_impl,), } - def operate(self, expr, op, *other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, expr, op, *(other + o[1:]), **kwargs) - - def reverse_operate(self, expr, op, other, **kwargs): - o = self.operators[op.__name__] - return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs) def _check_literal(self, expr, operator, other): - if isinstance(other, BindParameter) and \ - isinstance(other.type, sqltypes.NullType): - # TODO: perhaps we should not mutate the incoming bindparam() - # here and instead make a copy of it. this might - # be the only place that we're mutating an incoming construct. - other.type = expr.type + if isinstance(other, (ColumnElement, TextClause)): + if isinstance(other, BindParameter) and \ + isinstance(other.type, sqltypes.NullType): + # TODO: perhaps we should not mutate the incoming + # bindparam() here and instead make a copy of it. + # this might be the only place that we're mutating + # an incoming construct. + other.type = expr.type return other elif hasattr(other, '__clause_element__'): other = other.__clause_element__() @@ -2096,8 +2133,6 @@ class _DefaultColumnComparator(object): else: return other -_DEFAULT_COMPARATOR = _DefaultColumnComparator() - class ColumnElement(ClauseElement, ColumnOperators): """Represent an element that is usable within the "column clause" portion @@ -2155,11 +2190,7 @@ class ColumnElement(ClauseElement, ColumnOperators): def comparator(self): return self.type.comparator_factory(self) - #def _assert_comparator(self): - # assert self.comparator.expr is self - def __getattr__(self, key): - #self._assert_comparator() try: return getattr(self.comparator, key) except AttributeError: @@ -2171,11 +2202,9 @@ class ColumnElement(ClauseElement, ColumnOperators): ) def operate(self, op, *other, **kwargs): - #self._assert_comparator() return op(self.comparator, *other, **kwargs) def reverse_operate(self, op, other, **kwargs): - #self._assert_comparator() return op(other, self.comparator, **kwargs) def _bind_param(self, operator, obj): @@ -3090,6 +3119,10 @@ class TextClause(Executable, ClauseElement): else: return sqltypes.NULLTYPE + @property + def comparator(self): + return self.type.comparator_factory(self) + def self_group(self, against=None): if against is operators.in_op: return Grouping(self) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index d4dbd648c2..bbeebf5d36 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -11,21 +11,21 @@ types. For more information see the SQLAlchemy documentation on types. """ -__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', +__all__ = ['TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text', 'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer', 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', 'Unicode', 'Concatenable', - 'UnicodeText','PickleType', 'Interval', 'Enum' ] + 'UnicodeText', 'PickleType', 'Interval', 'Enum'] import datetime as dt import codecs from . import exc, schema, util, processors, events, event from .sql import operators -from .sql.expression import _DEFAULT_COMPARATOR +from .sql.expression import _DefaultColumnComparator from .util import pickle from .util.compat import decimal from .sql.visitors import Visitable @@ -42,7 +42,7 @@ class AbstractType(Visitable): class TypeEngine(AbstractType): """Base for built-in types.""" - class Comparator(operators.ColumnOperators): + class Comparator(_DefaultColumnComparator): """Base class for custom comparison operations defined at the type level. See :attr:`.TypeEngine.comparator_factory`. @@ -54,24 +54,6 @@ class TypeEngine(AbstractType): def __reduce__(self): return _reconstitute_comparator, (self.expr, ) - def operate(self, op, *other, **kwargs): - if len(other) == 1: - obj = other[0] - obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, obj) - op, adapt_type = self.expr.type._adapt_expression(op, - obj.type) - kwargs['result_type'] = adapt_type - - return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs) - - def reverse_operate(self, op, other, **kwargs): - - obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, other) - op, adapt_type = obj.type._adapt_expression(op, self.expr.type) - kwargs['result_type'] = adapt_type - - return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, obj, - **kwargs) comparator_factory = Comparator """A :class:`.TypeEngine.Comparator` class which will apply @@ -143,11 +125,6 @@ class TypeEngine(AbstractType): >>> (c1 == c2).type Boolean() - The propagation of :class:`.TypeEngine.Comparator` throughout an expression - will follow with how the :class:`.TypeEngine` itself is propagated. To - customize the behavior of most operators in this regard, see the - :meth:`._adapt_expression` method. - .. versionadded:: 0.8 The expression system was reworked to support user-defined comparator objects specified at the type level. @@ -247,34 +224,7 @@ class TypeEngine(AbstractType): .. versionadded:: 0.7.2 """ - return Variant(self, {dialect_name:type_}) - - def _adapt_expression(self, op, othertype): - """evaluate the return type of , - and apply any adaptations to the given operator. - - This method determines the type of a resulting binary expression - given two source types and an operator. For example, two - :class:`.Column` objects, both of the type :class:`.Integer`, will - produce a :class:`.BinaryExpression` that also has the type - :class:`.Integer` when compared via the addition (``+``) operator. - However, using the addition operator with an :class:`.Integer` - and a :class:`.Date` object will produce a :class:`.Date`, assuming - "days delta" behavior by the database (in reality, most databases - other than Postgresql don't accept this particular operation). - - The method returns a tuple of the form , . - The resulting operator and type will be those applied to the - resulting :class:`.BinaryExpression` as the final operator and the - right-hand side of the expression. - - Note that only a subset of operators make usage of - :meth:`._adapt_expression`, - including math operators and user-defined operators, but not - boolean comparison or special SQL keywords like MATCH or BETWEEN. - - """ - return op, self + return Variant(self, {dialect_name: type_}) @util.memoized_property def _type_affinity(self): @@ -334,7 +284,7 @@ class TypeEngine(AbstractType): impl = self.adapt(type(self)) # this can't be self, else we create a cycle assert impl is not self - dialect._type_memos[self] = d = {'impl':impl} + dialect._type_memos[self] = d = {'impl': impl} return d def _gen_dialect_impl(self, dialect): @@ -461,22 +411,21 @@ class UserDefinedType(TypeEngine): """ __visit_name__ = "user_defined" - def _adapt_expression(self, op, othertype): - """evaluate the return type of , - and apply any adaptations to the given operator. - - """ - return self.adapt_operator(op), self - - def adapt_operator(self, op): - """A hook which allows the given operator to be adapted - to something new. + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if hasattr(self.type, 'adapt_operator'): + util.warn_deprecated( + "UserDefinedType.adapt_operator is deprecated. Create " + "a UserDefinedType.Comparator subclass instead which " + "generates the desired expression constructs, given a " + "particular operator." + ) + return self.type.adapt_operator(op), self.type + else: + return op, self.type - See also UserDefinedType._adapt_expression(), an as-yet- - semi-public method with greater capability in this regard. + comparator_factory = Comparator - """ - return op class TypeDecorator(TypeEngine): """Allows the creation of types which add additional functionality @@ -837,13 +786,6 @@ class TypeDecorator(TypeEngine): """ return self.impl.compare_values(x, y) - def _adapt_expression(self, op, othertype): - op, typ = self.impl._adapt_expression(op, othertype) - typ = to_instance(typ) - if typ._compare_type_affinity(self.impl): - return op, self - else: - return op, typ class Variant(TypeDecorator): """A wrapping type that selects among a variety of @@ -926,8 +868,6 @@ def adapt_type(typeobj, colspecs): return typeobj.adapt(impltype) - - class NullType(TypeEngine): """An unknown type. @@ -943,11 +883,14 @@ class NullType(TypeEngine): """ __visit_name__ = 'null' - def _adapt_expression(self, op, othertype): - if isinstance(othertype, NullType) or not operators.is_commutative(op): - return op, self - else: - return othertype._adapt_expression(op, self) + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if isinstance(other_comparator, NullType.Comparator) or \ + not operators.is_commutative(op): + return op, self.expr.type + else: + return other_comparator._adapt_expression(op, self) + comparator_factory = Comparator NullTypeEngine = NullType @@ -955,12 +898,16 @@ class Concatenable(object): """A mixin that marks a type as supporting 'concatenation', typically strings.""" - def _adapt_expression(self, op, othertype): - if op is operators.add and issubclass(othertype._type_affinity, - (Concatenable, NullType)): - return operators.concat_op, self - else: - return op, self + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if op is operators.add and isinstance(other_comparator, + (Concatenable.Comparator, NullType.Comparator)): + return operators.concat_op, self.expr.type + else: + return op, self.expr.type + + comparator_factory = Comparator + class _DateAffinity(object): """Mixin date/time specific expression adaptations. @@ -975,12 +922,14 @@ class _DateAffinity(object): def _expression_adaptations(self): raise NotImplementedError() - _blank_dict = util.immutabledict() - def _adapt_expression(self, op, othertype): - othertype = othertype._type_affinity - return op, \ - self._expression_adaptations.get(op, self._blank_dict).\ - get(othertype, NULLTYPE) + class Comparator(TypeEngine.Comparator): + _blank_dict = util.immutabledict() + def _adapt_expression(self, op, other_comparator): + othertype = other_comparator.type._type_affinity + return op, \ + self.type._expression_adaptations.get(op, self._blank_dict).\ + get(othertype, NULLTYPE) + comparator_factory = Comparator class String(Concatenable, TypeEngine): """The base for all string and character types. diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index c38f95a015..05de8c9ef4 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -12,18 +12,16 @@ from sqlalchemy.types import Integer, TypeEngine, TypeDecorator class DefaultColumnComparatorTest(fixtures.TestBase): def _do_scalar_test(self, operator, compare_to): - cc = _DefaultColumnComparator() left = column('left') - assert cc.operate(left, operator).compare( + assert left.comparator.operate(operator).compare( compare_to(left) ) def _do_operate_test(self, operator): - cc = _DefaultColumnComparator() left = column('left') right = column('right') - assert cc.operate(left, operator, right, result_type=Integer).compare( + assert left.comparator.operate(operator, right).compare( BinaryExpression(left, right, operator) ) @@ -37,9 +35,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase): self._do_operate_test(operators.add) def test_in(self): - cc = _DefaultColumnComparator() left = column('left') - assert cc.operate(left, operators.in_op, [1, 2, 3]).compare( + assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare( BinaryExpression( left, Grouping(ClauseList( @@ -50,10 +47,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase): ) def test_collate(self): - cc = _DefaultColumnComparator() left = column('left') right = "some collation" - cc.operate(left, operators.collate, right).compare( + left.comparator.operate(operators.collate, right).compare( collate(left, right) ) @@ -144,12 +140,8 @@ class _CustomComparatorTests(object): self._assert_add_override(6 - c1) def test_binary_multi_propagate(self): - c1 = Column('foo', self._add_override_factory(True)) - self._assert_add_override((c1 - 6) + 5) - - def test_no_binary_multi_propagate_wo_adapt(self): c1 = Column('foo', self._add_override_factory()) - self._assert_not_add_override((c1 - 6) + 5) + self._assert_add_override((c1 - 6) + 5) def test_no_boolean_propagate(self): c1 = Column('foo', self._add_override_factory()) @@ -166,7 +158,7 @@ class _CustomComparatorTests(object): ) class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): @@ -176,19 +168,12 @@ class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): def __add__(self, other): return self.expr.op("goofy")(other) - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) return MyInteger class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(TypeDecorator): impl = Integer @@ -200,19 +185,12 @@ class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): def __add__(self, other): return self.expr.op("goofy")(other) - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) return MyInteger class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): @@ -221,13 +199,6 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas def __add__(self, other): return self.expr.op("goofy")(other) - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) class MyDecInteger(TypeDecorator): impl = MyInteger @@ -235,7 +206,7 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas return MyDecInteger class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self, include_adapt=False): + def _add_override_factory(self): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): @@ -243,15 +214,6 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): def foob(self, other): return self.expr.op("foob")(other) - - if include_adapt: - def _adapt_expression(self, op, othertype): - if op.__name__ == 'custom_op': - return op, self - else: - return super(MyInteger, self)._adapt_expression( - op, othertype) - return MyInteger def _assert_add_override(self, expr): @@ -262,5 +224,3 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): def _assert_not_add_override(self, expr): assert not hasattr(expr, "foob") - def test_no_binary_multi_propagate_wo_adapt(self): - pass \ No newline at end of file diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 91bf17175f..279ae36a0a 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1222,6 +1222,7 @@ class ExpressionTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled eq_(expr.right.type.__class__, CHAR) + @testing.uses_deprecated @testing.fails_on('firebird', 'Data type unknown on the parameter') @testing.fails_on('mssql', 'int is unsigned ? not clear') def test_operator_adapt(self): -- 2.47.3