]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- all "type" keyword arguments, such as those to bindparam(), column(),
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jul 2007 07:11:55 +0000 (07:11 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jul 2007 07:11:55 +0000 (07:11 +0000)
  Column(), and func.<something>(), renamed to "type_".  those objects
  still name their "type" attribute as "type".
- new SQL operator implementation which removes all hardcoded operators
  from expression structures and moves them into compilation;
  allows greater flexibility of operator compilation; for example, "+"
  compiles to "||" when used in a string context, or "concat(a,b)" on
  MySQL; whereas in a numeric context it compiles to "+".  fixes [ticket:475].
- major cruft cleanup in ANSICompiler regarding its processing of update/insert
  bind parameters.  code is actually readable !
- a clause element embedded in an UPDATE, i.e. for a correlated update, uses
  standard "grouping" rules now to place parenthesis.  Doesn't change much, except
  if you embed a text() clause in there, it will not be automatically parenthesized
  (place parens in the text() manually).

24 files changed:
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/types.py
test/engine/bind.py
test/engine/parseconnect.py
test/orm/inheritance/polymorph2.py
test/orm/query.py
test/sql/case_statement.py
test/sql/defaults.py
test/sql/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 235e77644a24bb82e475edd19a5cf1fa75299b9e..8aa7f6dd6963437e6714174ab6ad8d8c9eee89d0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -87,6 +87,9 @@
     style of Hibernate
     
 - sql
+  - all "type" keyword arguments, such as those to bindparam(), column(),
+    Column(), and func.<something>(), renamed to "type_".  those objects
+    still name their "type" attribute as "type".
   - transactions:
     - added context manager (with statement) support for transactions
     - added support for two phase commit, works with mysql and postgres so far.
   - MetaData:
     - DynamicMetaData has been renamed to ThreadLocalMetaData
     - BoundMetaData has been removed- regular MetaData is equivalent
+  - new SQL operator implementation which removes all hardcoded operators
+    from expression structures and moves them into compilation; 
+    allows greater flexibility of operator compilation; for example, "+" 
+    compiles to "||" when used in a string context, or "concat(a,b)" on 
+    MySQL; whereas in a numeric context it compiles to "+".  fixes [ticket:475].
   - "anonymous" alias and label names are now generated at SQL compilation
     time in a completely deterministic fashion...no more random hex IDs
   - significant architectural overhaul to SQL elements (ClauseElement).  
index 361fd7b1ea5e4d22d6f2a00dd3b3cfd27e63ed73..24ee13e47264c2b7501cd8fdd18cd59beefaa379 100644 (file)
@@ -12,7 +12,7 @@ module.
 
 from sqlalchemy import schema, sql, engine, util, sql_util, exceptions
 from  sqlalchemy.engine import default
-import string, re, sets, weakref, random
+import string, re, sets, random, operator
 
 ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
                                 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
@@ -43,6 +43,38 @@ ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
 BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
 BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE)
 
+OPERATORS =  {
+    operator.and_ : 'AND',
+    operator.or_ : 'OR',
+    operator.inv : 'NOT',
+    operator.add : '+',
+    operator.mul : '*',
+    operator.sub : '-',
+    operator.div : '/',
+    operator.mod : '%',
+    operator.truediv : '/',
+    operator.lt : '<',
+    operator.le : '<=',
+    operator.ne : '!=',
+    operator.gt : '>',
+    operator.ge : '>=',
+    operator.eq : '=',
+    sql.ColumnOperators.concat_op : '||',
+    sql.ColumnOperators.like_op : 'LIKE',
+    sql.ColumnOperators.notlike_op : 'NOT LIKE',
+    sql.ColumnOperators.ilike_op : 'ILIKE',
+    sql.ColumnOperators.notilike_op : 'NOT ILIKE',
+    sql.ColumnOperators.between_op : 'BETWEEN',
+    sql.ColumnOperators.in_op : 'IN',
+    sql.ColumnOperators.notin_op : 'NOT IN',
+    sql.ColumnOperators.comma_op : ', ',
+    sql.Operators.from_ : 'FROM',
+    sql.Operators.as_ : 'AS',
+    sql.Operators.exists : 'EXISTS',
+    sql.Operators.is_ : 'IS',
+    sql.Operators.isnot : 'IS NOT'
+}
+
 class ANSIDialect(default.DefaultDialect):
     def __init__(self, cache_identifiers=True, **kwargs):
         super(ANSIDialect,self).__init__(**kwargs)
@@ -77,6 +109,8 @@ class ANSICompiler(engine.Compiled):
 
     __traverse_options__ = {'column_collections':False, 'entry':True}
 
+    operators = OPERATORS
+    
     def __init__(self, dialect, statement, parameters=None, **kwargs):
         """Construct a new ``ANSICompiler`` object.
 
@@ -264,7 +298,7 @@ class ANSICompiler(engine.Compiled):
             if isinstance(label.obj, sql._ColumnClause):
                 self.column_labels[label.obj._label] = labelname
             self.column_labels[label.name] = labelname
-        self.strings[label] = self.strings[label.obj] + " AS "  + self.preparer.format_label(label, labelname)
+        self.strings[label] = " ".join([self.strings[label.obj], self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
         
     def visit_column(self, column):
         # there is actually somewhat of a ruleset when you would *not* necessarily
@@ -317,15 +351,15 @@ class ANSICompiler(engine.Compiled):
     def visit_null(self, null):
         self.strings[null] = 'NULL'
 
-    def visit_clauselist(self, list):
-        sep = list.operator
-        if sep == ',':
-            sep = ', '
-        elif sep is None or sep == " ":
+    def visit_clauselist(self, clauselist):
+        sep = clauselist.operator
+        if sep is None:
             sep = " "
+        elif sep == sql.ColumnOperators.comma_op:
+            sep = ', '
         else:
-            sep = " " + sep + " "
-        self.strings[list] = string.join([s for s in [self.strings[c] for c in list.clauses] if s is not None], sep)
+            sep = " " + self.operator_string(clauselist.operator) + " "
+        self.strings[clauselist] = string.join([s for s in [self.strings[c] for c in clauselist.clauses] if s is not None], sep)
 
     def apply_function_parens(self, func):
         return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
@@ -362,20 +396,20 @@ class ANSICompiler(engine.Compiled):
     def visit_unary(self, unary):
         s = self.strings[unary.element]
         if unary.operator:
-            s = unary.operator + " " + s
+            s = self.operator_string(unary.operator) + " " + s
         if unary.modifier:
             s = s + " " + unary.modifier
         self.strings[unary] = s
         
     def visit_binary(self, binary):
-        result = self.strings[binary.left]
-        if binary.operator is not None:
-            result += " " + self.binary_operator_string(binary)
-        result += " " + self.strings[binary.right]
-        self.strings[binary] = result
-
-    def binary_operator_string(self, binary):
-        return binary.operator
+        op = self.operator_string(binary.operator)
+        if callable(op):
+            self.strings[binary] = op(binary.left, binary.right)
+        else:
+            self.strings[binary] = self.strings[binary.left] + " " + op + " " + self.strings[binary.right]
+        
+    def operator_string(self, operator):
+        return self.operators.get(operator, str(operator))
 
     def visit_bindparam(self, bindparam):
         # apply truncation to the ultimate generated name
@@ -610,151 +644,86 @@ class ANSICompiler(engine.Compiled):
             " ON " + self.strings[join.onclause])
         self.strings[join] = self.froms[join]
 
-    def visit_insert_column_default(self, column, default, parameters):
-        """Called when visiting an ``Insert`` statement.
-
-        For each column in the table that contains a ``ColumnDefault``
-        object, add a blank *placeholder* parameter so the ``Insert``
-        gets compiled with this column's name in its column and
-        ``VALUES`` clauses.
-        """
-
-        parameters.setdefault(column.key, None)
-
-    def visit_update_column_default(self, column, default, parameters):
-        """Called when visiting an ``Update`` statement.
-
-        For each column in the table that contains a ``ColumnDefault``
-        object as an onupdate, add a blank *placeholder* parameter so
-        the ``Update`` gets compiled with this column's name as one of
-        its ``SET`` clauses.
-        """
-
-        parameters.setdefault(column.key, None)
-
-    def visit_insert_sequence(self, column, sequence, parameters):
-        """Called when visiting an ``Insert`` statement.
-
-        This may be overridden compilers that support sequences to
-        place a blank *placeholder* parameter for each column in the
-        table that contains a Sequence object, so the Insert gets
-        compiled with this column's name in its column and ``VALUES``
-        clauses.
-        """
-
-        pass
-
-    def visit_insert_column(self, column, parameters):
-        """Called when visiting an ``Insert`` statement.
-
-        This may be overridden by compilers who disallow NULL columns
-        being set in an ``Insert`` where there is a default value on
-        the column (i.e. postgres), to remove the column for which
-        there is a NULL insert from the parameter list.
-        """
-
-        pass
-
+    def uses_sequences_for_inserts(self):
+        return False
+        
     def visit_insert(self, insert_stmt):
-        # scan the table's columns for defaults that have to be pre-set for an INSERT
-        # add these columns to the parameter list via visit_insert_XXX methods
-        default_params = {}
+
+        # search for columns who will be required to have an explicit bound value.
+        # for inserts, this includes Python-side defaults, columns with sequences for dialects
+        # that support sequences, and primary key columns for dialects that explicitly insert
+        # pre-generated primary key values
+        required_cols = util.Set()
         class DefaultVisitor(schema.SchemaVisitor):
-            def visit_column(s, c):
-                self.visit_insert_column(c, default_params)
+            def visit_column(s, cd):
+                if c.primary_key and self.uses_sequences_for_inserts():
+                    required_cols.add(c)
             def visit_column_default(s, cd):
-                self.visit_insert_column_default(c, cd, default_params)
+                required_cols.add(c)
             def visit_sequence(s, seq):
-                self.visit_insert_sequence(c, seq, default_params)
+                if self.uses_sequences_for_inserts():
+                    required_cols.add(c)
         vis = DefaultVisitor()
         for c in insert_stmt.table.c:
             if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
                 vis.traverse(c)
 
         self.isinsert = True
-        colparams = self._get_colparams(insert_stmt, default_params)
-
-        self.inline_params = util.Set()
-        def create_param(col, p):
-            if isinstance(p, sql._BindParamClause):
-                self.binds[p.key] = p
-                if p.shortname is not None:
-                    self.binds[p.shortname] = p
-                return self.bindparam_string(self._truncate_bindparam(p))
-            else:
-                self.inline_params.add(col)
-                self.traverse(p)
-                if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
-                    return "(" + self.strings[p] + ")"
-                else:
-                    return self.strings[p]
+        colparams = self._get_colparams(insert_stmt, required_cols)
 
         text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
-         " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")")
+         " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
 
         self.strings[insert_stmt] = text
 
     def visit_update(self, update_stmt):
-        # scan the table's columns for onupdates that have to be pre-set for an UPDATE
-        # add these columns to the parameter list via visit_update_XXX methods
-        default_params = {}
+        
+        # search for columns who will be required to have an explicit bound value.
+        # for updates, this includes Python-side "onupdate" defaults.
+        required_cols = util.Set()
         class OnUpdateVisitor(schema.SchemaVisitor):
             def visit_column_onupdate(s, cd):
-                self.visit_update_column_default(c, cd, default_params)
+                required_cols.add(c)
         vis = OnUpdateVisitor()
         for c in update_stmt.table.c:
             if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
                 vis.traverse(c)
 
         self.isupdate = True
-        colparams = self._get_colparams(update_stmt, default_params)
-
-        self.inline_params = util.Set()
-        def create_param(col, p):
-            if isinstance(p, sql._BindParamClause):
-                self.binds[p.key] = p
-                self.binds[p.shortname] = p
-                return self.bindparam_string(self._truncate_bindparam(p))
-            else:
-                self.traverse(p)
-                self.inline_params.add(col)
-                if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
-                    return "(" + self.strings[p] + ")"
-                else:
-                    return self.strings[p]
+        colparams = self._get_colparams(update_stmt, required_cols)
 
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
 
         if update_stmt._whereclause:
             text += " WHERE " + self.strings[update_stmt._whereclause]
 
         self.strings[update_stmt] = text
 
-
-    def _get_colparams(self, stmt, default_params):
-        """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples.
-
-        Each tuple will contain the ``Column`` and a ``ClauseElement``
-        representing the value to be set (usually a ``_BindParamClause``,
-        but could also be other SQL expressions.)
-
-        The list of tuples will determine the columns that are
-        actually rendered into the ``SET``/``VALUES`` clause of the
-        rendered ``UPDATE``/``INSERT`` statement.  It will also
-        determine how to generate the list/dictionary of bind
-        parameters at execution time (i.e. ``get_params()``).
-
-        This list takes into account the `values` keyword specified
-        to the statement, the parameters sent to this Compiled
-        instance, and the default bind parameter values corresponding
-        to the dialect's behavior for otherwise unspecified primary
-        key columns.
+    def _get_colparams(self, stmt, required_cols):
+        """create a set of tuples representing column/string pairs for use 
+        in an INSERT or UPDATE statement.
+        
+        This method may generate new bind params within this compiled
+        based on the given set of "required columns", which are required
+        to have a value set in the statement.
         """
 
+        def create_bind_param(col, value):
+            bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True)
+            self.binds[col.key] = bindparam
+            return self.bindparam_string(self._truncate_bindparam(bindparam))
+
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.parameters is None and stmt.parameters is None:
-            return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns]
+            return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
+
+        def create_clause_param(col, value):
+            self.traverse(value)
+            self.inline_params.add(col)
+            return self.strings[value]
+
+        self.inline_params = util.Set()
 
         def to_col(key):
             if not isinstance(key, sql._ColumnClause):
@@ -773,18 +742,20 @@ class ANSICompiler(engine.Compiled):
             for k, v in stmt.parameters.iteritems():
                 parameters.setdefault(to_col(k), v)
 
-        for k, v in default_params.iteritems():
-            parameters.setdefault(to_col(k), v)
+        for col in required_cols:
+            parameters.setdefault(col, None)
 
         # create a list of column assignment clauses as tuples
         values = []
         for c in stmt.table.columns:
-            if parameters.has_key(c):
+            if c in parameters:
                 value = parameters[c]
                 if sql._is_literal(value):
-                    value = sql.bindparam(c.key, value, type=c.type, unique=True)
+                    value = create_bind_param(c, value)
+                else:
+                    value = create_clause_param(c, value)
                 values.append((c, value))
-
+        
         return values
 
     def visit_delete(self, delete_stmt):
@@ -846,8 +817,6 @@ class ANSISchemaGenerator(ANSISchemaBase):
         for column in table.columns:
             if column.default is not None:
                 self.traverse_single(column.default)
-            #if column.onupdate is not None:
-            #    column.onupdate.accept_visitor(visitor)
 
         self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
 
index 3ff46e09425107eadfea17533a9ebd5fad583aa2..d3f49544dcc955aafbfe98e7e8c90519c54c7707 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import re, datetime, inspect, warnings, weakref
+import re, datetime, inspect, warnings, weakref, operator
 
 from sqlalchemy import sql, schema, ansisql
 from sqlalchemy.engine import default
@@ -1284,6 +1284,14 @@ class _MySQLPythonRowProxy(object):
 
 
 class MySQLCompiler(ansisql.ANSICompiler):
+    operators = ansisql.ANSICompiler.operators.copy()
+    operators.update(
+        {
+            sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
+            operator.mod : '%%'
+        }
+    )
+
     def visit_cast(self, cast):
         if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
             return super(MySQLCompiler, self).visit_cast(cast)
@@ -1309,11 +1317,6 @@ class MySQLCompiler(ansisql.ANSICompiler):
             text += " OFFSET " + str(select._offset)
         return text
         
-    def binary_operator_string(self, binary):
-        if binary.operator == '%':
-            return '%%'
-        else:
-            return ansisql.ANSICompiler.binary_operator_string(self, binary)   
 
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, first_pk=False):
index a2b469a304e8f2e64f252ec6f940fde8a9e12ca5..82388ef871d1441703ee3a31a3aab2fbfa279ea2 100644 (file)
@@ -5,7 +5,7 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 
-import sys, StringIO, string, re, warnings
+import sys, StringIO, string, re, warnings, operator
 
 from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
 from sqlalchemy.engine import default, base
@@ -460,6 +460,13 @@ class OracleCompiler(ansisql.ANSICompiler):
     the use_ansi flag is False.
     """
 
+    operators = ansisql.ANSICompiler.operators.copy()
+    operators.update(
+        {
+            operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+        }
+    )
+
     def default_from(self):
         """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
 
@@ -496,13 +503,8 @@ class OracleCompiler(ansisql.ANSICompiler):
 
         self.traverse_single(self.wheres[join])
 
-    def visit_insert_sequence(self, column, sequence, parameters):
-        """This is the `sequence` equivalent to ``ANSICompiler``'s
-        `visit_insert_column_default` which ensures that the column is
-        present in the generated column list.
-        """
-
-        parameters.setdefault(column.key, None)
+    def uses_sequences_for_inserts(self):
+        return True
 
     def visit_alias(self, alias):
         """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
@@ -571,12 +573,6 @@ class OracleCompiler(ansisql.ANSICompiler):
         else:
             return super(OracleCompiler, self).for_update_clause(select)
 
-    def visit_binary(self, binary):
-        if binary.operator == '%': 
-            self.strings[binary] = ("MOD(%s,%s)"%(self.strings[binary.left], self.strings[binary.right]))
-        else:
-            return ansisql.ANSICompiler.visit_binary(self, binary)
-        
 
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, **kwargs):
index 469614fbb3e618df3a986a9017461f8433985c8f..80a56a3ca64e6936ea3fd7aef8ad3e9a93dd7b79 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import datetime, string, types, re, random, warnings
+import datetime, string, types, re, random, warnings, operator
 
 from sqlalchemy import util, sql, schema, ansisql, exceptions
 from sqlalchemy.engine import base, default
@@ -83,7 +83,7 @@ class PGBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOLEAN"
 
-class PGArray(sqltypes.TypeEngine):
+class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
     def __init__(self, item_type):
         if isinstance(item_type, type):
             item_type = item_type()
@@ -355,7 +355,7 @@ class PGDialect(ansisql.ANSIDialect):
                 ORDER BY a.attnum
             """ % schema_where_clause
 
-            s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
+            s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
             c = connection.execute(s, table_name=table.name,
                                       schema=table.schema)
             rows = c.fetchall()
@@ -525,15 +525,15 @@ class PGDialect(ansisql.ANSIDialect):
         
         
 class PGCompiler(ansisql.ANSICompiler):
-    def visit_insert_column(self, column, parameters):
-        # all column primary key inserts must be explicitly present
-        if column.primary_key:
-            parameters[column.key] = None
+    operators = ansisql.ANSICompiler.operators.copy()
+    operators.update(
+        {
+            operator.mod : '%%'
+        }
+    )
 
-    def visit_insert_sequence(self, column, sequence, parameters):
-        """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures
-        that the column is present in the generated column list"""
-        parameters.setdefault(column.key, None)
+    def uses_sequences_for_inserts(self):
+        return True
 
     def limit_clause(self, select):
         text = ""
@@ -565,14 +565,6 @@ class PGCompiler(ansisql.ANSICompiler):
         else:
             return super(PGCompiler, self).for_update_clause(select)
 
-    def binary_operator_string(self, binary):
-        if isinstance(binary.type, (sqltypes.String, PGArray)) and binary.operator == '+':
-            return '||'
-        elif binary.operator == '%':
-            return '%%'
-        else:
-            return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
index 70cbd0c0e175dbeeb87e328989a39a40bbf9edea..e7abc1f32b875a1a211f3a7e153429e85680e541 100644 (file)
@@ -347,12 +347,6 @@ class SQLiteCompiler(ansisql.ANSICompiler):
         # sqlite has no "FOR UPDATE" AFAICT
         return ''
 
-    def binary_operator_string(self, binary):
-        if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
-            return '||'
-        else:
-            return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
 
     def get_column_specification(self, column, **kwargs):
index 832b56f74a72ed605138dbe6cf0932bcc247d3a3..075d51a538bc1ed0a978ea45755470d68ddfbd61 100644 (file)
@@ -296,6 +296,9 @@ class DefaultExecutionContext(base.ExecutionContext):
         statement.
         """
 
+        # TODO: this calculation of defaults is one of the places SA slows down inserts.
+        # look into optimizing this for a list of params where theres no defaults defined
+        # (i.e. analyze the first batch of params).
         if self.compiled.isinsert:
             if isinstance(self.compiled_parameters, list):
                 plist = self.compiled_parameters
@@ -323,6 +326,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                             self._lastrow_has_defaults = True
                         newid = drunner.get_column_default(c)
                         if newid is not None:
+                            print "GOT GENERATED DEFAULT", c, repr(newid)
                             param.set_value(c.key, newid)
                             if c.primary_key:
                                 last_inserted_ids.append(param.get_processed(c.key))
index e8f3d4e2454f03365a56372b40c7459a36fc6292..c06db69631e7768b489e63d5fddafe7411dddead 100644 (file)
@@ -366,7 +366,7 @@ class ManyToManyDP(DependencyProcessor):
         if len(secondary_delete):
             secondary_delete.sort()
             # TODO: precompile the delete/insert queries?
-            statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type=c.type) for c in self.secondary.c if c.key in associationrow]))
+            statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
             result = connection.execute(statement, secondary_delete)
             if result.supports_sane_rowcount() and result.rowcount != len(secondary_delete):
                 raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (result.rowcount, len(secondary_delete)))
index f353575d90088a050784b6ebb020c6b8405213e6..e1209fabf9619142ffb72a3ef6facb9df0b6b51f 100644 (file)
@@ -337,7 +337,7 @@ class MapperProperty(object):
 
         return operator(self.comparator, value)
 
-class PropComparator(sql.Comparator):
+class PropComparator(sql.ColumnOperators):
     """defines comparison operations for MapperProperty objects"""
 
     def contains_op(a, b):
index f82f713bb8a1d5f4b31beb2b8065c4ee441b4182..eb69fb32c86ab8438e0c1a4e5b1602d075027a21 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.util import ExtensionCarrier
 from sqlalchemy.orm import sync
 from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, EXT_PASS, MapperExtension, SynonymProperty
-import weakref, warnings
+import weakref, warnings, operator
 
 __all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
 
@@ -587,7 +587,7 @@ class Mapper(object):
         
         _get_clause = sql.and_()
         for primary_key in self.primary_key:
-            _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
+            _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True))
         self._get_clause = _get_clause
 
     def _get_equivalent_columns(self):
@@ -620,7 +620,7 @@ class Mapper(object):
 
         result = {}
         def visit_binary(binary):
-            if binary.operator == '=':
+            if binary.operator == operator.eq:
                 if binary.left in result:
                     result[binary.left].add(binary.right)
                 else:
@@ -1221,9 +1221,9 @@ class Mapper(object):
                 mapper = table_to_mapper[table]
                 clause = sql.and_()
                 for col in mapper.pks_by_table[table]:
-                    clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True))
+                    clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
                 if mapper.version_id_col is not None:
-                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True))
+                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
                 statement = table.update(clause)
                 rows = 0
                 supports_sane_rowcount = True
@@ -1358,9 +1358,9 @@ class Mapper(object):
                 delete.sort(comparator)
                 clause = sql.and_()
                 for col in mapper.pks_by_table[table]:
-                    clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True))
+                    clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
                 if mapper.version_id_col is not None:
-                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type=mapper.version_id_col.type, unique=True))
+                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
                 statement = table.delete(clause)
                 c = connection.execute(statement, delete)
                 if c.supports_sane_rowcount() and c.rowcount != len(delete):
@@ -1567,10 +1567,10 @@ class Mapper(object):
             if leftcol is None or rightcol is None:
                 return
             if leftcol.table not in needs_tables:
-                binary.left = sql.bindparam(leftcol.name, None, type=binary.right.type, unique=True)
+                binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True)
                 param_names.append(leftcol)
             elif rightcol not in needs_tables:
-                binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True)
+                binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True)
                 param_names.append(rightcol)
         cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
         return cond, param_names
index 7a3da1fdd15f68d5db76b206fe4c1741950337b4..99148cf6147c3ab5d3058868ecee73287ab61795 100644 (file)
@@ -384,7 +384,7 @@ class PropertyLoader(StrategizedProperty):
         if len(self.foreign_keys):
             self._opposite_side = util.Set()
             def visit_binary(binary):
-                if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+                if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                     return
                 if binary.left in self.foreign_keys:
                     self._opposite_side.add(binary.right)
@@ -397,7 +397,7 @@ class PropertyLoader(StrategizedProperty):
             self.foreign_keys = util.Set()
             self._opposite_side = util.Set()
             def visit_binary(binary):
-                if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+                if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                     return
 
                 # this check is for when the user put the "view_only" flag on and has tables that have nothing
index fa86e450b255eacc4c859fcc24d5d3bdbeff4ff6..c581b27c0301df9679669f31655ff792c4a5ead5 100644 (file)
@@ -407,7 +407,7 @@ class LazyLoader(AbstractRelationLoader):
             if should_bind(leftcol, rightcol):
                 col = leftcol
                 binary.left = binds.setdefault(leftcol,
-                        sql.bindparam(None, None, shortname=leftcol.name, type=binary.right.type, unique=True))
+                        sql.bindparam(None, None, shortname=leftcol.name, type_=binary.right.type, unique=True))
                 reverse[rightcol] = binds[col]
 
             # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
@@ -415,7 +415,7 @@ class LazyLoader(AbstractRelationLoader):
             if leftcol is not rightcol and should_bind(rightcol, leftcol):
                 col = rightcol
                 binary.right = binds.setdefault(rightcol,
-                        sql.bindparam(None, None, shortname=rightcol.name, type=binary.left.type, unique=True))
+                        sql.bindparam(None, None, shortname=rightcol.name, type_=binary.left.type, unique=True))
                 reverse[leftcol] = binds[col]
 
         lazywhere = primaryjoin
index 88fd980ad2a42c1e43bcc0452e52f1caf7118456..cf48202b0f9bcf0d3f2748b1c0bed7f12a16d2c6 100644 (file)
@@ -12,6 +12,7 @@ clause that compares column values.
 from sqlalchemy import sql, schema, exceptions
 from sqlalchemy import logging
 from sqlalchemy.orm import util as mapperutil
+import operator
 
 ONETOMANY = 0
 MANYTOONE = 1
@@ -42,7 +43,7 @@ class ClauseSynchronizer(object):
         def compile_binary(binary):
             """Assemble a SyncRule given a single binary condition."""
 
-            if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+            if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                 return
 
             source_column = None
index afb433e1e4d40b3c2dacc8070ed1367da824326a..20160b0bf785876665a3ca88d042a4ea92204520 100644 (file)
@@ -396,7 +396,7 @@ class Column(SchemaItem, sql._ColumnClause):
     ``TableClause``/``Table``.
     """
 
-    def __init__(self, name, type, *args, **kwargs):
+    def __init__(self, name, type_, *args, **kwargs):
         """Construct a new ``Column`` object.
 
         Arguments are:
@@ -405,7 +405,7 @@ class Column(SchemaItem, sql._ColumnClause):
           The name of this column.  This should be the identical name
           as it appears, or will appear, in the database.
 
-        type
+        type_
           The ``TypeEngine`` for this column.  This can be any
           subclass of ``types.AbstractType``, including the
           database-agnostic types defined in the types module,
@@ -495,7 +495,7 @@ class Column(SchemaItem, sql._ColumnClause):
             identifier contains mixed case.
         """
 
-        super(Column, self).__init__(name, None, type)
+        super(Column, self).__init__(name, None, type_)
         self.args = args
         self.key = kwargs.pop('key', name)
         self._primary_key = kwargs.pop('primary_key', False)
index f38347fc4091aceae053eba370e19896eb7187e0..672d085487fdb09cc9a5855ce4a1a253f3f74f4f 100644 (file)
@@ -32,45 +32,12 @@ __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
            'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
            'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 
            'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
-           'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
+           'between', 'bindparam', 'case', 'cast', 'column', 'delete',
            'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
            'insert', 'intersect', 'intersect_all', 'join', 'literal',
            'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
            'subquery', 'table', 'text', 'union', 'union_all', 'update',]
 
-# precedence ordering for common operators.  if an operator is not present in this list,
-# it will be parenthesized when grouped against other operators
-PRECEDENCE = {
-    'FROM':15,
-    '*':7,
-    '/':7,
-       '%':7,
-    '+':6,
-    '-':6,
-    'ILIKE':5,
-    'NOT ILIKE':5,
-    'LIKE':5,
-    'NOT LIKE':5,
-    'IN':5,
-    'NOT IN':5,
-    'IS':5,
-    'IS NOT':5,
-    '=':5,
-    '!=':5,
-    '>':5,
-    '<':5,
-    '>=':5,
-    '<=':5,
-    'BETWEEN':5,
-    'NOT':4,
-    'AND':3,
-    'OR':2,
-    ',':-1,
-    'AS':-1,
-    'EXISTS':0,
-    '_smallest': -1000,
-    '_largest': 1000
-}
 BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
 
 def desc(column):
@@ -368,7 +335,7 @@ def and_(*clauses):
     """
     if len(clauses) == 1:
         return clauses[0]
-    return ClauseList(operator='AND', *clauses)
+    return ClauseList(operator=operator.and_, *clauses)
 
 def or_(*clauses):
     """Join a list of clauses together using the ``OR`` operator.
@@ -379,7 +346,7 @@ def or_(*clauses):
 
     if len(clauses) == 1:
         return clauses[0]
-    return ClauseList(operator='OR', *clauses)
+    return ClauseList(operator=operator.or_, *clauses)
 
 def not_(clause):
     """Return a negation of the given clause, i.e. ``NOT(clause)``.
@@ -388,7 +355,7 @@ def not_(clause):
     subclasses to produce the same result.
     """
 
-    return clause._negate()
+    return operator.inv(clause)
 
 def distinct(expr):
     """return a ``DISTINCT`` clause."""
@@ -404,12 +371,8 @@ def between(ctest, cleft, cright):
     provides similar functionality.
     """
 
-    return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type=ctest.type), _literal_as_binds(cright, type=ctest.type), operator='AND', group=False), 'BETWEEN')
+    return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op)
 
-def between_(*args, **kwargs):
-    """synonym for [sqlalchemy.sql#between()] (deprecated)."""
-    
-    return between(*args, **kwargs)
 
 def case(whens, value=None, else_=None):
     """Produce a ``CASE`` statement.
@@ -432,7 +395,7 @@ def case(whens, value=None, else_=None):
         type = list(whenlist[-1])[-1].type
     else:
         type = None
-    cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END'])
+    cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END'])
     return cc
 
 def cast(clause, totype, **kwargs):
@@ -454,7 +417,7 @@ def cast(clause, totype, **kwargs):
 def extract(field, expr):
     """Return the clause ``extract(field FROM expr)``."""
 
-    expr = _BinaryExpression(text(field), expr, "FROM")
+    expr = _BinaryExpression(text(field), expr, Operators.from_)
     return func.extract(expr)
 
 def exists(*args, **kwargs):
@@ -584,7 +547,7 @@ def alias(selectable, alias=None):
     return Alias(selectable, alias=alias)
 
 
-def literal(value, type=None):
+def literal(value, type_=None):
     """Return a literal clause, bound to a bind parameter.
 
     Literal clauses are created automatically when non-
@@ -606,7 +569,7 @@ def literal(value, type=None):
 
     """
 
-    return _BindParamClause('literal', value, type=type, unique=True)
+    return _BindParamClause('literal', value, type_=type_, unique=True)
 
 def label(name, obj):
     """Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement].
@@ -627,7 +590,7 @@ def label(name, obj):
 
     return _Label(name, obj)
 
-def column(text, type=None):
+def column(text, type_=None):
     """Return a textual column clause, as would be in the columns 
     clause of a ``SELECT`` statement.
     
@@ -647,9 +610,9 @@ def column(text, type=None):
         
     """
 
-    return _ColumnClause(text, type=type)
+    return _ColumnClause(text, type_=type_)
 
-def literal_column(text, type=None):
+def literal_column(text, type_=None):
     """Return a textual column clause, as would be in the columns
     clause of a ``SELECT`` statement.
   
@@ -671,7 +634,7 @@ def literal_column(text, type=None):
       
     """
 
-    return _ColumnClause(text, type=type, is_literal=True)
+    return _ColumnClause(text, type_=type_, is_literal=True)
 
 def table(name, *columns):
     """Return a [sqlalchemy.sql#Table] object.
@@ -682,7 +645,7 @@ def table(name, *columns):
 
     return TableClause(name, *columns)
 
-def bindparam(key, value=None, type=None, shortname=None, unique=False):
+def bindparam(key, value=None, type_=None, shortname=None, unique=False):
     """Create a bind parameter clause with the given key.
 
         value
@@ -704,9 +667,9 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False):
     """
 
     if isinstance(key, _ColumnClause):
-        return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique)
+        return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique)
     else:
-        return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique)
+        return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique)
 
 def text(text, bind=None, *args, **kwargs):
     """Create literal text to be inserted into a query.
@@ -781,19 +744,19 @@ def _is_literal(element):
     return not isinstance(element, ClauseElement)
 
 def _literal_as_text(element):
-    if isinstance(element, Comparator):
+    if isinstance(element, Operators):
         return element.clause_element()
     elif _is_literal(element):
         return _TextClause(unicode(element))
     else:
         return element
 
-def _literal_as_binds(element, name='literal', type=None):
+def _literal_as_binds(element, name='literal', type_=None):
     if _is_literal(element):
         if element is None:
             return null()
         else:
-            return _BindParamClause(name, element, shortname=name, type=type, unique=True)
+            return _BindParamClause(name, element, shortname=name, type_=type_, unique=True)
     else:
         return element
         
@@ -1134,16 +1097,67 @@ class ClauseElement(object):
         if hasattr(self, 'negation_clause'):
             return self.negation_clause
         else:
-            return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
+            return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None)
+
+
+class Operators(object):
+    def from_():
+        raise NotImplementedError()
+    from_ = staticmethod(from_)
+    
+    def as_():
+        raise NotImplementedError()
+    as_ = staticmethod(as_)
+    
+    def exists():
+        raise NotImplementedError()
+    exists = staticmethod(exists)
+
+    def is_():
+        raise NotImplementedError()
+    is_ = staticmethod(is_)
+    
+    def isnot():
+        raise NotImplementedError()
+    isnot = staticmethod(isnot)
+    
+    def __and__(self, other):
+        return self.operate(operator.and_, other)
 
+    def __or__(self, other):
+        return self.operate(operator.or_, other)
+
+    def __invert__(self):
+        return self.operate(operator.inv)
+
+    def clause_element(self):
+        raise NotImplementedError()
+
+    def operate(self, op, *other):
+        raise NotImplementedError()
+
+    def reverse_operate(self, op, *other):
+        raise NotImplementedError()
 
-class Comparator(object):
+class ColumnOperators(Operators):
     """defines comparison and math operations"""
 
     def like_op(a, b):
         return a.like(b)
     like_op = staticmethod(like_op)
     
+    def notlike_op(a, b):
+        raise NotImplementedError()
+    notlike_op = staticmethod(notlike_op)
+
+    def ilike_op(a, b):
+        return a.ilike(b)
+    ilike_op = staticmethod(ilike_op)
+    
+    def notilike_op(a, b):
+        raise NotImplementedError()
+    notilike_op = staticmethod(notilike_op)
+    
     def between_op(a, b):
         return a.between(b)
     between_op = staticmethod(between_op)
@@ -1151,6 +1165,10 @@ class Comparator(object):
     def in_op(a, b):
         return a.in_(*b)
     in_op = staticmethod(in_op)
+
+    def notin_op(a, b):
+        raise NotImplementedError()
+    notin_op = staticmethod(notin_op)
     
     def startswith_op(a, b):
         return a.startswith(b)
@@ -1159,15 +1177,14 @@ class Comparator(object):
     def endswith_op(a, b):
         return a.endswith(b)
     endswith_op = staticmethod(endswith_op)
-    
-    def clause_element(self):
-        raise NotImplementedError()
-        
-    def operate(self, op, other):
-        raise NotImplementedError()
 
-    def reverse_operate(self, op, other):
+    def comma_op(a, b):
         raise NotImplementedError()
+    comma_op = staticmethod(comma_op)
+
+    def concat_op(a, b):
+        return a.concat(b)
+    concat_op = staticmethod(concat_op)
     
     def __lt__(self, other):
         return self.operate(operator.lt, other)
@@ -1187,17 +1204,20 @@ class Comparator(object):
     def __ge__(self, other):
         return self.operate(operator.ge, other)
 
+    def concat(self, other):
+        return self.operate(ColumnOperators.concat_op, other)
+        
     def like(self, other):
-        return self.operate(Comparator.like_op, other)
-
+        return self.operate(ColumnOperators.like_op, other)
+    
     def in_(self, *other):
-        return self.operate(Comparator.in_op, other)
-
+        return self.operate(ColumnOperators.in_op, other)
+    
     def startswith(self, other):
-        return self.operate(Comparator.startswith_op, other)
+        return self.operate(ColumnOperators.startswith_op, other)
 
     def endswith(self, other):
-        return self.operate(Comparator.endswith_op, other)
+        return self.operate(ColumnOperators.endswith_op, other)
 
     def __radd__(self, other):
         return self.reverse_operate(operator.add, other)
@@ -1212,7 +1232,7 @@ class Comparator(object):
         return self.reverse_operate(operator.div, other)
 
     def between(self, cleft, cright):
-        return self.operate(Comparator.between_op, (cleft, cright))
+        return self.operate(Operators.between_op, (cleft, cright))
 
     def __add__(self, other):
         return self.operate(operator.add, other)
@@ -1232,51 +1252,97 @@ class Comparator(object):
     def __truediv__(self, other):
         return self.operate(operator.truediv, other)
 
-class _CompareMixin(Comparator):
+# precedence ordering for common operators.  if an operator is not present in this list,
+# it will be parenthesized when grouped against other operators
+_smallest = object()
+_largest = object()
+
+PRECEDENCE = {
+    Operators.from_:15,
+    operator.mul:7,
+    operator.div:7,
+    operator.mod:7,
+    operator.add:6,
+    operator.sub:6,
+    ColumnOperators.concat_op:6,
+    ColumnOperators.ilike_op:5,
+    ColumnOperators.notilike_op:5,
+    ColumnOperators.like_op:5,
+    ColumnOperators.notlike_op:5,
+    ColumnOperators.in_op:5,
+    ColumnOperators.notin_op:5,
+    Operators.is_:5,
+    Operators.isnot:5,
+    operator.eq:5,
+    operator.ne:5,
+    operator.gt:5,
+    operator.lt:5,
+    operator.ge:5,
+    operator.le:5,
+    ColumnOperators.between_op:5,
+    operator.inv:4,
+    operator.and_:3,
+    operator.or_:2,
+    ColumnOperators.comma_op:-1,
+    Operators.as_:-1,
+    Operators.exists:0,
+    _smallest: -1000,
+    _largest: 1000
+}
+
+class _CompareMixin(ColumnOperators):
     """Defines comparison and math operations for ``ClauseElement`` instances."""
 
-    def __compare(self, operator, obj, negate=None):
+    def __compare(self, op, obj, negate=None):
         if obj is None or isinstance(obj, _Null):
-            if operator == '=':
-                return _BinaryExpression(self.clause_element(), null(), 'IS', negate='IS NOT')
-            elif operator == '!=':
-                return _BinaryExpression(self.clause_element(), null(), 'IS NOT', negate='IS')
+            if op == operator.eq:
+                return _BinaryExpression(self.clause_element(), null(), Operators.is_, negate=Operators.isnot)
+            elif op == operator.ne:
+                return _BinaryExpression(self.clause_element(), null(), Operators.isnot, negate=Operators.is_)
             else:
                 raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
         else:
             obj = self._check_literal(obj)
 
-        return _BinaryExpression(self.clause_element(), obj, operator, type=sqltypes.Boolean, negate=negate)
+            
+        return _BinaryExpression(self.clause_element(), obj, op, type_=sqltypes.Boolean, negate=negate)
 
-    def __operate(self, operator, obj):
+    def __operate(self, op, obj):
         obj = self._check_literal(obj)
-        return _BinaryExpression(self.clause_element(), obj, operator, type=self._compare_type(obj))
+
+        type_ = self._compare_type(obj)
+        if op == operator.add and isinstance(type_, (sqltypes.Concatenable)):
+            op = ColumnOperators.concat_op
+        
+        return _BinaryExpression(self.clause_element(), obj, op, type_=type_)
 
     operators = {
-        operator.add : (__operate, '+'),
-        operator.mul : (__operate, '*'),
-        operator.sub : (__operate, '-'),
-        operator.div : (__operate, '/'),
-        operator.mod : (__operate, '%'),
-        operator.truediv : (__operate, '/'),
-        operator.lt : (__compare, '<', '=>'),
-        operator.le : (__compare, '<=', '>'),
-        operator.ne : (__compare, '!=', '='),
-        operator.gt : (__compare, '>', '<='),
-        operator.ge : (__compare, '>=', '<'),
-        operator.eq : (__compare, '=', '!='),
-        Comparator.like_op : (__compare, 'LIKE', 'NOT LIKE'),
+        operator.add : (__operate,),
+        operator.mul : (__operate,),
+        operator.sub : (__operate,),
+        operator.div : (__operate,),
+        operator.mod : (__operate,),
+        operator.truediv : (__operate,),
+        operator.lt : (__compare, operator.ge),
+        operator.le : (__compare, operator.gt),
+        operator.ne : (__compare, operator.eq),
+        operator.gt : (__compare, operator.le),
+        operator.ge : (__compare, operator.lt),
+        operator.eq : (__compare, operator.ne),
+        ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op),
     }
 
     def operate(self, op, other):
         o = _CompareMixin.operators[op]
-        return o[0](self, o[1], other, *o[2:])
+        return o[0](self, op, other, *o[1:])
     
     def reverse_operate(self, op, other):
         return self._bind_param(other).operate(op, self)
 
     def in_(self, *other):
-        """produce an ``IN`` clause."""
+        return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other)
+        
+    def _in_impl(self, op, negate_op, *other):
         if len(other) == 0:
             return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1')))
         elif len(other) == 1:
@@ -1285,7 +1351,7 @@ class _CompareMixin(Comparator):
                 return self.__eq__( o)    #single item -> ==
             else:
                 assert hasattr( o, '_selectable')   #better check?
-                return self.__compare( 'IN', o, negate='NOT IN')   #single selectable
+                return self.__compare( op, o, negate=negate_op)   #single selectable
 
         args = []
         for o in other:
@@ -1295,19 +1361,21 @@ class _CompareMixin(Comparator):
             else:
                 o = self._bind_param(o)
             args.append(o)
-        return self.__compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
+        return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op)
 
     def startswith(self, other):
         """produce the clause ``LIKE '<other>%'``"""
-        perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String)
+        
+        perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String)
         return self.__compare('LIKE', other + perc)
 
     def endswith(self, other):
         """produce the clause ``LIKE '%<other>'``"""
+        
         if isinstance(other,(str,unicode)): po = '%' + other
         else:
-            po = literal('%', typesqltypes.String) + other
-            po.type = sqltypes.to_instance( sqltypes.String)     #force!
+            po = literal('%', type_=sqltypes.String) + other
+            po.type = sqltypes.to_instance(sqltypes.String)     #force!
         return self.__compare('LIKE', po)
 
     def label(self, name):
@@ -1320,7 +1388,7 @@ class _CompareMixin(Comparator):
 
     def between(self, cleft, cright):
         """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
-        return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN')
+        return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), 'BETWEEN')
 
     def op(self, operator):
         """produce a generic operator function.
@@ -1342,10 +1410,10 @@ class _CompareMixin(Comparator):
         return lambda other: self.__operate(operator, other)
 
     def _bind_param(self, obj):
-        return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
+        return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True)
 
     def _check_literal(self, other):
-        if isinstance(other, Comparator):
+        if isinstance(other, Operators):
             return other.clause_element()
         elif _is_literal(other):
             return self._bind_param(other)
@@ -1764,7 +1832,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
 
     __visit_name__ = 'bindparam'
     
-    def __init__(self, key, value, shortname=None, type=None, unique=False):
+    def __init__(self, key, value, shortname=None, type_=None, unique=False):
         """Construct a _BindParamClause.
 
         key
@@ -1787,7 +1855,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
           execution may match either the key or the shortname of the
           corresponding ``_BindParamClause`` objects.
 
-        type
+        type_
           A ``TypeEngine`` object that will be used to pre-process the
           value corresponding to this ``_BindParamClause`` at
           execution time.
@@ -1803,8 +1871,20 @@ class _BindParamClause(ClauseElement, _CompareMixin):
         self.value = value
         self.shortname = shortname or key
         self.unique = unique
-        self.type = sqltypes.to_instance(type)
-
+        type_ = sqltypes.to_instance(type_)
+        if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map:
+            self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)])
+        else:
+            self.type = type_
+    
+    # TODO: move to types module, obviously
+    type_map = {
+        str : sqltypes.String,
+        unicode : sqltypes.Unicode,
+        int : sqltypes.Integer,
+        float : sqltypes.Numeric
+    }
+    
     def _get_from_objects(self, **modifiers):
         return []
 
@@ -1822,7 +1902,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
         return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
 
     def __repr__(self):
-        return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+        return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
 
 class _TypeClause(ClauseElement):
     """Handle a type keyword in a SQL statement.
@@ -1907,10 +1987,9 @@ class ClauseList(ClauseElement):
     
     def __init__(self, *clauses, **kwargs):
         self.clauses = []
-        self.operator = kwargs.pop('operator', ',')
+        self.operator = kwargs.pop('operator', ColumnOperators.comma_op)
         self.group = kwargs.pop('group', True)
         self.group_contents = kwargs.pop('group_contents', True)
-        self.negate_operator = kwargs.pop('negate', None)
         for c in clauses:
             if c is None: 
                 continue
@@ -1932,14 +2011,6 @@ class ClauseList(ClauseElement):
     def _copy_internals(self):
         self.clauses = [clause._clone() for clause in self.clauses]
 
-    def _negate(self):
-        if hasattr(self, 'negation_clause'):
-            return self.negation_clause
-        elif self.negate_operator is None:
-            return super(ClauseList, self)._negate()
-        else:
-            return ClauseList(operator=self.negate_operator, negate=self.operator, *(not_(c) for c in self.clauses))
-
     def get_children(self, **kwargs):
         return self.clauses
 
@@ -1950,7 +2021,7 @@ class ClauseList(ClauseElement):
         return f
 
     def self_group(self, against=None):
-        if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+        if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
             return _Grouping(self)
         else:
             return self
@@ -1981,7 +2052,7 @@ class _CalculatedClause(ColumnElement):
     
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
-        self.type = sqltypes.to_instance(kwargs.get('type', None))
+        self.type = sqltypes.to_instance(kwargs.get('type_', None))
         self._bind = kwargs.get('bind', None)
         self.group = kwargs.pop('group', True)
         self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
@@ -2002,7 +2073,7 @@ class _CalculatedClause(ColumnElement):
         return self.clauses._get_from_objects(**modifiers)
 
     def _bind_param(self, obj):
-        return _BindParamClause(self.name, obj, type=self.type, unique=True)
+        return _BindParamClause(self.name, obj, type_=self.type, unique=True)
 
     def select(self):
         return select([self])
@@ -2024,10 +2095,8 @@ class _Function(_CalculatedClause, FromClause):
     """
 
     def __init__(self, name, *clauses, **kwargs):
-        self.type = sqltypes.to_instance(kwargs.get('type', None))
         self.packagenames = kwargs.get('packagenames', None) or []
-        kwargs['operator'] = ','
-        self._bind = kwargs.get('bind', None)
+        kwargs['operator'] = ColumnOperators.comma_op
         _CalculatedClause.__init__(self, name, **kwargs)
         for c in clauses:
             self.append(c)
@@ -2065,7 +2134,7 @@ class _Cast(ColumnElement):
 
     def _make_proxy(self, selectable, name=None):
         if name is not None:
-            co = _ColumnClause(name, selectable, type=self.type)
+            co = _ColumnClause(name, selectable, type_=self.type)
             co._distance = self._distance + 1
             co.orig_set = self.orig_set
             selectable.columns[name]= co
@@ -2075,12 +2144,12 @@ class _Cast(ColumnElement):
 
 
 class _UnaryExpression(ColumnElement):
-    def __init__(self, element, operator=None, modifier=None, type=None, negate=None):
+    def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
         self.operator = operator
         self.modifier = modifier
         
         self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier)
-        self.type = sqltypes.to_instance(type)
+        self.type = sqltypes.to_instance(type_)
         self.negate = negate
         
     def _get_from_objects(self, **modifiers):
@@ -2103,12 +2172,12 @@ class _UnaryExpression(ColumnElement):
 
     def _negate(self):
         if self.negate is not None:
-            return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
+            return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type)
         else:
             return super(_UnaryExpression, self)._negate()
     
     def self_group(self, against):
-        if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+        if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
             return _Grouping(self)
         else:
             return self
@@ -2117,11 +2186,11 @@ class _UnaryExpression(ColumnElement):
 class _BinaryExpression(ColumnElement):
     """Represent an expression that is ``LEFT <operator> RIGHT``."""
     
-    def __init__(self, left, right, operator, type=None, negate=None):
+    def __init__(self, left, right, operator, type_=None, negate=None):
         self.left = _literal_as_text(left).self_group(against=operator)
         self.right = _literal_as_text(right).self_group(against=operator)
         self.operator = operator
-        self.type = sqltypes.to_instance(type)
+        self.type = sqltypes.to_instance(type_)
         self.negate = negate
 
     def _get_from_objects(self, **modifiers):
@@ -2142,7 +2211,7 @@ class _BinaryExpression(ColumnElement):
                 (
                     self.left.compare(other.left) and self.right.compare(other.right)
                     or (
-                        self.operator in ['=', '!=', '+', '*'] and
+                        self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and
                         self.left.compare(other.right) and self.right.compare(other.left)
                     )
                 )
@@ -2150,14 +2219,14 @@ class _BinaryExpression(ColumnElement):
         
     def self_group(self, against=None):
         # use small/large defaults for comparison so that unknown operators are always parenthesized
-        if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])):
+        if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])):
             return _Grouping(self)
         else:
             return self
     
     def _negate(self):
         if self.negate is not None:
-            return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type)
+            return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type)
         else:
             return super(_BinaryExpression, self)._negate()
 
@@ -2167,7 +2236,7 @@ class _Exists(_UnaryExpression):
     def __init__(self, *args, **kwargs):
         kwargs['correlate'] = True
         s = select(*args, **kwargs).self_group()
-        _UnaryExpression.__init__(self, s, operator="EXISTS")
+        _UnaryExpression.__init__(self, s, operator=Operators.exists)
 
     def _hide_froms(self, **modifiers):
         return self._get_from_objects(**modifiers)
@@ -2208,7 +2277,7 @@ class Join(FromClause):
                     
         class BinaryVisitor(ClauseVisitor):
             def visit_binary(self, binary):
-                if binary.operator == '=':
+                if binary.operator == operator.eq:
                     add_equiv(binary.left, binary.right)
         BinaryVisitor().traverse(self.onclause)
         
@@ -2290,7 +2359,7 @@ class Join(FromClause):
             equivs = util.Set()
         class LocateEquivs(NoColumnVisitor):
             def visit_binary(self, binary):
-                if binary.operator == '=' and binary.left.name == binary.right.name:
+                if binary.operator == operator.eq and binary.left.name == binary.right.name:
                     equivs.add(binary.right)
                     equivs.add(binary.left)
         LocateEquivs().traverse(self.onclause)
@@ -2463,14 +2532,14 @@ class _Label(ColumnElement):
     
     """
     
-    def __init__(self, name, obj, type=None):
+    def __init__(self, name, obj, type_=None):
         while isinstance(obj, _Label):
             obj = obj.obj
         self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
 
-        self.obj = obj.self_group(against='AS')
+        self.obj = obj.self_group(against=Operators.as_)
         self.case_sensitive = getattr(obj, "case_sensitive", True)
-        self.type = sqltypes.to_instance(type or getattr(obj, 'type', None))
+        self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
 
     key = property(lambda s: s.name)
     _label = property(lambda s: s.name)
@@ -2528,11 +2597,11 @@ class _ColumnClause(ColumnElement):
     
     """
 
-    def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
+    def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False):
         self.key = self.name = text
         self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name
         self.table = selectable
-        self.type = sqltypes.to_instance(type)
+        self.type = sqltypes.to_instance(type_)
         self._is_oid = _is_oid
         self._distance = 0
         self.__label = None
@@ -2586,13 +2655,13 @@ class _ColumnClause(ColumnElement):
             return []
 
     def _bind_param(self, obj):
-        return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True)
+        return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True)
 
     def _make_proxy(self, selectable, name = None):
         # propigate the "is_literal" flag only if we are keeping our name,
         # otherwise its considered to be a label
         is_literal = self.is_literal and (name is None or name == self.name)
-        c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal)
+        c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
         c.orig_set = self.orig_set
         c._distance = self._distance + 1
         if not self._is_oid:
@@ -3050,7 +3119,7 @@ class Select(_SelectBaseMixin, FromClause):
             column = literal_column(str(column))
 
         if isinstance(column, Select) and column.is_scalar:
-            column = column.self_group(against=',')
+            column = column.self_group(against=ColumnOperators.comma_op)
 
         self._raw_columns.append(column)
 
@@ -3191,7 +3260,7 @@ class _UpdateBase(ClauseElement):
         for key in parameters.keys():
             value = parameters[key]
             if isinstance(value, ClauseElement):
-                pass
+                parameters[key] = value.self_group()
             elif _is_literal(value):
                 if _is_literal(key):
                     col = self.table.c[key]
index ddaf990e7f5f737bc225e863d1891b1b3902533a..6e59ac16e550501d0f8e2ac6e6b6d00dc9f2927f 100644 (file)
@@ -192,7 +192,11 @@ class NullType(TypeEngine):
         return value
 NullTypeEngine = NullType
 
-class String(TypeEngine):
+class Concatenable(object):
+    """marks a type as supporting 'concatenation'"""
+    pass
+    
+class String(TypeEngine, Concatenable):
     def __init__(self, length=None, convert_unicode=False):
         self.length = length
         self.convert_unicode = convert_unicode
index 321493329ad73da6c8b4a6703ee3dfa6449643a8..17b243d256d99e340e164f92cf1124a4bf1731db 100644 (file)
@@ -149,9 +149,10 @@ class BindTest(testbase.PersistTest):
                     assert False
                 except exceptions.InvalidRequestError, e:
                     assert str(e) == "This Compiled object is not bound to any Engine or Connection."
-                
+
         finally:
-            bind.close()
+            if isinstance(bind, engine.Connection):
+                bind.close()
             metadata.drop_all(bind=testbase.db)
     
     def test_session(self):
@@ -165,7 +166,9 @@ class BindTest(testbase.PersistTest):
         mapper(Foo, table)
         metadata.create_all(bind=testbase.db)
         try:
-            for bind in (testbase.db, testbase.db.connect()):
+            for bind in (testbase.db, 
+                testbase.db.connect()
+                ):
                 for args in ({'bind':bind},):
                     sess = create_session(**args)
                     assert sess.bind is bind
@@ -173,6 +176,9 @@ class BindTest(testbase.PersistTest):
                     sess.save(f)
                     sess.flush()
                     assert sess.get(Foo, f.foo) is f
+
+                if isinstance(bind, engine.Connection):
+                    bind.close()
                     
             sess = create_session()
             f = Foo()
@@ -182,9 +188,11 @@ class BindTest(testbase.PersistTest):
                 assert False
             except exceptions.InvalidRequestError, e:
                 assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
+
                 
         finally:
-            bind.close()
+            if isinstance(bind, engine.Connection):
+                bind.close()
             metadata.drop_all(bind=testbase.db)
         
                
index f393b9f7d015e1e3bde329144b7a1d0f02cd639a..eb4f95619d059da85060aeeae74ec6b83a9d334f 100644 (file)
@@ -116,7 +116,6 @@ class CreateEngineTest(PersistTest):
         except TypeError:
             assert True
             
-        e = create_engine('sqlite://', echo=True)
         e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
         
         e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
index 141b3abc27e8df032e1f226be18e3d594be02b94..fe7d77985a9216128f4ff2ba28883eefc5aec8a6 100644 (file)
@@ -737,18 +737,18 @@ class MultiLevelTest(testbase.ORMTest):
     def define_tables(self, metadata):
         global table_Employee, table_Engineer, table_Manager
         table_Employee = Table( 'Employee', metadata,
-            Column( 'name', type= String(100), ),
-            Column( 'id', primary_key= True, type= Integer, ),
-            Column( 'atype', type= String(100), ),
+            Column( 'name', type_= String(100), ),
+            Column( 'id', primary_key= True, type_= Integer, ),
+            Column( 'atype', type_= String(100), ),
         )
 
         table_Engineer = Table( 'Engineer', metadata,
-            Column( 'machine', type= String(100), ),
+            Column( 'machine', type_= String(100), ),
             Column( 'id', Integer, ForeignKey( 'Employee.id', ), primary_key= True, ),
         )
 
         table_Manager = Table( 'Manager', metadata,
-            Column( 'duties', type= String(100), ),
+            Column( 'duties', type_= String(100), ),
             Column( 'id', Integer, ForeignKey( 'Engineer.id', ), primary_key= True, ),
         )
     def test_threelevels(self):
index 026f80808126efbec936eda9d5ed6846a5ff9940..df4187eb4dc0e5e1e35a59ca672df1e1ea2ff57d 100644 (file)
@@ -112,14 +112,14 @@ class OperatorTest(QueryTest):
                                 (operator.sub, '-'), (operator.div, '/'),
                                 ):
             for (lhs, rhs, res) in (
-                ('a', User.id, ':users_id %s users.id'),
-                ('a', literal('b'), ':literal %s :literal_1'),
-                (User.id, 'b', 'users.id %s :users_id'),
+                (5, User.id, ':users_id %s users.id'),
+                (5, literal(6), ':literal %s :literal_1'),
+                (User.id, 5, 'users.id %s :users_id'),
                 (User.id, literal('b'), 'users.id %s :literal'),
                 (User.id, User.id, 'users.id %s users.id'),
-                (literal('a'), 'b', ':literal %s :literal_1'),
-                (literal('a'), User.id, ':literal %s users.id'),
-                (literal('a'), literal('b'), ':literal %s :literal_1'),
+                (literal(5), 'b', ':literal %s :literal_1'),
+                (literal(5), User.id, ':literal %s users.id'),
+                (literal(5), literal(6), ':literal %s :literal_1'),
                 ):
                 self._test(py_op(lhs, rhs), res % sql_op)
 
@@ -503,7 +503,6 @@ class InstancesTest(QueryTest):
         l = q.add_column("count").from_statement(s).all()
         assert l == expected
 
-    @testbase.unsupported('mysql') # only because of "+" operator requiring "concat" in mysql (fix #475)
     def test_two_columns(self):
         sess = create_session()
         (user7, user8, user9, user10) = sess.query(User).all()
index 5a42317d7fa4f94856c506d531b3841369e052d3..bcf8849644f36366f7970d9714a64c8ac8ae556d 100644 (file)
@@ -28,9 +28,9 @@ class CaseTest(testbase.PersistTest):
     def testcase(self):
         inner = select([case([
                [info_table.c.pk < 3, 
-                        literal('lessthan3', type=String)],
+                        literal('lessthan3', type_=String)],
                [and_(info_table.c.pk >= 3, info_table.c.pk < 7), 
-                        literal('gt3', type=String)]]).label('x'),
+                        literal('gt3', type_=String)]]).label('x'),
                info_table.c.pk, info_table.c.info], 
                 from_obj=[info_table]).alias('q_inner')
 
@@ -67,9 +67,9 @@ class CaseTest(testbase.PersistTest):
 
         w_else = select([case([
                [info_table.c.pk < 3, 
-                        literal(3, type=Integer)],
+                        literal(3, type_=Integer)],
                [and_(info_table.c.pk >= 3, info_table.c.pk < 6), 
-                        literal(6, type=Integer)]],
+                        literal(6, type_=Integer)]],
                 else_ = 0).label('x'),
                info_table.c.pk, info_table.c.info], 
                 from_obj=[info_table]).alias('q_inner')
index 07363a402e67546810bc236199e74426f087f89e..a9dd2f5ad29e0a9e42af97e00f341c22530593b0 100644 (file)
@@ -25,7 +25,7 @@ class DefaultTest(PersistTest):
  
         # select "count(1)" returns different results on different DBs
         # also correct for "current_date" compatible as column default, value differences
-        currenttime = func.current_date(type=Date, bind=db);
+        currenttime = func.current_date(type_=Date, bind=db);
         if is_oracle:
             ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar()
             f = select([func.count(1) + 5], bind=db).scalar()
@@ -230,7 +230,7 @@ class SequenceTest(PersistTest):
         )
         sometable = Table( 'Manager', metadata,
                Column( 'obj_id', Integer, Sequence('obj_id_seq'), ),
-               Column( 'name', type= String, ),
+               Column( 'name', String, ),
                Column( 'id', Integer, primary_key= True, ),
            )
         
index 772ffa793aeb81ae18be6e7374760a7e771b4d36..05b6d0419dbf2120ef56c58179683eb125c54f25 100644 (file)
@@ -281,6 +281,10 @@ class QueryTest(PersistTest):
         y = testbase.db.func.current_date().select().execute().scalar()
         z = testbase.db.func.current_date().scalar()
         assert x == y == z
+        
+        x = testbase.db.func.current_date(type_=Date)
+        assert isinstance(x.type, Date)
+        assert isinstance(x.execute().scalar(), datetime.date)
 
     def test_conn_functions(self):
         conn = testbase.db.connect()
@@ -351,7 +355,7 @@ class QueryTest(PersistTest):
         w = select(['*'], from_obj=[testbase.db.func.current_date()]).scalar()
         
         # construct a column-based FROM object out of a function, like in [ticket:172]
-        s = select([column('date', type=DateTime)], from_obj=[testbase.db.func.current_date()])
+        s = select([column('date', type_=DateTime)], from_obj=[testbase.db.func.current_date()])
         q = s.execute().fetchone()[s.c.date]
         r = s.alias('datequery').select().scalar()
         
index d5b00e1dab9c28d7d5310b38ae0287d11b71a442..3d5996df9daf748e56e8337a686476a89e58ce43 100644 (file)
@@ -11,21 +11,21 @@ import unittest, re, operator
 # so SQLAlchemy's SQL construction engine can be used with no database dependencies at all.
 
 table1 = table('mytable', 
-    column('myid'),
-    column('name'),
-    column('description'),
+    column('myid', Integer),
+    column('name', String),
+    column('description', String),
 )
 
 table2 = table(
     'myothertable', 
-    column('otherid'),
-    column('othername'),
+    column('otherid', Integer),
+    column('othername', String),
 )
 
 table3 = table(
     'thirdtable', 
-    column('userid'),
-    column('otherstuff'),
+    column('userid', Integer),
+    column('otherstuff', String),
 )
 
 metadata = MetaData()
@@ -273,14 +273,14 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
                                 (operator.sub, '-'), (operator.div, '/'),
                                 ):
             for (lhs, rhs, res) in (
-                ('a', table1.c.myid, ':mytable_myid %s mytable.myid'),
-                ('a', literal('b'), ':literal %s :literal_1'),
+                (5, table1.c.myid, ':mytable_myid %s mytable.myid'),
+                (5, literal(5), ':literal %s :literal_1'),
                 (table1.c.myid, 'b', 'mytable.myid %s :mytable_myid'),
-                (table1.c.myid, literal('b'), 'mytable.myid %s :literal'),
+                (table1.c.myid, literal(2.7), 'mytable.myid %s :literal'),
                 (table1.c.myid, table1.c.myid, 'mytable.myid %s mytable.myid'),
-                (literal('a'), 'b', ':literal %s :literal_1'),
-                (literal('a'), table1.c.myid, ':literal %s mytable.myid'),
-                (literal('a'), literal('b'), ':literal %s :literal_1'),
+                (literal(5), 8, ':literal %s :literal_1'),
+                (literal(6), table1.c.myid, ':literal %s mytable.myid'),
+                (literal(7), literal(5.5), ':literal %s :literal_1'),
                 ):
                 self.runtest(py_op(lhs, rhs), res % sql_op)
 
@@ -328,7 +328,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         )
 
         self.runtest(
-         literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
+         literal("a") + literal("b") * literal("c"), ":literal || :literal_1 * :literal_2"
         )
 
         # test the op() function, also that its results are further usable in expressions
@@ -540,7 +540,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
 
     def testliteral(self):
         self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]), 
-            "SELECT :literal + :literal_1 FROM mytable")
+            "SELECT :literal || :literal_1 FROM mytable")
 
     def testcalculatedcolumns(self):
          value_tbl = table('values',
@@ -866,16 +866,16 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo
         self.runtest(select([table1], table1.c.myid.in_('a', literal('b'))),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)")
 
-        self.runtest(select([table1], table1.c.myid.in_(literal('a') + 'a')),
+        self.runtest(select([table1], table1.c.myid.in_(literal(1) + 'a')),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1")
 
         self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :mytable_myid)")
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :mytable_myid)")
 
         self.runtest(select([table1], table1.c.myid.in_(literal('a') + literal('a'), literal('b'))),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :literal_2)")
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :literal_2)")
 
-        self.runtest(select([table1], table1.c.myid.in_('a', literal('b') +'b')),
+        self.runtest(select([table1], table1.c.myid.in_(1, literal(3) + 4)),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal + :literal_1)")
 
         self.runtest(select([table1], table1.c.myid.in_(literal('a') < 'b')),
@@ -893,7 +893,7 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo
         self.runtest(select([table1], table1.c.myid.in_(literal('a'), table1.c.myid +'a')),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, mytable.myid + :mytable_myid)")
 
-        self.runtest(select([table1], table1.c.myid.in_(literal('a'), 'a' + table1.c.myid)),
+        self.runtest(select([table1], table1.c.myid.in_(literal(1), 'a' + table1.c.myid)),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :mytable_myid + mytable.myid)")
 
         self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)),
@@ -1040,12 +1040,16 @@ class CRUDTest(SQLTest):
             values = {
             table1.c.name : table1.c.name + "lala",
             table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho'))
-            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1")
+            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=(mytable.name || :mytable_name) WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal || mytable.name || :literal_1")
         
     def testcorrelatedupdate(self):
         # test against a straight text subquery
-        u = update(table1, values = {table1.c.name : text("select name from mytable where id=mytable.id")})
+        u = update(table1, values = {table1.c.name : text("(select name from mytable where id=mytable.id)")})
         self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)")
+
+        mt = table1.alias()
+        u = update(table1, values = {table1.c.name : select([mt.c.name], mt.c.myid==table1.c.myid)})
+        self.runtest(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)")
         
         # test against a regular constructed subquery
         s = select([table2], table2.c.otherid == table1.c.myid)