From 591bba69232275b74f6473bee2209310f8832bc0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 22 Oct 2005 05:00:30 +0000 Subject: [PATCH] --- lib/sqlalchemy/sql.py | 84 ++++++++++++++++++++++++++----------------- test/testbase.py | 19 +++++++--- 2 files changed, 66 insertions(+), 37 deletions(-) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 6322b95227..3ab1e5ec29 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -259,8 +259,43 @@ class ClauseElement(object): def result(self, **params): """the same as execute(), except a RowProxy object is returned instead of a DBAPI cursor.""" raise NotImplementedError() + +class CompareMixin(object): + def __lt__(self, other): + return self._compare('<', other) -class ColumnClause(ClauseElement): + def __le__(self, other): + return self._compare('<=', other) + + def __eq__(self, other): + return self._compare('=', other) + + def __ne__(self, other): + return self._compare('!=', other) + + def __gt__(self, other): + return self._compare('>', other) + + def __ge__(self, other): + return self._compare('>=', other) + + def like(self, other): + return self._compare('LIKE', other) + + def in_(self, *other): + if _is_literal(other[0]): + return self._compare('IN', CompoundClause(',', other)) + else: + return self._compare('IN', union(*other)) + + def startswith(self, other): + return self._compare('LIKE', str(other) + "%") + + def endswith(self, other): + return self._compare('LIKE', "%" + str(other)) + + +class ColumnClause(ClauseElement, CompareMixin): """represents a textual column clause in a SQL statement.""" def __init__(self, text, selectable): @@ -284,6 +319,19 @@ class ColumnClause(ClauseElement): def _get_from_objects(self): return [] + def _compare(self, operator, obj): + if _is_literal(obj): + if obj is None: + if operator != '=': + raise "Only '=' operator can be used with NULL" + return BinaryClause(self, null(), 'IS') + elif self.table.name is None: + obj = BindParamClause(self.text, obj, shortname=self.text, type=self.type) + else: + obj = BindParamClause(self.table.name + "_" + self.text, obj, shortname = self.text, type=self.type) + + return BinaryClause(self, obj, operator) + def _make_proxy(self, selectable, name = None): c = ColumnClause(self.text or name, selectable) selectable.columns[c.key] = c @@ -525,7 +573,8 @@ class Alias(Selectable): return select([self], whereclauses, **params) -class ColumnSelectable(Selectable): + +class ColumnSelectable(Selectable, CompareMixin): """Selectable implementation that gets attached to a schema.Column object.""" def __init__(self, column): @@ -562,39 +611,8 @@ class ColumnSelectable(Selectable): return BinaryClause(self.column, obj, operator) - def __lt__(self, other): - return self._compare('<', other) - - def __le__(self, other): - return self._compare('<=', other) - def __eq__(self, other): - return self._compare('=', other) - - def __ne__(self, other): - return self._compare('!=', other) - - def __gt__(self, other): - return self._compare('>', other) - def __ge__(self, other): - return self._compare('>=', other) - - def like(self, other): - return self._compare('LIKE', other) - - def in_(self, *other): - if _is_literal(other[0]): - return self._compare('IN', CompoundClause(',', other)) - else: - return self._compare('IN', union(*other)) - - def startswith(self, other): - return self._compare('LIKE', str(other) + "%") - - def endswith(self, other): - return self._compare('LIKE', "%" + str(other)) - class TableImpl(Selectable): """attached to a schema.Table to provide it with a Selectable interface as well as other functions diff --git a/test/testbase.py b/test/testbase.py index a578521f05..c48c48ef24 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -71,7 +71,18 @@ class EngineAssert(object): self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) return self.realexec(statement, parameters, **kwargs) -def runTests(suite): - runner = unittest.TextTestRunner(verbosity = 2, descriptions =1) - runner.run(suite) - +def runTests(*modules): + for m in modules: + if m.__dict__.has_key('startUp'): + m.startUp() + s = suite(m) + runner = unittest.TextTestRunner(verbosity = 2, descriptions =1) + runner.run(s) + if m.__dict__.has_key('tearDown'): + m.tearDown() + +def suite(modules): + alltests = unittest.TestSuite() + for module in map(__import__, modules): + alltests.addTest(unittest.findTestCases(module)) + return alltests -- 2.47.2