From 92a5df77538069efd9f8cfc14cf83807ce43c288 Mon Sep 17 00:00:00 2001 From: Jason Kirtland Date: Tue, 25 Mar 2008 16:51:29 +0000 Subject: [PATCH] - Added generic func.random (non-standard SQL) --- lib/sqlalchemy/databases/mysql.py | 6 ++++++ lib/sqlalchemy/sql/compiler.py | 1 + lib/sqlalchemy/sql/functions.py | 34 +++++++++++++++++++++---------- test/sql/functions.py | 18 +++++++++++++--- 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 314cb8dac8..f2544e9b5d 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -158,6 +158,7 @@ from array import array as _array from sqlalchemy import exceptions, logging, schema, sql, util from sqlalchemy.sql import operators as sql_operators +from sqlalchemy.sql import functions as sql_functions from sqlalchemy.sql import compiler from sqlalchemy.engine import base as engine_base, default @@ -1898,6 +1899,11 @@ class MySQLCompiler(compiler.DefaultCompiler): sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), sql_operators.mod: '%%' }) + functions = compiler.DefaultCompiler.functions.copy() + functions.update ({ + sql_functions.random: 'rand%(expr)s' + }) + def visit_typeclause(self, typeclause): type_ = typeclause.type.dialect_impl(self.dialect) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6a048a7809..ee3ecc1f1b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -101,6 +101,7 @@ FUNCTIONS = { functions.current_user: 'CURRENT_USER', functions.localtime: 'LOCALTIME', functions.localtimestamp: 'LOCALTIMESTAMP', + functions.random: 'random%(expr)s', functions.sysdate: 'sysdate', functions.session_user :'SESSION_USER', functions.user: 'USER' diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index be1d8eb611..66954168c5 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,16 +1,18 @@ from sqlalchemy import types as sqltypes -from sqlalchemy.sql.expression import _Function, _literal_as_binds, ClauseList, _FigureVisitName +from sqlalchemy.sql.expression import _Function, _literal_as_binds, \ + ClauseList, _FigureVisitName from sqlalchemy.sql import operators + class _GenericMeta(_FigureVisitName): def __init__(cls, clsname, bases, dict): cls.__visit_name__ = 'function' type.__init__(cls, clsname, bases, dict) - + def __call__(self, *args, **kwargs): args = [_literal_as_binds(c) for c in args] return type.__call__(self, *args, **kwargs) - + class GenericFunction(_Function): __metaclass__ = _GenericMeta @@ -20,16 +22,21 @@ class GenericFunction(_Function): self.name = self.__class__.__name__ self._bind = kwargs.get('bind', None) if group: - self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group() + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *args).self_group() else: - self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args) - self.type = sqltypes.to_instance(type_ or getattr(self, '__return_type__', None)) - + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *args) + self.type = sqltypes.to_instance( + type_ or getattr(self, '__return_type__', None)) + class AnsiFunction(GenericFunction): def __init__(self, **kwargs): GenericFunction.__init__(self, **kwargs) - + class coalesce(GenericFunction): def __init__(self, *args, **kwargs): kwargs.setdefault('type_', _type_from_args(args)) @@ -37,7 +44,7 @@ class coalesce(GenericFunction): class now(GenericFunction): __return_type__ = sqltypes.DateTime - + class concat(GenericFunction): __return_type__ = sqltypes.String def __init__(self, *args, **kwargs): @@ -48,10 +55,15 @@ class char_length(GenericFunction): def __init__(self, arg, **kwargs): GenericFunction.__init__(self, args=[arg], **kwargs) - + +class random(GenericFunction): + def __init__(self, *args, **kwargs): + kwargs.setdefault('type_', None) + GenericFunction.__init__(self, args=args, **kwargs) + class current_date(AnsiFunction): __return_type__ = sqltypes.Date - + class current_time(AnsiFunction): __return_type__ = sqltypes.Time diff --git a/test/sql/functions.py b/test/sql/functions.py index f6b1e67f18..e5c59b091c 100644 --- a/test/sql/functions.py +++ b/test/sql/functions.py @@ -35,7 +35,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): def test_generic_now(self): assert isinstance(func.now().type, sqltypes.DateTime) - + for ret, dialect in [ ('CURRENT_TIMESTAMP', sqlite.dialect()), ('now()', postgres.dialect()), @@ -43,7 +43,19 @@ class CompileTest(TestBase, AssertsCompiledSQL): ('CURRENT_TIMESTAMP', oracle.dialect()) ]: self.assert_compile(func.now(), ret, dialect=dialect) - + + def test_generic_random(self): + assert func.random().type == sqltypes.NULLTYPE + assert isinstance(func.random(type_=Integer).type, Integer) + + for ret, dialect in [ + ('random()', sqlite.dialect()), + ('random()', postgres.dialect()), + ('rand()', mysql.dialect()), + ('random()', oracle.dialect()) + ]: + self.assert_compile(func.random(), ret, dialect=dialect) + def test_constructor(self): try: func.current_timestamp('somearg') @@ -79,7 +91,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): 'myothertable', column('otherid', Integer), ) - + # test an expression with a function self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid") -- 2.47.3