From: Robert Leftwich Date: Tue, 4 Apr 2006 00:28:33 +0000 (+0000) Subject: Added cast() to allow use of cast(tbl.c.col as Numeric(4,2)) in select and where... X-Git-Tag: rel_0_1_6~19 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1c8f771344620e18e541e8196e841c4112068825;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added cast() to allow use of cast(tbl.c.col as Numeric(4,2)) in select and where clauses. Unit tests for same. --- diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 816fac3781..532b060a71 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -13,7 +13,7 @@ from exceptions import * import string, re, random types = __import__('types') -__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] +__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] def desc(column): """returns a descending ORDER BY clause element, e.g.: @@ -132,6 +132,25 @@ def between_(ctest, cleft, cright): """ returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """ return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN') +def cast(clause, totype, **kwargs): + """ returns CAST function CAST(clause AS totype) + Use with a sqlalchemy.types.TypeEngine object, i.e + cast(table.c.unit_price * table.c.qty, Numeric(10,4)) + or + cast(table.c.timestamp, DATE) + """ + engine = kwargs.get('engine', None) + if engine is None: + engine = getattr(clause, 'engine', None) + if engine is not None: + totype_desc = engine.type_descriptor(totype) + # handle non-column clauses (e.g. cast(1234, TEXT) + if not hasattr(clause, 'label'): + clause = literal(clause) + return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs) + else: + raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype)) + def exists(*args, **params): params['correlate'] = True s = select(*args, **params) diff --git a/test/select.py b/test/select.py index 0fc406e4a6..adaf65ec1a 100644 --- a/test/select.py +++ b/test/select.py @@ -4,6 +4,7 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.databases.postgres as postgres import sqlalchemy.databases.oracle as oracle import sqlalchemy.databases.sqlite as sqlite +import sqlalchemy.databases.mysql as mysql db = ansisql.engine() #db = create_engine('mssql') @@ -532,6 +533,34 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'}) + def testcast(self): + tbl = table('casttest', + Column('id', Integer), + Column('v1', Float), + Column('v2', Float), + Column('ts', TIMESTAMP), + ) + + def check_results(engine, expected_results, literal): + self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results') + self.assertEqual(str(cast(tbl.c.v1, Numeric, engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) + self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9), engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) + self.assertEqual(str(cast(tbl.c.ts, Date, engine=engine)), 'CAST(casttest.ts AS %s)' %expected_results[2]) + self.assertEqual(str(cast(1234, TEXT, engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[3])) + self.assertEqual(str(cast('test', String(20), engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[4])) + + # first test with Postgres engine + check_results(postgres.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s') + + # then the Oracle engine + check_results(oracle.engine({}, use_ansi = False), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal') + + # then the sqlite engine + check_results(sqlite.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') + + # and the MySQL engine + check_results(mysql.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s') + class CRUDTest(SQLTest): def testinsert(self): # generic insert, will create bind params for all columns