]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added cast() to allow use of cast(tbl.c.col as Numeric(4,2)) in select and where...
authorRobert Leftwich <rtl@pobox.com>
Tue, 4 Apr 2006 00:28:33 +0000 (00:28 +0000)
committerRobert Leftwich <rtl@pobox.com>
Tue, 4 Apr 2006 00:28:33 +0000 (00:28 +0000)
lib/sqlalchemy/sql.py
test/select.py

index 816fac37818b24139dc24752b7f59262e616a89f..532b060a71108e06bfc099d65d21032a80199211 100644 (file)
@@ -13,7 +13,7 @@ from exceptions import *
 import string, re, random
 types = __import__('types')
 
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
 
 def desc(column):
     """returns a descending ORDER BY clause element, e.g.:
@@ -132,6 +132,25 @@ def between_(ctest, cleft, cright):
     """ returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """
     return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN')
         
+def cast(clause, totype, **kwargs):
+    """ returns CAST function CAST(clause AS totype) 
+        Use with a sqlalchemy.types.TypeEngine object, i.e
+        cast(table.c.unit_price * table.c.qty, Numeric(10,4))
+         or
+        cast(table.c.timestamp, DATE)
+    """
+    engine = kwargs.get('engine', None)
+    if engine is None:
+        engine = getattr(clause, 'engine', None)
+    if engine is not None:
+        totype_desc = engine.type_descriptor(totype)
+        # handle non-column clauses (e.g. cast(1234, TEXT)
+        if not hasattr(clause, 'label'):
+            clause = literal(clause)
+        return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs)
+    else:
+        raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype))
+        
 def exists(*args, **params):
     params['correlate'] = True
     s = select(*args, **params)
index 0fc406e4a64b13b5d03dfbb392ee9a641b22bed4..adaf65ec1a8ff51708de959e69d2afad1a2e62d2 100644 (file)
@@ -4,6 +4,7 @@ import sqlalchemy.ansisql as ansisql
 import sqlalchemy.databases.postgres as postgres
 import sqlalchemy.databases.oracle as oracle
 import sqlalchemy.databases.sqlite as sqlite
+import sqlalchemy.databases.mysql as mysql
 
 db = ansisql.engine()
 #db = create_engine('mssql')
@@ -532,6 +533,34 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
 
         self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'})
         
+    def testcast(self):
+        tbl = table('casttest',
+                    Column('id', Integer),
+                    Column('v1', Float),
+                    Column('v2', Float),
+                    Column('ts', TIMESTAMP),
+                    )
+        
+        def check_results(engine, expected_results, literal):
+            self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results')
+            self.assertEqual(str(cast(tbl.c.v1, Numeric, engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
+            self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9), engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
+            self.assertEqual(str(cast(tbl.c.ts, Date, engine=engine)), 'CAST(casttest.ts AS %s)' %expected_results[2])
+            self.assertEqual(str(cast(1234, TEXT, engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
+            self.assertEqual(str(cast('test', String(20), engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
+            
+        # first test with Postgres engine
+        check_results(postgres.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s')
+
+        # then the Oracle engine
+        check_results(oracle.engine({}, use_ansi = False), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal')
+
+        # then the sqlite engine
+        check_results(sqlite.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
+
+        # and the MySQL engine
+        check_results(mysql.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s')
+
 class CRUDTest(SQLTest):
     def testinsert(self):
         # generic insert, will create bind params for all columns