]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
expanded and integrated qvx's patch for dotted function names
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Mar 2006 23:38:14 +0000 (23:38 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Mar 2006 23:38:14 +0000 (23:38 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/sql.py
test/select.py

index 6a5ef8ef0d43b7ec9d9e99d395c4ca21b6a2027e..b05872182bc6e23f1e298af57ec14272e8848055 100644 (file)
@@ -224,9 +224,9 @@ class ANSICompiler(sql.Compiled):
         if len(self.select_stack):
             self.typemap.setdefault(func.name, func.type)
         if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
-            self.strings[func] = func.name
+            self.strings[func] = ".".join(func.packagenames + [func.name])
         else:
-            self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
+            self.strings[func] = ".".join(func.packagenames + [func.name]) + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
         
     def visit_compound_select(self, cs):
         text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
index 44ffadfeb3b9b30a395935e9d817eb8d43c2b57d..bee0e6ff6a867c535145317461fa1d3a5da4c227 100644 (file)
@@ -280,10 +280,7 @@ class SQLEngine(schema.SchemaEngine):
         return typeobj
 
     def _func(self):
-        class FunctionGateway(object):
-            def __getattr__(s, name):
-                return lambda *c, **kwargs: sql.Function(name, engine=self, *c, **kwargs)
-        return FunctionGateway()
+        return sql.FunctionGenerator(self)
     func = property(_func)
     
     def text(self, text, *args, **kwargs):
index 0df297743755a7db40d8a0d34ff76fdaf68ca13f..24263184d1cc6017f585594225780451ebece27e 100644 (file)
@@ -218,7 +218,7 @@ class FunctionGateway(object):
     """returns a callable based on an attribute name, which then returns a Function 
     object with that name."""
     def __getattr__(self, name):
-        return lambda *c, **kwargs: Function(name, *c, **kwargs)
+        return getattr(FunctionGenerator(), name)
 func = FunctionGateway()
 
 def _compound_clause(keyword, *clauses):
@@ -794,6 +794,7 @@ class Function(ClauseList, ColumnElement):
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
         self.type = kwargs.get('type', sqltypes.NULLTYPE)
+        self.packagenames = kwargs.get('packagenames')
         self._engine = kwargs.get('engine', None)
         if self._engine is not None:
             self.type = self._engine.type_descriptor(self.type)
@@ -827,6 +828,17 @@ class Function(ClauseList, ColumnElement):
         return select([self]).execute()
     def _compare_type(self, obj):
         return self.type
+
+class FunctionGenerator(object):
+    """generates Function objects based on getattr calls"""
+    def __init__(self, engine=None):
+        self.__engine = engine
+        self.__names = []
+    def __getattr__(self, name):
+        self.__names.append(name)
+        return self
+    def __call__(self, *c, **kwargs):
+        return Function(self.__names[-1], packagenames=self.__names[0:-1], engine=self.__engine, *c, **kwargs)     
                 
 class BinaryClause(ClauseElement):
     """represents two clauses with an operator in between"""
index cec8939d3b8639b8e62d7a11f2e59dabf1bce1f4..610e05cb40caa1060dbf6264c60a8519a4e05a84 100644 (file)
@@ -343,12 +343,27 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
             "SELECT :literal + :literal_1 FROM mytable")
 
     def testfunction(self):
+        """tests the generation of functions using the func keyword"""
+        # test an expression with a function
         self.runtest(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, 
             "lala(:lala, :lala_1, :literal, mytable.myid) * myothertable.otherid")
 
+        # test it in a SELECT
         self.runtest(select([func.count(table1.c.myid)]), 
             "SELECT count(mytable.myid) FROM mytable")
 
+        # test a "dotted" function name
+        self.runtest(select([func.foo.bar.lala(table1.c.myid)]), 
+            "SELECT foo.bar.lala(mytable.myid) FROM mytable")
+
+        # test the bind parameter name with a "dotted" function name is only the name
+        # (limits the length of the bind param name)
+        self.runtest(select([func.foo.bar.lala(12)]), 
+            "SELECT foo.bar.lala(:lala)")
+
+        # test a dotted func off the engine itself
+        self.runtest(db.func.lala.hoho(7), "lala.hoho(:hoho)")
+        
     def testjoin(self):
         self.runtest(
             join(table2, table1, table1.c.myid == table2.c.otherid).select(),