e.g. select([x* 5]) produces "SELECT x * 5 AS anon_1".
This allows the labelname to be present in the cursor.description
which can then be appropriately matched to result-column processing
rules. (we can't reliably use positional tracking for result-column
matches since text() expressions may represent multiple columns).
- operator overloading is now controlled by TypeEngine objects - the
one built-in operator overload so far is String types overloading
'+' to be the string concatenation operator.
User-defined types can also define their own operator overloading
by overriding the adapt_operator(self, op) method.
- untyped bind parameters on the right side of a binary expression
will be assigned the type of the left side of the operation, to better
enable the appropriate bind parameter processing to take effect
[ticket:819]
- Added contains operator (generates a "LIKE %<other>%" clause).
+ - anonymous column expressions are automatically labeled.
+ e.g. select([x* 5]) produces "SELECT x * 5 AS anon_1".
+ This allows the labelname to be present in the cursor.description
+ which can then be appropriately matched to result-column processing
+ rules. (we can't reliably use positional tracking for result-column
+ matches since text() expressions may represent multiple columns).
+
+ - operator overloading is now controlled by TypeEngine objects - the
+ one built-in operator overload so far is String types overloading
+ '+' to be the string concatenation operator.
+ User-defined types can also define their own operator overloading
+ by overriding the adapt_operator(self, op) method.
+
+ - untyped bind parameters on the right side of a binary expression
+ will be assigned the type of the left side of the operation, to better
+ enable the appropriate bind parameter processing to take effect
+ [ticket:819]
+
- Removed regular expression step from most statement compilations.
Also fixes [ticket:833]
return ""
def __visit_label(self, label):
+ # TODO: whats this method for ?
if self.select_stack:
self.typemap.setdefault(label.name.lower(), label.obj.type)
if self.strings[label.obj]:
kwargs['mssql_aliased'] = True
return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
- def visit_column(self, column):
+ def visit_column(self, column, **kwargs):
if column.table is not None and not self.isupdate and not self.isdelete:
# translate for schema-qualified table aliases
t = self._schema_aliased_table(column.table)
if t is not None:
return self.process(t.corresponding_column(column))
- return super(MSSQLCompiler, self).visit_column(column)
+ return super(MSSQLCompiler, self).visit_column(column, **kwargs)
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
def get_col_spec(self):
return "BOOLEAN"
-class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
+class PGArray(sqltypes.Concatenable, sqltypes.TypeEngine):
def __init__(self, item_type):
if isinstance(item_type, type):
item_type = item_type()
for i, item in enumerate(metadata):
# sqlite possibly prepending table name to colnames so strip
colname = (item[0].split('.')[-1]).decode(self.dialect.encoding)
+
if self.context.typemap is not None:
type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
else:
def visit_grouping(self, grouping, **kwargs):
return "(" + self.process(grouping.elem) + ")"
- def visit_label(self, label):
+ def visit_label(self, label, typemap=None, column_labels=None):
labelname = self._truncated_identifier("colident", label.name)
- if len(self.stack) == 1 and self.stack[-1].get('select'):
+ if typemap is not None:
self.typemap.setdefault(labelname.lower(), label.obj.type)
+
+ if column_labels is not None:
if isinstance(label.obj, sql._ColumnClause):
- self.column_labels[label.obj._label] = labelname
- self.column_labels[label.name] = labelname
+ column_labels[label.obj._label] = labelname
+ column_labels[label.name] = labelname
return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
- def visit_column(self, column, **kwargs):
+ def visit_column(self, column, typemap=None, column_labels=None, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
# want to truncate a column identifier, if its mapped to the name of a
# physical column. but thats very hard to identify at this point, and
else:
name = column.name
- if len(self.stack) == 1 and self.stack[-1].get('select'):
- # if we are within a visit to a Select, set up the "typemap"
- # for this column which is used to translate result set values
- self.typemap.setdefault(name.lower(), column.type)
+ if typemap is not None:
+ typemap.setdefault(name.lower(), column.type)
+ if column_labels is not None:
self.column_labels.setdefault(column._label, name.lower())
if column._is_oid:
def visit_calculatedclause(self, clause, **kwargs):
return self.process(clause.clause_expr)
- def visit_cast(self, cast, **kwargs):
- if self.stack and self.stack[-1].get('select'):
- # not sure if we want to set the typemap here...
- self.typemap.setdefault("CAST", cast.type)
+ def visit_cast(self, cast, typemap=None, **kwargs):
return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
- def visit_function(self, func, **kwargs):
- if self.stack and self.stack[-1].get('select'):
- self.typemap.setdefault(func.name, func.type)
+ def visit_function(self, func, typemap=None, **kwargs):
+ if typemap is not None:
+ typemap.setdefault(func.name, func.type)
if not self.apply_function_parens(func):
return ".".join(func.packagenames + [func.name])
else:
s = s + " " + self.operator_string(unary.modifier)
return s
- def visit_binary(self, binary, **kwargs):
+ def visit_binary(self, binary, typemap=None, **kwargs):
op = self.operator_string(binary.operator)
if callable(op):
return op(self.process(binary.left), self.process(binary.right))
else:
return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+
+ return ret
def operator_string(self, operator):
return self.operators.get(operator, str(operator))
column.table is not None and \
not isinstance(column.table, sql.Select):
return column.label(column.name)
+ elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'):
+ return column.label(None)
else:
return None
if asfrom:
stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+ column_clause_args = {}
elif self.stack and self.stack[-1].get('select'):
stack_entry['is_subquery'] = True
-
+ column_clause_args = {}
+ else:
+ column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels}
+
if self.stack and self.stack[-1].get('from'):
existingfroms = self.stack[-1]['from']
else:
existingfroms = None
+
froms = select._get_display_froms(existingfroms)
correlate_froms = util.Set()
labelname = co._label
if labelname is not None:
l = co.label(labelname)
- inner_columns.add(self.process(l))
+ inner_columns.add(self.process(l, **column_clause_args))
else:
- inner_columns.add(self.process(co))
+ inner_columns.add(self.process(co, **column_clause_args))
else:
l = self.label_select_column(select, co)
if l is not None:
- inner_columns.add(self.process(l))
+ inner_columns.add(self.process(l, **column_clause_args))
else:
- inner_columns.add(self.process(co))
+ inner_columns.add(self.process(co, **column_clause_args))
collist = string.join(inner_columns.difference(util.Set([None])), ', ')
type_ = self._compare_type(obj)
- # TODO: generalize operator overloading like this out into the
- # types module
- if op == operators.add and isinstance(type_, (sqltypes.Concatenable)):
- op = operators.concat_op
- return _BinaryExpression(self.expression_element(), obj, op, type_=type_)
+ return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
# a mapping of operators with the method they use, along with their negated
# operator for comparison operators
return self.__compare(operators.like_op, po)
def label(self, name):
- """Produce a column label, i.e. ``<columnname> AS <name>``"""
+ """Produce a column label, i.e. ``<columnname> AS <name>``.
+
+ if 'name' is None, an anonymous label name will be generated.
+ """
return _Label(name, self, self.type)
def desc(self):
return _BindParamClause('literal', obj, type_=self.type, unique=True)
def _check_literal(self, other):
- if isinstance(other, Operators):
+ if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
+ other.type = self.type
+ return other
+ elif isinstance(other, Operators):
return other.expression_element()
elif _is_literal(other):
return self._bind_param(other)
"""
return None
-
+
+ def adapt_operator(self, op):
+ """given an operator from the sqlalchemy.sql.operators package,
+ translate it to a new operator based on the semantics of this type.
+
+ By default, returns the operator unchanged."""
+ return op
+
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]]))
class Concatenable(object):
"""marks a type as supporting 'concatenation'"""
- pass
+ def adapt_operator(self, op):
+ from sqlalchemy.sql import operators
+ if op == operators.add:
+ return operators.concat_op
+ else:
+ return op
-class String(TypeEngine, Concatenable):
+class String(Concatenable, TypeEngine):
def __init__(self, length=None, convert_unicode=False):
self.length = length
self.convert_unicode = convert_unicode
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+
def test_joins(self):
"""test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after
def test_scalar_select(self):
s = select([table1.c.myid], scalar=True, correlate=False)
- self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
+ self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable")
s = select([table1.c.myid], scalar=True)
- self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
+ self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable")
s = select([table1.c.myid]).correlate(None).as_scalar()
- self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
+ self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable")
s = select([table1.c.myid]).as_scalar()
- self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
+ self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable")
# test expressions against scalar selects
- self.assert_compile(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal")
- self.assert_compile(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal")
- self.assert_compile(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal")
+ self.assert_compile(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal AS anon_1")
+ self.assert_compile(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal AS anon_1")
+ self.assert_compile(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal AS anon_1")
self.assert_compile(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo")
s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True)
j1 = table1.join(table2, table1.c.myid==table2.c.otherid)
s2 = select([table1, s1], from_obj=[j1])
- self.assert_compile(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
+ self.assert_compile(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) AS anon_1 FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
def testlabelcomparison(self):
x = func.lala(table1.c.myid).label('foo')
def testliteral(self):
self.assert_compile(select([literal("foo") + literal("bar")], from_obj=[table1]),
- "SELECT :literal || :literal_1 FROM mytable")
+ "SELECT :literal || :literal_1 AS anon_1 FROM mytable")
def testcalculatedcolumns(self):
value_tbl = table('values',
self.assert_compile(
select([value_tbl.c.id, (value_tbl.c.val2 -
value_tbl.c.val1)/value_tbl.c.val1]),
- "SELECT values.id, (values.val2 - values.val1) / values.val1 FROM values"
+ "SELECT values.id, (values.val2 - values.val1) / values.val1 AS anon_1 FROM values"
)
self.assert_compile(
# coverage on other dialects.
sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect)
if isinstance(dialect, type(mysql.dialect())):
- self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) \nFROM casttest")
+ self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
else:
- self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest")
+ self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
# first test with Postgres engine
check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s')
import datetime, os
from sqlalchemy import *
from sqlalchemy import types
+from sqlalchemy.sql import operators
import sqlalchemy.engine.url as url
from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
from testlib import *
# put a number less than the typical MySQL default BLOB size
return file(f).read(len)
+class ExpressionTest(AssertMixin):
+ def setUpAll(self):
+ global test_table, meta
+
+ class MyCustomType(types.TypeEngine):
+ def get_col_spec(self):
+ return "INT"
+ def bind_processor(self, dialect):
+ def process(value):
+ return value * 10
+ return process
+ def result_processor(self, dialect):
+ def process(value):
+ return value / 10
+ return process
+ def adapt_operator(self, op):
+ return {operators.add:operators.sub, operators.sub:operators.add}.get(op, op)
+
+ meta = MetaData(testbase.db)
+ test_table = Table('test', meta,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)),
+ Column('timestamp', Date),
+ Column('value', MyCustomType))
+
+ meta.create_all()
+
+ test_table.insert().execute({'id':1, 'data':'somedata', 'timestamp':datetime.date(2007, 10, 15), 'value':25})
+
+ def tearDownAll(self):
+ meta.drop_all()
+
+ def test_control(self):
+ assert testbase.db.execute("select value from test").scalar() == 250
+
+ assert test_table.select().execute().fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+
+ def test_bind_adapt(self):
+ expr = test_table.c.timestamp == bindparam("thedate")
+ assert expr.right.type.__class__ == test_table.c.timestamp.type.__class__
+
+ assert testbase.db.execute(test_table.select().where(expr), {"thedate":datetime.date(2007, 10, 15)}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+
+ expr = test_table.c.value == bindparam("somevalue")
+ assert expr.right.type.__class__ == test_table.c.value.type.__class__
+ assert testbase.db.execute(test_table.select().where(expr), {"somevalue":25}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+
+
+ def test_operator_adapt(self):
+ """test type-based overloading of operators"""
+
+ # test string concatenation
+ expr = test_table.c.data + "somedata"
+ assert testbase.db.execute(select([expr])).scalar() == "somedatasomedata"
+ expr = test_table.c.id + 15
+ assert testbase.db.execute(select([expr])).scalar() == 16
+
+ # test custom operator conversion
+ expr = test_table.c.value + 40
+ assert expr.type.__class__ is test_table.c.value.type.__class__
+
+ # + operator converted to -
+ # value is calculated as: (250 - (40 * 10)) / 10 == -15
+ assert testbase.db.execute(select([expr.label('foo')])).scalar() == -15
+
+ # this one relies upon anonymous labeling to assemble result
+ # processing rules on the column.
+ assert testbase.db.execute(select([expr])).scalar() == -15
+
class DateTest(AssertMixin):
def setUpAll(self):
global users_with_date, insert_data