]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
cast converted into its own ClauseElement so that it can have an explicit compilation rel_0_2_3
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jun 2006 00:53:33 +0000 (00:53 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jun 2006 00:53:33 +0000 (00:53 +0000)
function in ANSICompiler
MySQLCompiler then skips most CAST calls since it only seems to support the standard syntax for Date
types; other types now a TODO for MySQL
then, polymorphic_union() function now CASTs null()s to the type corresponding to the columns in the UNION,
since postgres doesnt like mixing NULL with integer types
(long road for that .....)

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 9b80778efacfb9971bafd01e85683f431bafc92c..5bcbda4d737cb209e9cf13b7e0886601efeea586 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -40,6 +40,11 @@ function at the moment
 information better [ticket:202]
 - if an object fails to be constructed, is not added to the 
 session [ticket:203]
+- CAST function has been made into its own clause object with
+its own compilation function in ansicompiler; allows MySQL
+to silently ignore most CAST calls since MySQL
+seems to only support the standard CAST syntax with Date types.  
+MySQL-compatible CAST support for strings, ints, etc. a TODO
 
 0.2.2
 - big improvements to polymorphic inheritance behavior, enabling it
index cdd86044021448a492a8afc8e01799528bbb4b41..82abb95775ebf285b4d001bb13dbd062290cec95 100644 (file)
@@ -231,7 +231,13 @@ class ANSICompiler(sql.Compiled):
             self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")"
         else:
             self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ')
-       
+      
+    def visit_cast(self, cast):
+        if len(self.select_stack):
+            # not sure if we want to set the typemap here...
+            self.typemap.setdefault("CAST", cast.type)
+        self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause])
+         
     def visit_function(self, func):
         if len(self.select_stack):
             self.typemap.setdefault(func.name, func.type)
index aa05134d0416dcbad644daad820709eb038b42da..e32d6f120a470bbf2096af26fed35656ee0edde0 100644 (file)
@@ -9,7 +9,6 @@ import sys, StringIO, string, types, re, datetime
 from sqlalchemy import sql,engine,schema,ansisql
 from sqlalchemy.engine import default
 import sqlalchemy.types as sqltypes
-import sqlalchemy.databases.information_schema as ischema
 import sqlalchemy.exceptions as exceptions
 
 try:
@@ -250,6 +249,15 @@ class MySQLDialect(ansisql.ANSIDialect):
 
 class MySQLCompiler(ansisql.ANSICompiler):
 
+    def visit_cast(self, cast):
+        """hey ho MySQL supports almost no types at all for CAST"""
+        if (isinstance(cast.type, sqltypes.Date) or isinstance(cast.type, sqltypes.Time) or isinstance(cast.type, sqltypes.DateTime)):
+            return super(MySQLCompiler, self).visit_cast(cast)
+        else:
+            # so just skip the CAST altogether for now.
+            # TODO: put whatever MySQL does for CAST here.
+            self.strings[cast] = self.strings[cast.clause]
+
     def limit_clause(self, select):
         text = ""
         if select.limit is not None:
index 86799b31166e25d67dab1a4091cc14ee9698b5f5..10f7b4e808d542aa2cd7ea0e6e2ee70ccd23771a 100644 (file)
@@ -24,7 +24,7 @@ class CascadeOptions(object):
 def polymorphic_union(table_map, typecolname, aliasname='p_union'):
     colnames = util.Set()
     colnamemaps = {}
-    
+    types = {}
     for key in table_map.keys():
         table = table_map[key]
 
@@ -37,13 +37,14 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'):
         for c in table.c:
             colnames.add(c.name)
             m[c.name] = c
+            types[c.name] = c.type
         colnamemaps[table] = m
         
     def col(name, table):
         try:
             return colnamemaps[table][name]
         except KeyError:
-            return sql.null().label(name)
+            return sql.cast(sql.null(), types[name]).label(name)
 
     result = []
     for type, table in table_map.iteritems():
index 0cacea12d0cf115af8d0ff1106a0c33b7cbfa6fd..d978ee208ee8156c9954e17217a9894dc0160617 100644 (file)
@@ -153,12 +153,9 @@ def cast(clause, totype, **kwargs):
          or
         cast(table.c.timestamp, DATE)
     """
-    # handle non-column clauses (e.g. cast(1234, TEXT)
-    if not hasattr(clause, 'label'):
-        clause = literal(clause)
-    totype = sqltypes.to_instance(totype)
-    return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs)
-        
+    return Cast(clause, totype, **kwargs)
+
+
 def exists(*args, **params):
     params['correlate'] = True
     s = select(*args, **params)
@@ -320,6 +317,7 @@ class ClauseVisitor(object):
     def visit_clauselist(self, list):pass
     def visit_calculatedclause(self, calcclause):pass
     def visit_function(self, func):pass
+    def visit_cast(self, cast):pass
     def visit_label(self, label):pass
     def visit_typeclause(self, typeclause):pass
             
@@ -974,7 +972,20 @@ class Function(CalculatedClause):
             c.accept_visitor(visitor)
         visitor.visit_function(self)
 
-
+class Cast(ColumnElement):
+    def __init__(self, clause, totype, **kwargs):
+        if not hasattr(clause, 'label'):
+            clause = literal(clause)
+        self.type = sqltypes.to_instance(totype)
+        self.clause = clause
+        self.typeclause = TypeClause(self.type)
+    def accept_visitor(self, visitor):
+        self.clause.accept_visitor(visitor)
+        self.typeclause.accept_visitor(visitor)
+        visitor.visit_cast(self)
+    def _get_from_objects(self):
+        return self.clause._get_from_objects()
+        
 class FunctionGenerator(object):
     """generates Function objects based on getattr calls"""
     def __init__(self, engine=None):
index 290e324cf2cd75013348b05ac2d88c4bb3630269..d78f36b1add2aea346e98b9cd9db2a029400af9d 100644 (file)
@@ -556,14 +556,14 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
         check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s')
 
         # then the Oracle engine
-#        check_results(oracle.OracleDialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal')
+        check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal')
 
         # then the sqlite engine
         check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
 
-        # and the MySQL engine
-        check_results(mysql.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s')
-
+        # MySQL seems to only support DATE types for cast
+        self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=mysql.dialect())), 'CAST(casttest.ts AS DATE)')
+        self.assertEqual(str(cast(tbl.c.ts, Numeric).compile(dialect=mysql.dialect())), 'casttest.ts')
 
     def testdatebetween(self):
         import datetime