]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Removed the visit_function stuff in mssql dialect. Added some tests for the function...
authorMichael Trier <mtrier@gmail.com>
Sat, 11 Oct 2008 16:14:20 +0000 (16:14 +0000)
committerMichael Trier <mtrier@gmail.com>
Sat, 11 Oct 2008 16:14:20 +0000 (16:14 +0000)
lib/sqlalchemy/databases/mssql.py
test/dialect/mssql.py
test/sql/defaults.py

index 8bf8144cfe2a10b65d18463a3efa44c0748c3708..4c5ad1fd11026c61902d7f2eea70220b69eaffae 100644 (file)
@@ -882,10 +882,12 @@ class MSSQLCompiler(compiler.DefaultCompiler):
     functions = compiler.DefaultCompiler.functions.copy()
     functions.update (
         {
-            sql_functions.now: 'CURRENT_TIMESTAMP'
+            sql_functions.now: 'CURRENT_TIMESTAMP',
+            sql_functions.current_date: 'GETDATE()',
+            'length': lambda x: "LEN(%s)" % x
         }
     )
-    
+
     def __init__(self, *args, **kwargs):
         super(MSSQLCompiler, self).__init__(*args, **kwargs)
         self.tablealiases = {}
@@ -984,14 +986,6 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         else:
             return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
 
-    # TODO: update this to use generic functions
-    function_rewrites =  {'current_date': 'getdate',
-                          'length':     'len',
-                          }
-    def visit_function(self, func, **kwargs):
-        func.name = self.function_rewrites.get(func.name, func.name)
-        return super(MSSQLCompiler, self).visit_function(func, **kwargs)
-
     def for_update_clause(self, select):
         # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
         return ''
index 26f8892bd8494236dcdfa0be54b91f7ca6cf5e97..02c583d5dfe64680d1a96e00918fa912e02b1bd6 100755 (executable)
@@ -105,6 +105,10 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         t = Table('sometable', m, Column('col1', Integer), Column('col2', Integer))
         self.assert_compile(select([func.max(t.c.col1)]), "SELECT max(sometable.col1) AS max_1 FROM sometable")
 
+    def test_function_overrides(self):
+        self.assert_compile(func.current_date(), "GETDATE()")
+        self.assert_compile(func.length(3), "LEN(:length_1)")
+
 class ReflectionTest(TestBase):
     __only_on__ = 'mssql'
 
index dfd626b72ba518cf9841c7575cb0ab71585e73f1..d69174248dfcc05b6d0957eda818593916b22ce1 100644 (file)
@@ -33,14 +33,13 @@ class DefaultTest(testing.TestBase):
                 # since its a "branched" connection
                 conn.close()
 
-        use_function_defaults = testing.against('postgres', 'oracle')
+        use_function_defaults = testing.against('postgres', 'mssql', 'maxdb')
         is_oracle = testing.against('oracle')
 
         # select "count(1)" returns different results on different DBs also
         # correct for "current_date" compatible as column default, value
         # differences
         currenttime = func.current_date(type_=sa.Date, bind=db)
-
         if is_oracle:
             ts = db.scalar(sa.select([func.trunc(func.sysdate(), sa.literal_column("'DAY'"), type_=sa.Date).label('today')]))
             assert isinstance(ts, datetime.date) and not isinstance(ts, datetime.datetime)
@@ -56,11 +55,13 @@ class DefaultTest(testing.TestBase):
             f = sa.select([func.length('abcdef')], bind=db).scalar()
             f2 = sa.select([func.length('abcdefghijk')], bind=db).scalar()
             def1 = currenttime
+            deftype = sa.Date
             if testing.against('maxdb'):
                 def2 = sa.text("curdate")
+            elif testing.against('mssql'):
+                def2 = sa.text("getdate()")
             else:
                 def2 = sa.text("current_date")
-            deftype = sa.Date
             ts = db.func.current_date().scalar()
         else:
             f = len('abcdef')