]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- modified SQL operator functions to be module-level operators, allowing
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Aug 2007 01:00:44 +0000 (01:00 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Aug 2007 01:00:44 +0000 (01:00 +0000)
  SQL expressions to be pickleable [ticket:735]

- small adjustment to mapper class.__init__ to allow for Py2.6 object.__init__()
  behavior

CHANGES
examples/sharding/attribute_shard.py
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/operators.py [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql.py
test/orm/sharding/shard.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index d88a158e54c8ed9281b9f514dea6c2988e20c4a1..b52de26136af8f18f76ffb9d337bc9fd8cfc572d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -24,6 +24,12 @@ CHANGES
   connection; a close() will ensure that connection transactional state is 
   the same as that which existed on it before being bound to the Session.
 
+- modified SQL operator functions to be module-level operators, allowing
+  SQL expressions to be pickleable [ticket:735]
+
+- small adjustment to mapper class.__init__ to allow for Py2.6 object.__init__()
+  behavior
+    
 0.4.0beta3
 ----------
 
index df3f7467f736c30367c84b32bd5b9baa69abffc9..25da98872927c83db8e5bb2a6dd0a0765ed16575 100644 (file)
@@ -21,8 +21,8 @@ To set up a sharding system, you need:
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm.shard import ShardedSession
-from sqlalchemy.sql import ColumnOperators
-import datetime, operator
+from sqlalchemy.sql import operators
+import datetime
 
 # step 2. databases
 echo = True
@@ -133,9 +133,9 @@ def query_chooser(query):
     class FindContinent(sql.ClauseVisitor):
         def visit_binary(self, binary):
             if binary.left is weather_locations.c.continent:
-                if binary.operator == operator.eq:
+                if binary.operator == operators.eq:
                     ids.append(shard_lookup[binary.right.value])
-                elif binary.operator == ColumnOperators.in_op:
+                elif binary.operator == operators.in_op:
                     for bind in binary.right.clauses:
                         ids.append(shard_lookup[bind.value])
                     
index dd4065f3916b6628a551fd29a0739916df8766a8..5f5e1c1713ec395a50406463888af87cc8e84af5 100644 (file)
@@ -12,7 +12,7 @@ module.
 
 import string, re, sets, operator
 
-from sqlalchemy import schema, sql, engine, util, exceptions
+from sqlalchemy import schema, sql, engine, util, exceptions, operators
 from  sqlalchemy.engine import default
 
 
@@ -50,39 +50,39 @@ BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE)
 ANONYMOUS_LABEL = re.compile(r'{ANON (-?\d+) (.*)}')
 
 OPERATORS =  {
-    operator.and_ : 'AND',
-    operator.or_ : 'OR',
-    operator.inv : 'NOT',
-    operator.add : '+',
-    operator.mul : '*',
-    operator.sub : '-',
-    operator.div : '/',
-    operator.mod : '%',
-    operator.truediv : '/',
-    operator.lt : '<',
-    operator.le : '<=',
-    operator.ne : '!=',
-    operator.gt : '>',
-    operator.ge : '>=',
-    operator.eq : '=',
-    sql.ColumnOperators.distinct_op : 'DISTINCT',
-    sql.ColumnOperators.concat_op : '||',
-    sql.ColumnOperators.like_op : 'LIKE',
-    sql.ColumnOperators.notlike_op : 'NOT LIKE',
-    sql.ColumnOperators.ilike_op : 'ILIKE',
-    sql.ColumnOperators.notilike_op : 'NOT ILIKE',
-    sql.ColumnOperators.between_op : 'BETWEEN',
-    sql.ColumnOperators.in_op : 'IN',
-    sql.ColumnOperators.notin_op : 'NOT IN',
-    sql.ColumnOperators.comma_op : ', ',
-    sql.ColumnOperators.desc_op : 'DESC',
-    sql.ColumnOperators.asc_op : 'ASC',
+    operators.and_ : 'AND',
+    operators.or_ : 'OR',
+    operators.inv : 'NOT',
+    operators.add : '+',
+    operators.mul : '*',
+    operators.sub : '-',
+    operators.div : '/',
+    operators.mod : '%',
+    operators.truediv : '/',
+    operators.lt : '<',
+    operators.le : '<=',
+    operators.ne : '!=',
+    operators.gt : '>',
+    operators.ge : '>=',
+    operators.eq : '=',
+    operators.distinct_op : 'DISTINCT',
+    operators.concat_op : '||',
+    operators.like_op : 'LIKE',
+    operators.notlike_op : 'NOT LIKE',
+    operators.ilike_op : 'ILIKE',
+    operators.notilike_op : 'NOT ILIKE',
+    operators.between_op : 'BETWEEN',
+    operators.in_op : 'IN',
+    operators.notin_op : 'NOT IN',
+    operators.comma_op : ', ',
+    operators.desc_op : 'DESC',
+    operators.asc_op : 'ASC',
     
-    sql.Operators.from_ : 'FROM',
-    sql.Operators.as_ : 'AS',
-    sql.Operators.exists : 'EXISTS',
-    sql.Operators.is_ : 'IS',
-    sql.Operators.isnot : 'IS NOT'
+    operators.from_ : 'FROM',
+    operators.as_ : 'AS',
+    operators.exists : 'EXISTS',
+    operators.is_ : 'IS',
+    operators.isnot : 'IS NOT'
 }
 
 class ANSIDialect(default.DefaultDialect):
@@ -284,7 +284,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
             if isinstance(label.obj, sql._ColumnClause):
                 self.column_labels[label.obj._label] = labelname
             self.column_labels[label.name] = labelname
-        return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
+        return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
         
     def visit_column(self, column, **kwargs):
         # there is actually somewhat of a ruleset when you would *not* necessarily
@@ -343,7 +343,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
         sep = clauselist.operator
         if sep is None:
             sep = " "
-        elif sep == sql.ColumnOperators.comma_op:
+        elif sep == operators.comma_op:
             sep = ', '
         else:
             sep = " " + self.operator_string(clauselist.operator) + " "
index 4cdc7962a458e8319ac49de8bc7b5322e84be910..303a445525e691efa8770ed431754f70e84e31b0 100644 (file)
@@ -123,10 +123,12 @@ Notes page on the wiki at http://sqlalchemy.org is a good resource for timely
 information affecting MySQL in SQLAlchemy.
 """
 
-import re, datetime, inspect, warnings, operator, sys
+import re, datetime, inspect, warnings, sys
 from array import array as _array
 
 from sqlalchemy import ansisql, exceptions, logging, schema, sql, util
+from sqlalchemy import operators as sql_operators
+
 from sqlalchemy.engine import base as engine_base, default
 import sqlalchemy.types as sqltypes
 
@@ -1735,9 +1737,9 @@ class MySQLCompiler(ansisql.ANSICompiler):
     operators = ansisql.ANSICompiler.operators.copy()
     operators.update(
         {
-            sql.ColumnOperators.concat_op: \
+            sql_operators.concat_op: \
               lambda x, y: "concat(%s, %s)" % (x, y),
-            operator.mod: '%%'
+            sql_operators.mod: '%%'
         }
     )
 
diff --git a/lib/sqlalchemy/operators.py b/lib/sqlalchemy/operators.py
new file mode 100644 (file)
index 0000000..b8aca3d
--- /dev/null
@@ -0,0 +1,61 @@
+"""define opeators used in SQL expressions"""
+
+from operator import and_, or_, inv, add, mul, sub, div, mod, truediv, lt, le, ne, gt, ge, eq
+
+def from_():
+    raise NotImplementedError()
+
+def as_():
+    raise NotImplementedError()
+
+def exists():
+    raise NotImplementedError()
+
+def is_():
+    raise NotImplementedError()
+
+def isnot():
+    raise NotImplementedError()
+
+def like_op(a, b):
+    return a.like(b)
+
+def notlike_op(a, b):
+    raise NotImplementedError()
+
+def ilike_op(a, b):
+    return a.ilike(b)
+
+def notilike_op(a, b):
+    raise NotImplementedError()
+
+def between_op(a, b):
+    return a.between(b)
+
+def in_op(a, b):
+    return a.in_(*b)
+
+def notin_op(a, b):
+    raise NotImplementedError()
+
+def distinct_op(a):
+    return a.distinct()
+
+def startswith_op(a, b):
+    return a.startswith(b)
+
+def endswith_op(a, b):
+    return a.endswith(b)
+
+def comma_op(a, b):
+    raise NotImplementedError()
+
+def concat_op(a, b):
+    return a.concat(b)
+
+def desc_op(a):
+    return a.desc()
+
+def asc_op(a):
+    return a.asc()
+
index 30a9525f17ecde25bdc7f510076e0bee5298f579..b4836841774503f790e74ff9ff6c11da5fe9abc0 100644 (file)
@@ -670,14 +670,13 @@ class Mapper(object):
         attribute_manager.reset_class_managed(self.class_)
 
         oldinit = self.class_.__init__
-        if oldinit is object.__init__:
-            oldinit = None
+        doinit = oldinit is not None and oldinit is not object.__init__
             
         def init(instance, *args, **kwargs):
             self.compile()
             self.extension.init_instance(self, self.class_, oldinit, instance, args, kwargs)
 
-            if oldinit is not None:
+            if doinit:
                 try:
                     oldinit(instance, *args, **kwargs)
                 except:
index bad60bb3c014f58525169eebd4edf3310ac86372..2f1cd4b91a89206c3b572b0fc802cd4d04090b61 100644 (file)
@@ -24,9 +24,9 @@ are less guaranteed to stay the same in future releases.
 
 """
 
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exceptions, operators
 from sqlalchemy import types as sqltypes
-import re, operator
+import re
 
 __all__ = ['Alias', 'ClauseElement', 'ClauseParameters',
            'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
@@ -47,7 +47,7 @@ def desc(column):
 
       order_by = [desc(table1.mycol)]
     """
-    return _UnaryExpression(column, modifier=ColumnOperators.desc_op)
+    return _UnaryExpression(column, modifier=operators.desc_op)
 
 def asc(column):
     """Return an ascending ``ORDER BY`` clause element.
@@ -56,7 +56,7 @@ def asc(column):
 
       order_by = [asc(table1.mycol)]
     """
-    return _UnaryExpression(column, modifier=ColumnOperators.asc_op)
+    return _UnaryExpression(column, modifier=operators.asc_op)
 
 def outerjoin(left, right, onclause=None, **kwargs):
     """Return an ``OUTER JOIN`` clause element.
@@ -337,7 +337,7 @@ def and_(*clauses):
     """
     if len(clauses) == 1:
         return clauses[0]
-    return ClauseList(operator=operator.and_, *clauses)
+    return ClauseList(operator=operators.and_, *clauses)
 
 def or_(*clauses):
     """Join a list of clauses together using the ``OR`` operator.
@@ -348,7 +348,7 @@ def or_(*clauses):
 
     if len(clauses) == 1:
         return clauses[0]
-    return ClauseList(operator=operator.or_, *clauses)
+    return ClauseList(operator=operators.or_, *clauses)
 
 def not_(clause):
     """Return a negation of the given clause, i.e. ``NOT(clause)``.
@@ -357,12 +357,12 @@ def not_(clause):
     subclasses to produce the same result.
     """
 
-    return operator.inv(clause)
+    return operators.inv(clause)
 
 def distinct(expr):
     """return a ``DISTINCT`` clause."""
     
-    return _UnaryExpression(expr, operator=ColumnOperators.distinct_op)
+    return _UnaryExpression(expr, operator=operators.distinct_op)
 
 def between(ctest, cleft, cright):
     """Return a ``BETWEEN`` predicate clause.
@@ -374,7 +374,7 @@ def between(ctest, cleft, cright):
     """
 
     ctest = _literal_as_binds(ctest)
-    return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op)
+    return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operators.and_, group=False), operators.between_op)
 
 
 def case(whens, value=None, else_=None):
@@ -420,7 +420,7 @@ def cast(clause, totype, **kwargs):
 def extract(field, expr):
     """Return the clause ``extract(field FROM expr)``."""
 
-    expr = _BinaryExpression(text(field), expr, Operators.from_)
+    expr = _BinaryExpression(text(field), expr, operators.from_)
     return func.extract(expr)
 
 def exists(*args, **kwargs):
@@ -1195,38 +1195,18 @@ class ClauseElement(object):
         if hasattr(self, 'negation_clause'):
             return self.negation_clause
         else:
-            return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None)
+            return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None)
 
 
 class Operators(object):
-    def from_():
-        raise NotImplementedError()
-    from_ = staticmethod(from_)
-    
-    def as_():
-        raise NotImplementedError()
-    as_ = staticmethod(as_)
-    
-    def exists():
-        raise NotImplementedError()
-    exists = staticmethod(exists)
-
-    def is_():
-        raise NotImplementedError()
-    is_ = staticmethod(is_)
-    
-    def isnot():
-        raise NotImplementedError()
-    isnot = staticmethod(isnot)
-    
     def __and__(self, other):
-        return self.operate(operator.and_, other)
+        return self.operate(operators.and_, other)
 
     def __or__(self, other):
-        return self.operate(operator.or_, other)
+        return self.operate(operators.or_, other)
 
     def __invert__(self):
-        return self.operate(operator.inv)
+        return self.operate(operators.inv)
 
     def clause_element(self):
         raise NotImplementedError()
@@ -1239,137 +1219,80 @@ class Operators(object):
 
 class ColumnOperators(Operators):
     """defines comparison and math operations"""
-
-    def like_op(a, b):
-        return a.like(b)
-    like_op = staticmethod(like_op)
-    
-    def notlike_op(a, b):
-        raise NotImplementedError()
-    notlike_op = staticmethod(notlike_op)
-
-    def ilike_op(a, b):
-        return a.ilike(b)
-    ilike_op = staticmethod(ilike_op)
-    
-    def notilike_op(a, b):
-        raise NotImplementedError()
-    notilike_op = staticmethod(notilike_op)
-    
-    def between_op(a, b):
-        return a.between(b)
-    between_op = staticmethod(between_op)
-    
-    def in_op(a, b):
-        return a.in_(*b)
-    in_op = staticmethod(in_op)
-
-    def notin_op(a, b):
-        raise NotImplementedError()
-    notin_op = staticmethod(notin_op)
-    
-    def distinct_op(a):
-        return a.distinct()
-    distinct_op = staticmethod(distinct_op)
-    
-    def startswith_op(a, b):
-        return a.startswith(b)
-    startswith_op = staticmethod(startswith_op)
-    
-    def endswith_op(a, b):
-        return a.endswith(b)
-    endswith_op = staticmethod(endswith_op)
-
-    def comma_op(a, b):
-        raise NotImplementedError()
-    comma_op = staticmethod(comma_op)
-
-    def concat_op(a, b):
-        return a.concat(b)
-    concat_op = staticmethod(concat_op)
-    
-    def desc_op(a):
-        return a.desc()
-    desc_op = staticmethod(desc_op)
-
-    def asc_op(a):
-        return a.asc()
-    asc_op = staticmethod(asc_op)
-    
     def __lt__(self, other):
-        return self.operate(operator.lt, other)
+        return self.operate(operators.lt, other)
 
     def __le__(self, other):
-        return self.operate(operator.le, other)
+        return self.operate(operators.le, other)
 
     def __eq__(self, other):
-        return self.operate(operator.eq, other)
+        return self.operate(operators.eq, other)
 
     def __ne__(self, other):
-        return self.operate(operator.ne, other)
+        return self.operate(operators.ne, other)
 
     def __gt__(self, other):
-        return self.operate(operator.gt, other)
+        return self.operate(operators.gt, other)
 
     def __ge__(self, other):
-        return self.operate(operator.ge, other)
+        return self.operate(operators.ge, other)
 
     def concat(self, other):
-        return self.operate(ColumnOperators.concat_op, other)
+        return self.operate(operators.concat_op, other)
         
     def like(self, other):
-        return self.operate(ColumnOperators.like_op, other)
+        return self.operate(operators.like_op, other)
     
     def in_(self, *other):
-        return self.operate(ColumnOperators.in_op, other)
+        return self.operate(operators.in_op, other)
     
     def startswith(self, other):
-        return self.operate(ColumnOperators.startswith_op, other)
+        return self.operate(operators.startswith_op, other)
 
     def endswith(self, other):
-        return self.operate(ColumnOperators.endswith_op, other)
+        return self.operate(operators.endswith_op, other)
     
     def desc(self):
-        return self.operate(ColumnOperators.desc_op)
+        return self.operate(operators.desc_op)
         
     def asc(self):
-        return self.operate(ColumnOperators.asc_op)
+        return self.operate(operators.asc_op)
         
     def __radd__(self, other):
-        return self.reverse_operate(operator.add, other)
+        return self.reverse_operate(operators.add, other)
 
     def __rsub__(self, other):
-        return self.reverse_operate(operator.sub, other)
+        return self.reverse_operate(operators.sub, other)
 
     def __rmul__(self, other):
-        return self.reverse_operate(operator.mul, other)
+        return self.reverse_operate(operators.mul, other)
 
     def __rdiv__(self, other):
-        return self.reverse_operate(operator.div, other)
+        return self.reverse_operate(operators.div, other)
 
     def between(self, cleft, cright):
-        return self.operate(ColumnOperators.between_op, cleft, cright)
+        return self.operate(operators.between_op, cleft, cright)
 
     def distinct(self):
-        return self.operate(ColumnOperators.distinct_op)
+        return self.operate(operators.distinct_op)
         
     def __add__(self, other):
-        return self.operate(operator.add, other)
+        return self.operate(operators.add, other)
 
     def __sub__(self, other):
-        return self.operate(operator.sub, other)
+        return self.operate(operators.sub, other)
 
     def __mul__(self, other):
-        return self.operate(operator.mul, other)
+        return self.operate(operators.mul, other)
 
     def __div__(self, other):
-        return self.operate(operator.div, other)
+        return self.operate(operators.div, other)
 
     def __mod__(self, other):
-        return self.operate(operator.mod, other)
+        return self.operate(operators.mod, other)
 
     def __truediv__(self, other):
-        return self.operate(operator.truediv, other)
+        return self.operate(operators.truediv, other)
 
 # precedence ordering for common operators.  if an operator is not present in this list,
 # it will be parenthesized when grouped against other operators
@@ -1377,35 +1300,35 @@ _smallest = object()
 _largest = object()
 
 PRECEDENCE = {
-    Operators.from_:15,
-    operator.mul:7,
-    operator.div:7,
-    operator.mod:7,
-    operator.add:6,
-    operator.sub:6,
-    ColumnOperators.concat_op:6,
-    ColumnOperators.ilike_op:5,
-    ColumnOperators.notilike_op:5,
-    ColumnOperators.like_op:5,
-    ColumnOperators.notlike_op:5,
-    ColumnOperators.in_op:5,
-    ColumnOperators.notin_op:5,
-    Operators.is_:5,
-    Operators.isnot:5,
-    operator.eq:5,
-    operator.ne:5,
-    operator.gt:5,
-    operator.lt:5,
-    operator.ge:5,
-    operator.le:5,
-    ColumnOperators.between_op:5,
-    ColumnOperators.distinct_op:5,
-    operator.inv:4,
-    operator.and_:3,
-    operator.or_:2,
-    ColumnOperators.comma_op:-1,
-    Operators.as_:-1,
-    Operators.exists:0,
+    operators.from_:15,
+    operators.mul:7,
+    operators.div:7,
+    operators.mod:7,
+    operators.add:6,
+    operators.sub:6,
+    operators.concat_op:6,
+    operators.ilike_op:5,
+    operators.notilike_op:5,
+    operators.like_op:5,
+    operators.notlike_op:5,
+    operators.in_op:5,
+    operators.notin_op:5,
+    operators.is_:5,
+    operators.isnot:5,
+    operators.eq:5,
+    operators.ne:5,
+    operators.gt:5,
+    operators.lt:5,
+    operators.ge:5,
+    operators.le:5,
+    operators.between_op:5,
+    operators.distinct_op:5,
+    operators.inv:4,
+    operators.and_:3,
+    operators.or_:2,
+    operators.comma_op:-1,
+    operators.as_:-1,
+    operators.exists:0,
     _smallest: -1000,
     _largest: 1000
 }
@@ -1415,10 +1338,10 @@ class _CompareMixin(ColumnOperators):
 
     def __compare(self, op, obj, negate=None):
         if obj is None or isinstance(obj, _Null):
-            if op == operator.eq:
-                return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot)
-            elif op == operator.ne:
-                return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_)
+            if op == operators.eq:
+                return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot)
+            elif op == operators.ne:
+                return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_)
             else:
                 raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
         else:
@@ -1433,25 +1356,25 @@ class _CompareMixin(ColumnOperators):
         type_ = self._compare_type(obj)
         
         # TODO: generalize operator overloading like this out into the types module
-        if op == operator.add and isinstance(type_, (sqltypes.Concatenable)):
-            op = ColumnOperators.concat_op
+        if op == operators.add and isinstance(type_, (sqltypes.Concatenable)):
+            op = operators.concat_op
         
         return _BinaryExpression(self.expression_element(), obj, op, type_=type_)
 
     operators = {
-        operator.add : (__operate,),
-        operator.mul : (__operate,),
-        operator.sub : (__operate,),
-        operator.div : (__operate,),
-        operator.mod : (__operate,),
-        operator.truediv : (__operate,),
-        operator.lt : (__compare, operator.ge),
-        operator.le : (__compare, operator.gt),
-        operator.ne : (__compare, operator.eq),
-        operator.gt : (__compare, operator.le),
-        operator.ge : (__compare, operator.lt),
-        operator.eq : (__compare, operator.ne),
-        ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op),
+        operators.add : (__operate,),
+        operators.mul : (__operate,),
+        operators.sub : (__operate,),
+        operators.div : (__operate,),
+        operators.mod : (__operate,),
+        operators.truediv : (__operate,),
+        operators.lt : (__compare, operators.ge),
+        operators.le : (__compare, operators.gt),
+        operators.ne : (__compare, operators.eq),
+        operators.gt : (__compare, operators.le),
+        operators.ge : (__compare, operators.lt),
+        operators.eq : (__compare, operators.ne),
+        operators.like_op : (__compare, operators.notlike_op),
     }
 
     def operate(self, op, *other):
@@ -1462,7 +1385,7 @@ class _CompareMixin(ColumnOperators):
         return self._bind_param(other).operate(op, self)
 
     def in_(self, *other):
-        return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other)
+        return self._in_impl(operators.in_op, operators.notin_op, *other)
         
     def _in_impl(self, op, negate_op, *other):
         if len(other) == 0:
@@ -1489,7 +1412,7 @@ class _CompareMixin(ColumnOperators):
         """produce the clause ``LIKE '<other>%'``"""
 
         perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String)
-        return self.__compare(ColumnOperators.like_op, other + perc)
+        return self.__compare(operators.like_op, other + perc)
 
     def endswith(self, other):
         """produce the clause ``LIKE '%<other>'``"""
@@ -1498,7 +1421,7 @@ class _CompareMixin(ColumnOperators):
         else:
             po = literal('%', type_=sqltypes.String) + other
             po.type = sqltypes.to_instance(sqltypes.String)     #force!
-        return self.__compare(ColumnOperators.like_op, po)
+        return self.__compare(operators.like_op, po)
 
     def label(self, name):
         """produce a column label, i.e. ``<columnname> AS <name>``"""
@@ -1516,12 +1439,12 @@ class _CompareMixin(ColumnOperators):
         
     def distinct(self):
         """produce a DISTINCT clause, i.e. ``DISTINCT <columnname>``"""
-        return _UnaryExpression(self, operator=ColumnOperators.distinct_op)
+        return _UnaryExpression(self, operator=operators.distinct_op)
 
     def between(self, cleft, cright):
         """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
 
-        return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op)
+        return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operators.and_, group=False), operators.between_op)
 
     def op(self, operator):
         """produce a generic operator function.
@@ -2132,7 +2055,7 @@ class ClauseList(ClauseElement):
     
     def __init__(self, *clauses, **kwargs):
         self.clauses = []
-        self.operator = kwargs.pop('operator', ColumnOperators.comma_op)
+        self.operator = kwargs.pop('operator', operators.comma_op)
         self.group = kwargs.pop('group', True)
         self.group_contents = kwargs.pop('group_contents', True)
         for c in clauses:
@@ -2248,7 +2171,7 @@ class _Function(_CalculatedClause, FromClause):
 
     def __init__(self, name, *clauses, **kwargs):
         self.packagenames = kwargs.get('packagenames', None) or []
-        kwargs['operator'] = ColumnOperators.comma_op
+        kwargs['operator'] = operators.comma_op
         _CalculatedClause.__init__(self, name, **kwargs)
         for c in clauses:
             self.append(c)
@@ -2365,7 +2288,7 @@ class _BinaryExpression(ColumnElement):
                 (
                     self.left.compare(other.left) and self.right.compare(other.right)
                     or (
-                        self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and
+                        self.operator in [operators.eq, operators.ne, operators.add, operators.mul] and
                         self.left.compare(other.right) and self.right.compare(other.left)
                     )
                 )
@@ -2390,7 +2313,7 @@ class _Exists(_UnaryExpression):
     def __init__(self, *args, **kwargs):
         kwargs['correlate'] = True
         s = select(*args, **kwargs).as_scalar().self_group()
-        _UnaryExpression.__init__(self, s, operator=Operators.exists)
+        _UnaryExpression.__init__(self, s, operator=operators.exists)
 
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
@@ -2444,7 +2367,7 @@ class Join(FromClause):
                     
         class BinaryVisitor(ClauseVisitor):
             def visit_binary(self, binary):
-                if binary.operator == operator.eq:
+                if binary.operator == operators.eq:
                     add_equiv(binary.left, binary.right)
         BinaryVisitor().traverse(self.onclause)
         
@@ -2526,7 +2449,7 @@ class Join(FromClause):
             equivs = util.Set()
         class LocateEquivs(NoColumnVisitor):
             def visit_binary(self, binary):
-                if binary.operator == operator.eq and binary.left.name == binary.right.name:
+                if binary.operator == operators.eq and binary.left.name == binary.right.name:
                     equivs.add(binary.right)
                     equivs.add(binary.left)
         LocateEquivs().traverse(self.onclause)
@@ -2739,7 +2662,7 @@ class _Label(ColumnElement):
             obj = obj.obj
         self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
 
-        self.obj = obj.self_group(against=Operators.as_)
+        self.obj = obj.self_group(against=operators.as_)
         self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
 
     key = property(lambda s: s.name)
@@ -3314,7 +3237,7 @@ class Select(_SelectBaseMixin, FromClause):
         column = _literal_as_column(column)
 
         if isinstance(column, _ScalarSelect):
-            column = column.self_group(against=ColumnOperators.comma_op)
+            column = column.self_group(against=operators.comma_op)
 
         self._raw_columns.append(column)
     
index e80f8ab057e90d86ceb88b856776d05671b304d0..d4b4c249a7f3dee8fed3425a9cb0ad5e7c91ece0 100644 (file)
@@ -3,8 +3,8 @@ from sqlalchemy import *
 from sqlalchemy.orm import *
 
 from sqlalchemy.orm.shard import ShardedSession
-from sqlalchemy.sql import ColumnOperators
-import datetime, operator, os
+from sqlalchemy.sql import operators
+import datetime, os
 from testlib import PersistTest
 
 # TODO: ShardTest can be turned into a base for further subclasses
@@ -81,9 +81,9 @@ class ShardTest(PersistTest):
             class FindContinent(sql.ClauseVisitor):
                 def visit_binary(self, binary):
                     if binary.left is weather_locations.c.continent:
-                        if binary.operator == operator.eq:
+                        if binary.operator == operators.eq:
                             ids.append(shard_lookup[binary.right.value])
-                        elif binary.operator == ColumnOperators.in_op:
+                        elif binary.operator == operators.in_op:
                             for bind in binary.right.clauses:
                                 ids.append(shard_lookup[bind.value])
 
index 865f1ec48b5f3a0a48e1bc692fbc7bfec1b0fea5..c075d4e3b462fd20bdc3c85c7d3fa9a1b151d125 100644 (file)
@@ -1,6 +1,7 @@
 import testbase
 import re, operator
 from sqlalchemy import *
+from sqlalchemy import util
 from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
 from testlib import *
 
@@ -406,6 +407,10 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
             table1.select(table1.c.myid.op('hoho')(12)==14),
             "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (mytable.myid hoho :mytable_myid) = :literal"
         )
+        
+        # test that clauses can be pickled (operators need to be module-level, etc.)
+        clause = (table1.c.myid == 12) & table1.c.myid.between(15, 20) & table1.c.myid.like('hoho')
+        assert str(clause) == str(util.pickle.loads(util.pickle.dumps(clause)))
 
     def testunicodestartswith(self):
         string = u"hi \xf6 \xf5"