]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added generic func.random (non-standard SQL)
authorJason Kirtland <jek@discorporate.us>
Tue, 25 Mar 2008 16:51:29 +0000 (16:51 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 25 Mar 2008 16:51:29 +0000 (16:51 +0000)
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/functions.py
test/sql/functions.py

index 314cb8dac8cbf5580e2b7a070beb8ffdd6f98af6..f2544e9b5d71060c30762c91f8027686d7c4c587 100644 (file)
@@ -158,6 +158,7 @@ from array import array as _array
 
 from sqlalchemy import exceptions, logging, schema, sql, util
 from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy.sql import functions as sql_functions
 from sqlalchemy.sql import compiler
 
 from sqlalchemy.engine import base as engine_base, default
@@ -1898,6 +1899,11 @@ class MySQLCompiler(compiler.DefaultCompiler):
         sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
         sql_operators.mod: '%%'
     })
+    functions = compiler.DefaultCompiler.functions.copy()
+    functions.update ({
+        sql_functions.random: 'rand%(expr)s'
+        })
+
 
     def visit_typeclause(self, typeclause):
         type_ = typeclause.type.dialect_impl(self.dialect)
index 6a048a78096cc99cb6f14fc0dd0f4f910b8f1c27..ee3ecc1f1b0c48fddbf97ad0b21523443f34693a 100644 (file)
@@ -101,6 +101,7 @@ FUNCTIONS = {
     functions.current_user: 'CURRENT_USER',
     functions.localtime: 'LOCALTIME',
     functions.localtimestamp: 'LOCALTIMESTAMP',
+    functions.random: 'random%(expr)s',
     functions.sysdate: 'sysdate',
     functions.session_user :'SESSION_USER',
     functions.user: 'USER'
index be1d8eb611d8c9d49d6004bcf3faf14ea76f0140..66954168c501ff19b9677c5a057a2b5985f5ddb7 100644 (file)
@@ -1,16 +1,18 @@
 from sqlalchemy import types as sqltypes
-from sqlalchemy.sql.expression import _Function, _literal_as_binds, ClauseList, _FigureVisitName
+from sqlalchemy.sql.expression import _Function, _literal_as_binds, \
+                                      ClauseList, _FigureVisitName
 from sqlalchemy.sql import operators
 
+
 class _GenericMeta(_FigureVisitName):
     def __init__(cls, clsname, bases, dict):
         cls.__visit_name__ = 'function'
         type.__init__(cls, clsname, bases, dict)
-       
+
     def __call__(self, *args, **kwargs):
         args = [_literal_as_binds(c) for c in args]
         return type.__call__(self, *args, **kwargs)
-        
+
 class GenericFunction(_Function):
     __metaclass__ = _GenericMeta
 
@@ -20,16 +22,21 @@ class GenericFunction(_Function):
         self.name = self.__class__.__name__
         self._bind = kwargs.get('bind', None)
         if group:
-            self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group()
+            self.clause_expr = ClauseList(
+                operator=operators.comma_op,
+                group_contents=True, *args).self_group()
         else:
-            self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args)
-        self.type = sqltypes.to_instance(type_ or getattr(self, '__return_type__', None))
-        
+            self.clause_expr = ClauseList(
+                operator=operators.comma_op,
+                group_contents=True, *args)
+        self.type = sqltypes.to_instance(
+            type_ or getattr(self, '__return_type__', None))
+
 class AnsiFunction(GenericFunction):
     def __init__(self, **kwargs):
         GenericFunction.__init__(self, **kwargs)
 
-        
+
 class coalesce(GenericFunction):
     def __init__(self, *args, **kwargs):
         kwargs.setdefault('type_', _type_from_args(args))
@@ -37,7 +44,7 @@ class coalesce(GenericFunction):
 
 class now(GenericFunction):
     __return_type__ = sqltypes.DateTime
-    
+
 class concat(GenericFunction):
     __return_type__ = sqltypes.String
     def __init__(self, *args, **kwargs):
@@ -48,10 +55,15 @@ class char_length(GenericFunction):
 
     def __init__(self, arg, **kwargs):
         GenericFunction.__init__(self, args=[arg], **kwargs)
-        
+
+class random(GenericFunction):
+    def __init__(self, *args, **kwargs):
+        kwargs.setdefault('type_', None)
+        GenericFunction.__init__(self, args=args, **kwargs)
+
 class current_date(AnsiFunction):
     __return_type__ = sqltypes.Date
-        
+
 class current_time(AnsiFunction):
     __return_type__ = sqltypes.Time
 
index f6b1e67f182ec26e6e913f65a4a101545caa8644..e5c59b091c5f9ed45d22ea629054b46fc9aaa1e4 100644 (file)
@@ -35,7 +35,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
 
     def test_generic_now(self):
         assert isinstance(func.now().type, sqltypes.DateTime)
-        
+
         for ret, dialect in [
             ('CURRENT_TIMESTAMP', sqlite.dialect()),
             ('now()', postgres.dialect()),
@@ -43,7 +43,19 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             ('CURRENT_TIMESTAMP', oracle.dialect())
         ]:
             self.assert_compile(func.now(), ret, dialect=dialect)
-            
+
+    def test_generic_random(self):
+        assert func.random().type == sqltypes.NULLTYPE
+        assert isinstance(func.random(type_=Integer).type, Integer)
+
+        for ret, dialect in [
+            ('random()', sqlite.dialect()),
+            ('random()', postgres.dialect()),
+            ('rand()', mysql.dialect()),
+            ('random()', oracle.dialect())
+        ]:
+            self.assert_compile(func.random(), ret, dialect=dialect)
+
     def test_constructor(self):
         try:
             func.current_timestamp('somearg')
@@ -79,7 +91,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             'myothertable',
             column('otherid', Integer),
         )
-        
+
         # test an expression with a function
         self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid,
             "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid")