elif isinstance(key, basestring) and key.lower() in props:
rec = props[key.lower()]
elif isinstance(key, sql.ColumnElement):
- print "LABEL ON COLUMN", repr(key.key), "IS", repr(key._label)
label = context.column_labels.get(key._label, key.name).lower()
- print "SO YEAH, NOW WE GOT LABEL", repr(label), "AND PROPS IS", repr(props)
if label in props:
rec = props[label]
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util
+from sqlalchemy import util, sql
from sqlalchemy.orm import util as orm_util, interfaces, collections
from sqlalchemy.orm.mapper import class_mapper
from sqlalchemy import logging, exceptions
PASSIVE_NORESULT = object()
ATTR_WAS_SET = object()
-class InstrumentedAttribute(object):
- def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, **kwargs):
+class InstrumentedAttribute(sql.Comparator):
+ """attribute access for instrumented classes."""
+
+ def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs):
+ """Construct an InstrumentedAttribute.
+
+ class_
+ the class to be instrumented.
+
+ manager
+ AttributeManager managing this class
+
+ key
+ string name of the attribute
+
+ callable_
+ optional function which generates a callable based on a parent
+ instance, which produces the "default" values for a scalar or
+ collection attribute when it's first accessed, if not present already.
+
+ trackparent
+ if True, attempt to track if an instance has a parent attached to it
+ via this attribute
+
+ extension
+ an AttributeExtension object which will receive
+ set/delete/append/remove/etc. events
+
+ compare_function
+ a function that compares two values which are normally assignable to this
+ attribute
+
+ mutable_scalars
+ if True, the values which are normally assignable to this attribute can mutate,
+ and need to be compared against a copy of their original contents in order to
+ detect changes on the parent instance
+
+ comparator
+ a sql.Comparator to which compare/math events will be sent
+
+ """
+
self.class_ = class_
self.manager = manager
self.key = key
self.callable_ = callable_
self.trackparent = trackparent
self.mutable_scalars = mutable_scalars
-
+ self.comparator = comparator
self.copy = None
if compare_function is None:
self.is_equal = lambda x,y: x == y
return self
return self.get(obj)
+ def compare_self(self):
+ return self.comparator.compare_self()
+
+ def operate(self, op, other):
+ return self.comparator.operate(op, other)
+
+ def reverse_operate(self, op, other):
+ return self.comparator.reverse_operate(op, other)
+
def hasparent(self, item, optimistic=False):
"""Return the boolean value of a `hasparent` flag attached to the given item.
class InstrumentedScalarAttribute(InstrumentedAttribute):
+ """represents a scalar-holding InstrumentedAttribute."""
+
def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
super(InstrumentedScalarAttribute, self).__init__(class_, manager, key,
callable_, trackparent=trackparent, extension=extension,
obj.__dict__[self.key] = value
self.fire_replace_event(obj, value, old, initiator)
+ type = property(lambda self: self.property.columns[0].type)
+
+
class InstrumentedCollectionAttribute(InstrumentedAttribute):
"""A collection-holding attribute that instruments changes in membership.
return self.attr.hasparent(obj)
class AttributeManager(object):
- """Allow the instrumentation of object attributes.
-
- ``AttributeManager`` is stateless, but can be overridden by
- subclasses to redefine some of its factory operations. Also be
- aware ``AttributeManager`` will cache attributes for a given
- class, allowing not to determine those for each objects (used in
- ``managed_attributes()`` and
- ``noninherited_managed_attributes()``). This cache is cleared for
- a given class while calling ``register_attribute()``, and can be
- cleared using ``clear_attribute_cache()``.
- """
+ """Allow the instrumentation of object attributes."""
def __init__(self):
# will cache attributes, indexed by class objects
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util, logging
+from sqlalchemy import util, logging, sql
# returned by a MapperExtension method to indicate a "do nothing" response
EXT_PASS = object()
raise NotImplementedError()
+class PropComparator(sql.Comparator):
+ """defines comparison operations for MapperProperty objects"""
+
+ def __init__(self, prop):
+ self.prop = prop
class StrategizedProperty(MapperProperty):
"""A MapperProperty which uses selectable strategies to affect
self.columns = list(columns)
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
+ self.comparator = ColumnProperty.ColumnComparator(self)
def create_strategy(self):
if self.deferred:
def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
- def compare(self, value):
- return self.columns[0] == value
+ def compare(self, value, op='=='):
+ return self.comparator == value
def get_col_value(self, column, value):
return value
+
+ class ColumnComparator(PropComparator):
+ def compare_self(self):
+ return self.prop.columns[0]
+
+ def operate(self, op, other):
+ return op(self.prop.columns[0], other)
+
+ def reverse_operate(self, op, other):
+ col = self.prop.columns[0]
+ return op(col._bind_param(other), col)
+
ColumnProperty.logger = logging.class_logger(ColumnProperty)
def __init__(self, class_, *columns, **kwargs):
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
-
+ self.comparator = None
+
def copy(self):
return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
if a is column:
setattr(obj, b, value)
- def compare(self, value):
- return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())])
+ def compare(self, value, op='=='):
+ # TODO: build into operator framework
+ if op == '==':
+ return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())])
+ elif op == '!=':
+ return sql.or_([a!=b for a, b in zip(self.columns, value.__colset__())])
def get_col_value(self, column, value):
for a, b in zip(self.columns, value.__colset__()):
self.remote_side = util.to_set(remote_side)
self.enable_typechecks = enable_typechecks
self._parent_join_cache = {}
+ self.comparator = None
if cascade is not None:
self.cascade = mapperutil.CascadeOptions(cascade)
self.backref = backref
self.is_backref = is_backref
- def compare(self, value):
- return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
+ def compare(self, value, value_is_parent=False, op='=='):
+ if op == '==':
+ # optimized operation for ==, uses a lazy clause.
+ (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
+ bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+ class Visitor(sql.ClauseVisitor):
+ def visit_bindparam(s, bindparam):
+ mapper = value_is_parent and self.parent or self.mapper
+ bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
+ Visitor().traverse(criterion)
+ return criterion
+ else:
+ # TODO: build expressions like these into operator framework
+ return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
private = property(lambda s:s.cascade.delete_orphan)
if instance is None:
raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
return instance
-
-
- def _with_lazy_criterion(cls, instance, prop, reverse=False):
- """extract query criterion from a LazyLoader strategy given a Mapper,
- source persisted/detached instance and PropertyLoader.
-
- """
-
- from sqlalchemy.orm import strategies
- (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(prop, reverse_direction=reverse)
- bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
-
- class Visitor(sql.ClauseVisitor):
- def visit_bindparam(self, bindparam):
- mapper = reverse and prop.mapper or prop.parent
- bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
- Visitor().traverse(criterion)
- return criterion
- _with_lazy_criterion = classmethod(_with_lazy_criterion)
-
def query_from_parent(cls, instance, property, **kwargs):
"""return a newly constructed Query object, with criterion corresponding to
mapper = object_mapper(instance)
prop = mapper.get_property(property, resolve_synonyms=True)
target = prop.mapper
- criterion = cls._with_lazy_criterion(instance, prop)
+ criterion = prop.compare(instance, value_is_parent=True)
return Query(target, **kwargs).filter(criterion)
query_from_parent = classmethod(query_from_parent)
raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
else:
prop = mapper.get_property(property, resolve_synonyms=True)
- return self.filter(Query._with_lazy_criterion(instance, prop))
+ return self.filter(prop.compare(instance, value_is_parent=True))
def add_entity(self, entity):
"""add a mapped entity to the list of result columns to be returned.
for key, value in kwargs.iteritems():
prop = joinpoint.get_property(key, resolve_synonyms=True)
- if isinstance(prop, properties.PropertyLoader):
- c = self._with_lazy_criterion(value, prop, True) # & self.join_via(keys[:-1]) - use aliasized join feature
- else:
- c = prop.compare(value) # & self.join_via(keys) - use aliasized join feature
+ c = prop.compare(value)
+
if alias is not None:
sql_util.ClauseAdapter(alias).traverse(c)
if clause is None:
for key, value in params.iteritems():
(keys, prop) = self._locate_prop(key, start=start)
if isinstance(prop, properties.PropertyLoader):
- c = self._with_lazy_criterion(value, prop, True) & self.join_via(keys[:-1])
+ c = prop.compare(value) & self.join_via(keys[:-1])
else:
c = prop.compare(value) & self.join_via(keys)
if clause is None:
return False
else:
return True
- sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True)
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator)
def _init_scalar_attribute(self):
self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
coltype = self.columns[0].type
- sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable())
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
def create_row_processor(self, selectcontext, mapper, row):
if self.is_composite:
def init_class_attribute(self):
self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
- sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable())
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
def setup_query(self, context, **kwargs):
if self.group is not None and context.attributes.get(('undefer', self.group), False):
def _register_attribute(self, class_, callable_=None):
self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
- sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_)
+ sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator)
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):
sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
return (execute, None)
- def _create_lazy_clause(cls, prop, reverse_direction=False):
+ def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='):
(primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
binds = {}
rightcol = find_column_in_expr(binary.right)
if leftcol is None or rightcol is None:
return
+
+ # TODO: comprehensive negation support for expressions
+ if op == '!=' and binary.operator == '==':
+ binary.operator = '!='
+
if should_bind(leftcol, rightcol):
col = leftcol
binary.left = binds.setdefault(leftcol,
from sqlalchemy import util, exceptions, logging
from sqlalchemy import types as sqltypes
-import string, re, sets
+import string, re, sets, operator
__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
def _negate(self):
return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
-class _CompareMixin(object):
- """Defines comparison operations for ``ClauseElement`` instances.
+
+class Comparator(object):
+ """defines comparison and math operations"""
+
+ def like_op(a, b):
+ return a.like(b)
+ like_op = staticmethod(like_op)
- This is a mixin class that adds the capability to produce ``ClauseElement``
- instances based on regular Python operators.
- These operations are achieved using Python's operator overload methods
- (i.e. ``__eq__()``, ``__ne__()``, etc.
+ def between_op(a, b):
+ return a.between(b)
+ between_op = staticmethod(between_op)
- Overridden operators include all comparison operators (i.e. '==', '!=', '<'),
- math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate
- to ``AND`` and ``OR`` respectively.
-
- Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``,
- ``DISTINCT``, etc.
+ def in_op(a, b):
+ return a.in_(b)
+ in_op = staticmethod(in_op)
- """
+ def startswith_op(a, b):
+ return a.startswith(b)
+ startswith_op = staticmethod(startswith_op)
+
+ def endswith_op(a, b):
+ return a.endswith(b)
+ endswith_op = staticmethod(endswith_op)
+
+ def compare_self(self):
+ raise NotImplementedError()
+
+ def operate(self, op, other):
+ raise NotImplementedError()
+ def reverse_operate(self, op, other):
+ raise NotImplementedError()
+
def __lt__(self, other):
- return self._compare('<', other)
+ return self.operate(operator.lt, other)
def __le__(self, other):
- return self._compare('<=', other)
+ return self.operate(operator.le, other)
def __eq__(self, other):
- return self._compare('=', other)
+ return self.operate(operator.eq, other)
def __ne__(self, other):
- return self._compare('!=', other)
+ return self.operate(operator.ne, other)
def __gt__(self, other):
- return self._compare('>', other)
+ return self.operate(operator.gt, other)
def __ge__(self, other):
- return self._compare('>=', other)
+ return self.operate(operator.ge, other)
def like(self, other):
- """produce a ``LIKE`` clause."""
- return self._compare('LIKE', other)
+ return self.operate(Comparator.like_op, other)
+
+ def in_(self, *other):
+ return self.operate(Comparator.in_op, other)
+
+ def startswith(self, other):
+ return self.operate(Comparator.startswith_op, other)
+
+ def endswith(self, other):
+ return self.operate(Comparator.endswith_op, other)
+
+ def __radd__(self, other):
+ return self.reverse_operate(operator.add, other)
+
+ def __rsub__(self, other):
+ return self.reverse_operate(operator.sub, other)
+
+ def __rmul__(self, other):
+ return self.reverse_operate(operator.mul, other)
+
+ def __rdiv__(self, other):
+ return self.reverse_operate(operator.div, other)
+
+ def between(self, cleft, cright):
+ return self.operate(Comparator.between_op, (cleft, cright))
+
+ def __add__(self, other):
+ return self.operate(operator.add, other)
+
+ def __sub__(self, other):
+ return self.operate(operator.sub, other)
+
+ def __mul__(self, other):
+ return self.operate(operator.mul, other)
+
+ def __div__(self, other):
+ return self.operate(operator.div, other)
+
+ def __mod__(self, other):
+ return self.operate(operator.mod, other)
+
+ def __truediv__(self, other):
+ return self.operate(operator.truediv, other)
+
+class _CompareMixin(Comparator):
+ """Defines comparison and math operations for ``ClauseElement`` instances."""
+
+ def __compare(self, operator, obj, negate=None):
+ if obj is None or isinstance(obj, _Null):
+ if operator == '=':
+ return _BinaryExpression(self.compare_self(), null(), 'IS', negate='IS NOT')
+ elif operator == '!=':
+ return _BinaryExpression(self.compare_self(), null(), 'IS NOT', negate='IS')
+ else:
+ raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+ else:
+ obj = self._check_literal(obj)
+
+ return _BinaryExpression(self.compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate)
+
+ def __operate(self, operator, obj):
+ obj = self._check_literal(obj)
+ return _BinaryExpression(self.compare_self(), obj, operator, type=self._compare_type(obj))
+
+ operators = {
+ operator.add : (__operate, '+'),
+ operator.mul : (__operate, '*'),
+ operator.sub : (__operate, '-'),
+ operator.div : (__operate, '/'),
+ operator.mod : (__operate, '%'),
+ operator.truediv : (__operate, '/'),
+ operator.lt : (__compare, '<', '=>'),
+ operator.le : (__compare, '<=', '>'),
+ operator.ne : (__compare, '!=', '='),
+ operator.gt : (__compare, '>', '<='),
+ operator.ge : (__compare, '>=', '<'),
+ operator.eq : (__compare, '=', '!='),
+ Comparator.like_op : (__compare, 'LIKE', 'NOT LIKE'),
+ }
+
+ def operate(self, op, other):
+ o = _CompareMixin.operators[op]
+ return o[0](self, o[1], other, *o[2:])
+
+ def reverse_operate(self, op, other):
+ return self._bind_param(other).operate(op, self)
def in_(self, *other):
"""produce an ``IN`` clause."""
return self.__eq__( o) #single item -> ==
else:
assert hasattr( o, '_selectable') #better check?
- return self._compare( 'IN', o, negate='NOT IN') #single selectable
+ return self.__compare( 'IN', o, negate='NOT IN') #single selectable
args = []
for o in other:
else:
o = self._bind_param(o)
args.append(o)
- return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
+ return self.__compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
def startswith(self, other):
"""produce the clause ``LIKE '<other>%'``"""
perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String)
- return self._compare('LIKE', other + perc)
+ return self.__compare('LIKE', other + perc)
def endswith(self, other):
"""produce the clause ``LIKE '%<other>'``"""
else:
po = literal('%', type= sqltypes.String) + other
po.type = sqltypes.to_instance( sqltypes.String) #force!
- return self._compare('LIKE', po)
-
- def __radd__(self, other):
- return self._bind_param(other)._operate('+', self)
- def __rsub__(self, other):
- return self._bind_param(other)._operate('-', self)
- def __rmul__(self, other):
- return self._bind_param(other)._operate('*', self)
- def __rdiv__(self, other):
- return self._bind_param(other)._operate('/', self)
+ return self.__compare('LIKE', po)
def label(self, name):
"""produce a column label, i.e. ``<columnname> AS <name>``"""
passed to the generated function.
"""
- return lambda other: self._operate(operator, other)
-
- # and here come the math operators:
-
- def __add__(self, other):
- return self._operate('+', other)
-
- def __sub__(self, other):
- return self._operate('-', other)
-
- def __mul__(self, other):
- return self._operate('*', other)
-
- def __div__(self, other):
- return self._operate('/', other)
-
- def __mod__(self, other):
- return self._operate('%', other)
-
- def __truediv__(self, other):
- return self._operate('/', other)
+ return lambda other: self.__operate(operator, other)
def _bind_param(self, obj):
return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
def _check_literal(self, other):
- if _is_literal(other):
+ if isinstance(other, Comparator):
+ return other.compare_self()
+ elif _is_literal(other):
return self._bind_param(other)
else:
return other
-
- def _compare(self, operator, obj, negate=None):
- if obj is None or isinstance(obj, _Null):
- if operator == '=':
- return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT')
- elif operator == '!=':
- return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS')
- else:
- raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
- else:
- obj = self._check_literal(obj)
-
- return _BinaryExpression(self._compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate)
-
- def _operate(self, operator, obj):
- if _is_literal(obj):
- obj = self._bind_param(obj)
- return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
-
- def _compare_self(self):
- """Allow ``ColumnImpl`` to return its ``Column`` object for
- usage in ``ClauseElements``, all others to just return self.
- """
+
+ def compare_self(self):
+ """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
return self
_label = property(lambda s: s.name)
orig_set = property(lambda s:s.obj.orig_set)
- def _compare_self(self):
+ def compare_self(self):
return self.obj
def _copy_internals(self):
from sqlalchemy import *
+from sqlalchemy import ansisql
from sqlalchemy.orm import *
import testbase
from testbase import Table, Column
from fixtures import *
+import operator
class Base(object):
def __init__(self, **kwargs):
mapper(LocalFoo, table)
assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring)
+class OperatorTest(QueryTest):
+ """test sql.Comparator implementation for MapperProperties"""
+
+ def _test(self, clause, expected):
+ c = str(clause.compile(dialect=ansisql.ANSIDialect()))
+ assert c == expected, "%s != %s" % (c, expected)
+
+ def test_arithmetic(self):
+ create_session().query(User)
+ for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
+ (operator.sub, '-'), (operator.div, '/'),
+ ):
+ for (lhs, rhs, res) in (
+ ('a', User.id, ':users_id %s users.id'),
+ ('a', literal('b'), ':literal %s :literal_1'),
+ (User.id, 'b', 'users.id %s :users_id'),
+ (User.id, literal('b'), 'users.id %s :literal'),
+ (User.id, User.id, 'users.id %s users.id'),
+ (literal('a'), 'b', ':literal %s :literal_1'),
+ (literal('a'), User.id, ':literal %s users.id'),
+ (literal('a'), literal('b'), ':literal %s :literal_1'),
+ ):
+ self._test(py_op(lhs, rhs), res % sql_op)
+
+ def test_comparison(self):
+ create_session().query(User)
+ for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'),
+ (operator.gt, '>', '<'),
+ (operator.eq, '=', '='),
+ (operator.ne, '!=', '!='),
+ (operator.le, '<=', '>='),
+ (operator.ge, '>=', '<=')):
+ for (lhs, rhs, l_sql, r_sql) in (
+ ('a', User.id, ':users_id', 'users.id'),
+ ('a', literal('b'), ':literal_1', ':literal'), # note swap!
+ (User.id, 'b', 'users.id', ':users_id'),
+ (User.id, literal('b'), 'users.id', ':literal'),
+ (User.id, User.id, 'users.id', 'users.id'),
+ (literal('a'), 'b', ':literal', ':literal_1'),
+ (literal('a'), User.id, ':literal', 'users.id'),
+ (literal('a'), literal('b'), ':literal', ':literal_1'),
+ ):
+
+ # the compiled clause should match either (e.g.):
+ # 'a' < 'b' -or- 'b' > 'a'.
+ compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect()))
+ fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
+ rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
+
+ self.assert_(compiled == fwd_sql or compiled == rev_sql,
+ "\n'" + compiled + "'\n does not match\n'" +
+ fwd_sql + "'\n or\n'" + rev_sql + "'")
+
class CompileTest(QueryTest):
def test_deferred(self):
session = create_session()
- s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.id)).compile()
+ s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile()
l = session.query(User).instances(s.execute(emailad = 'jack@bean.com'))
assert [User(id=7)] == l
def test_first(self):
assert User(id=7) == create_session().query(User).first()
- assert create_session().query(User).filter(users.c.id==27).first() is None
+ assert create_session().query(User).filter(User.id==27).first() is None
# more slice tests are available in test/orm/generative.py
assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all()
- assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(users.c.id==9).all()
+ assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all()
def test_binds(self):
assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all()
assert User(id=8) == create_session().query(User)[1]
def test_onefilter(self):
- assert [User(id=8), User(id=9)] == create_session().query(User).filter(users.c.name.endswith('ed')).all()
+ assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all()
- def test_typecheck(self):
- try:
- create_session().query(User).filter(User.name==5)
- assert False
- except exceptions.ArgumentError, e:
- assert str(e) == "filter() argument must be of type sqlalchemy.sql.ClauseElement or string"
class CountTest(QueryTest):
def test_basic(self):
assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all()
- assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(users.c.id==9).all()
+ assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all()
def test_binds(self):
assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all()
)
def testoperators(self):
- self.runtest(
- table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name = :mytable_name"
- )
-
- self.runtest(
- table1.select((table1.c.myid != 12) & ~table1.c.name),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name"
- )
-
- self.runtest(
- literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
- )
# exercise arithmetic operators
for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
"\n'" + compiled + "'\n does not match\n'" +
fwd_sql + "'\n or\n'" + rev_sql + "'")
+ self.runtest(
+ table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND mytable.name != :mytable_name"
+ )
+
+ self.runtest(
+ table1.select((table1.c.myid != 12) & ~and_(table1.c.name=='john', table1.c.name=='ed', table1.c.name=='fred')),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name AND mytable.name = :mytable_name_1 AND mytable.name = :mytable_name_2)"
+ )
+
+ self.runtest(
+ table1.select((table1.c.myid != 12) & ~table1.c.name),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name"
+ )
+
+ self.runtest(
+ literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
+ )
+
# test the op() function, also that its results are further usable in expressions
self.runtest(
table1.select(table1.c.myid.op('hoho')(12)==14),
"SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))")
self.runtest(table.select((5 + table.c.field).in_(5,6)),
"SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)")
- self.runtest(table.select(not_(table.c.field == 5)),
- "SELECT op.field FROM op WHERE NOT op.field = :op_field")
+ self.runtest(table.select(not_(and_(table.c.field == 5, table.c.field == 7))),
+ "SELECT op.field FROM op WHERE NOT (op.field = :op_field AND op.field = :op_field_1)")
self.runtest(table.select(not_(table.c.field) == 5),
"SELECT op.field FROM op WHERE (NOT op.field) = :literal")
self.runtest(table.select((table.c.field == table.c.field).between(False, True)),