From: Mike Bayer Date: Sat, 18 Mar 2006 00:32:49 +0000 (+0000) Subject: got oracle parenthesized rules for funcs back, fixed copy_container on function X-Git-Tag: rel_0_1_5~54 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=48be2fbb7c8f03e3931bda839fd7507c3cb12ae2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git got oracle parenthesized rules for funcs back, fixed copy_container on function --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index b05872182b..c0aebcbaed 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -220,10 +220,13 @@ class ANSICompiler(sql.Compiled): else: self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') + def apply_function_parens(self, func): + return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 + def visit_function(self, func): if len(self.select_stack): self.typemap.setdefault(func.name, func.type) - if func.name.upper() in ANSI_FUNCS and not len(func.clauses): + if not self.apply_function_parens(func): self.strings[func] = ".".join(func.packagenames + [func.name]) else: self.strings[func] = ".".join(func.packagenames + [func.name]) + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 003f0aadbc..546bd1462e 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -221,6 +221,9 @@ class OracleCompiler(ansisql.ANSICompiler): gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ return " FROM DUAL" + def apply_function_parens(self, func): + return len(func.clauses) > 0 + def visit_join(self, join): if self._use_ansi: return ansisql.ANSICompiler.visit_join(self, join) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 24263184d1..87ce3b9652 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -794,7 +794,7 @@ class Function(ClauseList, ColumnElement): def __init__(self, name, *clauses, **kwargs): self.name = name self.type = kwargs.get('type', sqltypes.NULLTYPE) - self.packagenames = kwargs.get('packagenames') + self.packagenames = kwargs.get('packagenames', None) or [] self._engine = kwargs.get('engine', None) if self._engine is not None: self.type = self._engine.type_descriptor(self.type) @@ -813,7 +813,7 @@ class Function(ClauseList, ColumnElement): data.setdefault(self, self) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] - return Function(self.name, type=self.type, *clauses) + return Function(self.name, type=self.type, packagenames=self.packagenames, *clauses) def accept_visitor(self, visitor): for c in self.clauses: c.accept_visitor(visitor)