From 1ffed8432e282aa57ecde9f3e4ca778a1756ddc0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 17 Jun 2006 00:53:33 +0000 Subject: [PATCH] cast converted into its own ClauseElement so that it can have an explicit compilation function in ANSICompiler MySQLCompiler then skips most CAST calls since it only seems to support the standard syntax for Date types; other types now a TODO for MySQL then, polymorphic_union() function now CASTs null()s to the type corresponding to the columns in the UNION, since postgres doesnt like mixing NULL with integer types (long road for that .....) --- CHANGES | 5 +++++ lib/sqlalchemy/ansisql.py | 8 +++++++- lib/sqlalchemy/databases/mysql.py | 10 +++++++++- lib/sqlalchemy/orm/util.py | 5 +++-- lib/sqlalchemy/sql.py | 25 ++++++++++++++++++------- test/sql/select.py | 8 ++++---- 6 files changed, 46 insertions(+), 15 deletions(-) diff --git a/CHANGES b/CHANGES index 9b80778efa..5bcbda4d73 100644 --- a/CHANGES +++ b/CHANGES @@ -40,6 +40,11 @@ function at the moment information better [ticket:202] - if an object fails to be constructed, is not added to the session [ticket:203] +- CAST function has been made into its own clause object with +its own compilation function in ansicompiler; allows MySQL +to silently ignore most CAST calls since MySQL +seems to only support the standard CAST syntax with Date types. +MySQL-compatible CAST support for strings, ints, etc. a TODO 0.2.2 - big improvements to polymorphic inheritance behavior, enabling it diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index cdd8604402..82abb95775 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -231,7 +231,13 @@ class ANSICompiler(sql.Compiled): self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")" else: self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ') - + + def visit_cast(self, cast): + if len(self.select_stack): + # not sure if we want to set the typemap here... + self.typemap.setdefault("CAST", cast.type) + self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause]) + def visit_function(self, func): if len(self.select_stack): self.typemap.setdefault(func.name, func.type) diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index aa05134d04..e32d6f120a 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -9,7 +9,6 @@ import sys, StringIO, string, types, re, datetime from sqlalchemy import sql,engine,schema,ansisql from sqlalchemy.engine import default import sqlalchemy.types as sqltypes -import sqlalchemy.databases.information_schema as ischema import sqlalchemy.exceptions as exceptions try: @@ -250,6 +249,15 @@ class MySQLDialect(ansisql.ANSIDialect): class MySQLCompiler(ansisql.ANSICompiler): + def visit_cast(self, cast): + """hey ho MySQL supports almost no types at all for CAST""" + if (isinstance(cast.type, sqltypes.Date) or isinstance(cast.type, sqltypes.Time) or isinstance(cast.type, sqltypes.DateTime)): + return super(MySQLCompiler, self).visit_cast(cast) + else: + # so just skip the CAST altogether for now. + # TODO: put whatever MySQL does for CAST here. + self.strings[cast] = self.strings[cast.clause] + def limit_clause(self, select): text = "" if select.limit is not None: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 86799b3116..10f7b4e808 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -24,7 +24,7 @@ class CascadeOptions(object): def polymorphic_union(table_map, typecolname, aliasname='p_union'): colnames = util.Set() colnamemaps = {} - + types = {} for key in table_map.keys(): table = table_map[key] @@ -37,13 +37,14 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'): for c in table.c: colnames.add(c.name) m[c.name] = c + types[c.name] = c.type colnamemaps[table] = m def col(name, table): try: return colnamemaps[table][name] except KeyError: - return sql.null().label(name) + return sql.cast(sql.null(), types[name]).label(name) result = [] for type, table in table_map.iteritems(): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 0cacea12d0..d978ee208e 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -153,12 +153,9 @@ def cast(clause, totype, **kwargs): or cast(table.c.timestamp, DATE) """ - # handle non-column clauses (e.g. cast(1234, TEXT) - if not hasattr(clause, 'label'): - clause = literal(clause) - totype = sqltypes.to_instance(totype) - return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs) - + return Cast(clause, totype, **kwargs) + + def exists(*args, **params): params['correlate'] = True s = select(*args, **params) @@ -320,6 +317,7 @@ class ClauseVisitor(object): def visit_clauselist(self, list):pass def visit_calculatedclause(self, calcclause):pass def visit_function(self, func):pass + def visit_cast(self, cast):pass def visit_label(self, label):pass def visit_typeclause(self, typeclause):pass @@ -974,7 +972,20 @@ class Function(CalculatedClause): c.accept_visitor(visitor) visitor.visit_function(self) - +class Cast(ColumnElement): + def __init__(self, clause, totype, **kwargs): + if not hasattr(clause, 'label'): + clause = literal(clause) + self.type = sqltypes.to_instance(totype) + self.clause = clause + self.typeclause = TypeClause(self.type) + def accept_visitor(self, visitor): + self.clause.accept_visitor(visitor) + self.typeclause.accept_visitor(visitor) + visitor.visit_cast(self) + def _get_from_objects(self): + return self.clause._get_from_objects() + class FunctionGenerator(object): """generates Function objects based on getattr calls""" def __init__(self, engine=None): diff --git a/test/sql/select.py b/test/sql/select.py index 290e324cf2..d78f36b1ad 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -556,14 +556,14 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s') # then the Oracle engine -# check_results(oracle.OracleDialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal') + check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal') # then the sqlite engine check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') - # and the MySQL engine - check_results(mysql.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s') - + # MySQL seems to only support DATE types for cast + self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=mysql.dialect())), 'CAST(casttest.ts AS DATE)') + self.assertEqual(str(cast(tbl.c.ts, Numeric).compile(dialect=mysql.dialect())), 'casttest.ts') def testdatebetween(self): import datetime -- 2.47.2