From: Mike Bayer Date: Thu, 31 Jan 2008 03:57:20 +0000 (+0000) Subject: - the startswith(), endswith(), and contains() operators X-Git-Tag: rel_0_4_3~59 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=a5b23bda66bc5ee52efeefd58b6e5e69c8f8d330;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - the startswith(), endswith(), and contains() operators now concatenate the wildcard operator with the given operand in SQL, i.e. "'%' || " in all cases, accept text('something') operands properly [ticket:962] - cast() accepts text('something') and other non-literal operands properly [ticket:962] --- diff --git a/CHANGES b/CHANGES index 01f04acd75..97f480eb9f 100644 --- a/CHANGES +++ b/CHANGES @@ -9,6 +9,14 @@ CHANGES to ILIKE on postgres, lower(x) LIKE lower(y) on all others. [ticket:727] + - the startswith(), endswith(), and contains() operators + now concatenate the wildcard operator with the given + operand in SQL, i.e. "'%' || " in all cases, + accept text('something') operands properly [ticket:962] + + - cast() accepts text('something') and other non-literal + operands properly [ticket:962] + - The '.c.' attribute on a selectable now gets an entry for every column expression in its columns clause. Previously, "unnamed" columns like functions and CASE diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index aff8654f25..6c0c4659ec 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -640,21 +640,26 @@ def column(text, type_=None): return _ColumnClause(text, type_=type_) def literal_column(text, type_=None): - """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. + """Return a textual column expression, as would be in the columns + clause of a ``SELECT`` statement. - The object returned is an instance of [sqlalchemy.sql.expression#_ColumnClause], - which represents the "syntactical" portion of the schema-level - [sqlalchemy.schema#Column] object. + The object returned supports further expressions in the same way + as any other column object, including comparison, math and string + operations. The type_ parameter is important to determine proper + expression behavior (such as, '+' means string concatenation or + numerical addition based on the type). text - the name of the column. Quoting rules will not be applied to - the column. For textual column constructs that should be quoted - like any other column construct, use the - [sqlalchemy.sql.expression#column()] function. + the text of the expression; can be any SQL expression. Quoting rules + will not be applied. To specify a column-name expression which should + be subject to quoting rules, use the [sqlalchemy.sql.expression#column()] + function. - type + type_ an optional [sqlalchemy.types#TypeEngine] object which will - provide result-set translation for this column. + provide result-set translation and additional expression + semantics for this column. If left as None the type will be + NullType. """ return _ColumnClause(text, type_=type_, is_literal=True) @@ -1173,7 +1178,7 @@ class ColumnOperators(Operators): class _CompareMixin(ColumnOperators): """Defines comparison and math operations for ``ClauseElement`` instances.""" - def __compare(self, op, obj, negate=None): + def __compare(self, op, obj, negate=None, reverse=False): if obj is None or isinstance(obj, _Null): if op == operators.eq: return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) @@ -1183,14 +1188,21 @@ class _CompareMixin(ColumnOperators): raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) - return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate) + + if reverse: + return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate) + else: + return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate) - def __operate(self, op, obj): + def __operate(self, op, obj, reverse=False): obj = self._check_literal(obj) type_ = self._compare_type(obj) - return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) + if reverse: + return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_) + else: + return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1216,7 +1228,8 @@ class _CompareMixin(ColumnOperators): return o[0](self, op, other[0], *o[1:]) def reverse_operate(self, op, other): - return self._bind_param(other).operate(op, self) + o = _CompareMixin.operators[op] + return o[0](self, op, other, reverse=True, *o[1:]) def in_(self, *other): return self._in_impl(operators.in_op, operators.notin_op, *other) @@ -1251,29 +1264,18 @@ class _CompareMixin(ColumnOperators): def startswith(self, other): """Produce the clause ``LIKE '%'``""" - perc = isinstance(other, basestring) and '%' or literal('%', type_=sqltypes.String) - return self.__compare(operators.like_op, other + perc) + # use __radd__ to force string concat behavior + return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other))) def endswith(self, other): """Produce the clause ``LIKE '%'``""" - if isinstance(other, basestring): - po = '%' + other - else: - po = literal('%', type_=sqltypes.String) + other - po.type = sqltypes.to_instance(sqltypes.String) #force! - return self.__compare(operators.like_op, po) + return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String) + self._check_literal(other)) def contains(self, other): """Produce the clause ``LIKE '%%'``""" - if isinstance(other, basestring): - po = '%' + other + '%' - else: - perc = literal('%', type_=sqltypes.String) - po = perc + other + perc - po.type = sqltypes.to_instance(sqltypes.String) #force! - return self.__compare(operators.like_op, po) + return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String) + self._check_literal(other) + literal_column("'%'", type_=sqltypes.String)) def label(self, name): """Produce a column label, i.e. `` AS ``. @@ -2030,10 +2032,8 @@ class _Cast(ColumnElement): def __init__(self, clause, totype, **kwargs): ColumnElement.__init__(self) - if not hasattr(clause, 'label'): - clause = literal(clause) self.type = sqltypes.to_instance(totype) - self.clause = clause + self.clause = _literal_as_binds(clause, None) self.typeclause = _TypeClause(self.type) def _copy_internals(self, clone=_clone): diff --git a/test/sql/select.py b/test/sql/select.py index c34cec7c51..522f9a2ffd 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -451,29 +451,33 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A clause = (table1.c.myid == 12) & table1.c.myid.between(15, 20) & table1.c.myid.like('hoho') assert str(clause) == str(util.pickle.loads(util.pickle.dumps(clause))) - - - def testextracomparisonoperators(self): + def test_composed_string_comparators(self): self.assert_compile( - table1.select(table1.c.name.contains('jo')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name_1", - checkparams = {'mytable_name_1': u'%jo%'}, + table1.c.name.contains('jo'), "mytable.name LIKE '%' || :mytable_name_1 || '%'" , checkparams = {'mytable_name_1': u'jo'}, ) self.assert_compile( - table1.select(table1.c.name.endswith('hn')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name_1", - checkparams = {'mytable_name_1': u'%hn'}, + table1.c.name.contains('jo'), "mytable.name LIKE concat(concat('%', %s), '%')" , checkparams = {'mytable_name_1': u'jo'}, + dialect=mysql.dialect() + ) + self.assert_compile( + table1.c.name.endswith('hn'), "mytable.name LIKE '%' || :mytable_name_1", checkparams = {'mytable_name_1': u'hn'}, + ) + self.assert_compile( + table1.c.name.endswith('hn'), "mytable.name LIKE concat('%', %s)", + checkparams = {'mytable_name_1': u'hn'}, dialect=mysql.dialect() ) - - def testunicodestartswith(self): - string = u"hi \xf6 \xf5" self.assert_compile( - table1.select(table1.c.name.startswith(string)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name_1", - checkparams = {'mytable_name_1': u'hi \xf6 \xf5%'}, + table1.c.name.startswith(u"hi \xf6 \xf5"), "mytable.name LIKE :mytable_name_1 || '%'", + checkparams = {'mytable_name_1': u'hi \xf6 \xf5'}, ) + self.assert_compile(column('name').endswith(text("'foo'")), "name LIKE '%' || 'foo'" ) + self.assert_compile(column('name').endswith(literal_column("'foo'")), "name LIKE '%' || 'foo'" ) + self.assert_compile(column('name').startswith(text("'foo'")), "name LIKE 'foo' || '%'" ) + self.assert_compile(column('name').startswith(text("'foo'")), "name LIKE concat('foo', '%')", dialect=mysql.dialect()) + self.assert_compile(column('name').startswith(literal_column("'foo'")), "name LIKE 'foo' || '%'" ) + self.assert_compile(column('name').startswith(literal_column("'foo'")), "name LIKE concat('foo', '%')", dialect=mysql.dialect()) - def testmultiparam(self): + def test_multiple_col_binds(self): self.assert_compile( select(["*"], or_(table1.c.myid == 12, table1.c.myid=='asdf', table1.c.myid == 'foo')), "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2 OR mytable.myid = :mytable_myid_3" @@ -1067,14 +1071,14 @@ EXISTS (select yay from foo where boo = lar)", assert str(err) == "Bind parameter 'mytable_myid_1' conflicts with unique bind parameter of the same name" - def testbindascol(self): + def test_bind_as_col(self): t = table('foo', column('id')) s = select([t, literal('lala').label('hoho')]) self.assert_compile(s, "SELECT foo.id, :param_1 AS hoho FROM foo") assert [str(c) for c in s.c] == ["id", "hoho"] - def testin(self): + def test_in(self): self.assert_compile(select([table1], table1.c.myid.in_(['a'])), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid_1)") @@ -1179,7 +1183,7 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE self.assert_compile(select([table1], table1.c.myid.in_()), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)") - def testcast(self): + def test_cast(self): tbl = table('casttest', column('id', Integer), column('v1', Float), @@ -1215,7 +1219,11 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE # then the MySQL engine check_results(mysql.dialect(), ['DECIMAL(10, 2)', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s') - def testdatebetween(self): + self.assert_compile(cast(text('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) + self.assert_compile(cast(null(), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) + self.assert_compile(cast(literal_column('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) + + def test_date_between(self): import datetime table = Table('dt', metadata, Column('date', Date))