From: Mike Bayer Date: Tue, 29 Dec 2009 02:35:42 +0000 (+0000) Subject: - The extract() function, which was slightly improved in X-Git-Tag: rel_0_5_8~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4acf0f69f394c151b2e6c2d399632888cdef6fd9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - The extract() function, which was slightly improved in 0.5.7, needed a lot more work to generate the correct typecast (the typecasts appear to be necessary in PG's EXTRACT quite a lot of the time). The typecast is now generated using a rule dictionary based on PG's documentation for date/time/interval arithmetic. It also accepts text() constructs again, which was broken in 0.5.7. [ticket:1647] --- diff --git a/CHANGES b/CHANGES index 700d4b088a..f5c386f45a 100644 --- a/CHANGES +++ b/CHANGES @@ -4,6 +4,18 @@ CHANGES ======= +0.5.8 +===== +- postgresql + - The extract() function, which was slightly improved in + 0.5.7, needed a lot more work to generate the correct + typecast (the typecasts appear to be necessary in PG's + EXTRACT quite a lot of the time). The typecast is + now generated using a rule dictionary based + on PG's documentation for date/time/interval arithmetic. + It also accepts text() constructs again, which was broken + in 0.5.7. [ticket:1647] + 0.5.7 ===== - orm diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a46ed6723d..6605eebd97 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -93,7 +93,7 @@ import decimal, random, re, string from sqlalchemy import sql, schema, exc, util from sqlalchemy.engine import base, default -from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import compiler, expression, util as sql_util from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes @@ -821,8 +821,16 @@ class PGCompiler(compiler.DefaultCompiler): def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) + affinity = sql_util.determine_date_affinity(extract.expr) + + casts = {sqltypes.Date:'date', sqltypes.DateTime:'timestamp', sqltypes.Interval:'interval', sqltypes.Time:'time'} + cast = casts.get(affinity, None) + if isinstance(extract.expr, sql.ColumnElement) and cast is not None: + expr = extract.expr.op('::')(sql.literal_column(cast)) + else: + expr = extract.expr return "EXTRACT(%s FROM %s)" % ( - field, self.process(extract.expr.op('::')(sql.literal_column('timestamp')))) + field, self.process(expr)) class PGSchemaGenerator(compiler.SchemaGenerator): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index dccd3d4627..06cd78db1d 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,4 +1,4 @@ -from sqlalchemy import exc, schema, topological, util, sql +from sqlalchemy import exc, schema, topological, util, sql, types as sqltypes from sqlalchemy.sql import expression, operators, visitors from itertools import chain @@ -46,6 +46,90 @@ def find_join_source(clauses, join_to): else: return None, None +_date_affinities = None +def determine_date_affinity(expr): + """Given an expression, determine if it returns 'interval', 'date', or 'datetime'. + + the PG dialect uses this to generate the extract() function. + + It's less than ideal since it basically needs to duplicate PG's + date arithmetic rules. + + Rules are based on http://www.postgresql.org/docs/current/static/functions-datetime.html. + + Returns None if operators other than + or - are detected as well as types + outside of those above. + + """ + + global _date_affinities + if _date_affinities is None: + Date, DateTime, Integer, \ + Numeric, Interval, Time = \ + sqltypes.Date, sqltypes.DateTime,\ + sqltypes.Integer, sqltypes.Numeric,\ + sqltypes.Interval, sqltypes.Time + + _date_affinities = { + operators.add:{ + (Date, Integer):Date, + (Date, Interval):DateTime, + (Date, Time):DateTime, + (Interval, Interval):Interval, + (DateTime, Interval):DateTime, + (Interval, Time):Time, + }, + operators.sub:{ + (Date, Integer):Date, + (Date, Interval):DateTime, + (Time, Time):Interval, + (Time, Interval):Time, + (DateTime, Interval):DateTime, + (Interval, Interval):Interval, + (DateTime, DateTime):Interval, + }, + operators.mul:{ + (Integer, Interval):Interval, + (Interval, Numeric):Interval, + }, + operators.div: { + (Interval, Numeric):Interval + } + } + + if isinstance(expr, expression._BinaryExpression): + if expr.operator not in _date_affinities: + return None + + left_affin, right_affin = \ + determine_date_affinity(expr.left), \ + determine_date_affinity(expr.right) + + if operators.is_commutative(expr.operator): + key = tuple(sorted([left_affin, right_affin], key=lambda cls:cls.__name__)) + else: + key = (left_affin, right_affin) + + lookup = _date_affinities[expr.operator] + return lookup.get(key, None) + + # work around the fact that expressions put the wrong type + # on generated bind params when its "datetime + timedelta" + # and similar + if isinstance(expr, expression._BindParamClause): + type_ = sqltypes.type_map.get(type(expr.value), sqltypes.NullType)() + else: + type_ = expr.type + + affinities = set([sqltypes.Date, sqltypes.DateTime, + sqltypes.Interval, sqltypes.Time, sqltypes.Integer]) + + if type_ is not None and type_._type_affinity in affinities: + return type_._type_affinity + else: + return None + + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 8e6accdbf5..676e97f6e8 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -839,7 +839,11 @@ class Interval(TypeDecorator): import sqlalchemy.databases.postgres as pg self.__supported = {pg.PGDialect:pg.PGInterval} del pg - + + @property + def _type_affinity(self): + return Interval + def load_dialect_impl(self, dialect): if dialect.__class__ in self.__supported: return self.__supported[dialect.__class__]() diff --git a/test/dialect/test_postgres.py b/test/dialect/test_postgres.py index f8eefec5c9..a53dc64f5b 100644 --- a/test/dialect/test_postgres.py +++ b/test/dialect/test_postgres.py @@ -61,19 +61,66 @@ class CompileTest(TestBase, AssertsCompiledSQL): self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) def test_extract(self): - t = table('t', column('col1')) - - for field in 'year', 'month', 'day': - self.assert_compile( - select([extract(field, t.c.col1)]), - "SELECT EXTRACT(%s FROM t.col1 :: timestamp) AS anon_1 " - "FROM t" % field) - - for field in 'year', 'month', 'day': - self.assert_compile( - select([extract(field, func.timestamp() - datetime.timedelta(days =5))]), - "SELECT EXTRACT(%s FROM (timestamp() - %%(timestamp_1)s) :: timestamp) AS anon_1" - % field) + + t = table('t', column('col1', DateTime), column('col2', Date), column('col3', Time)) + + for field in 'year', 'month', 'day', 'epoch', 'hour': + for expr, compiled_expr in [ + ( t.c.col1, "t.col1 :: timestamp" ), + ( t.c.col2, "t.col2 :: date" ), + ( t.c.col3, "t.col3 :: time" ), + (func.current_timestamp() - datetime.timedelta(days=5), + "(CURRENT_TIMESTAMP - %(current_timestamp_1)s) :: timestamp" + ), + (func.current_timestamp() + func.current_timestamp(), + "CURRENT_TIMESTAMP + CURRENT_TIMESTAMP" # invalid, no cast. + ), + (text("foo.date + foo.time"), + "foo.date + foo.time" # plain text. no cast. + ), + (func.current_timestamp() + datetime.timedelta(days=5), + "(CURRENT_TIMESTAMP + %(current_timestamp_1)s) :: timestamp" + ), + (t.c.col2 + t.c.col3, + "(t.col2 + t.col3) :: timestamp" + ), + # addition is commutative + (t.c.col2 + datetime.timedelta(days=5), + "(t.col2 + %(col2_1)s) :: timestamp" + ), + (datetime.timedelta(days=5) + t.c.col2, + "(%(col2_1)s + t.col2) :: timestamp" + ), + # subtraction is not + (t.c.col1 - datetime.timedelta(seconds=30), + "(t.col1 - %(col1_1)s) :: timestamp" + ), + (datetime.timedelta(seconds=30) - t.c.col1, + "%(col1_1)s - t.col1" # invalid - no cast. + ), + (func.coalesce(t.c.col1, func.current_timestamp()), + "coalesce(t.col1, CURRENT_TIMESTAMP) :: timestamp" + ), + (t.c.col3 + datetime.timedelta(seconds=30), + "(t.col3 + %(col3_1)s) :: time" + ), + (func.current_timestamp() - func.coalesce(t.c.col1, func.current_timestamp()), + "(CURRENT_TIMESTAMP - coalesce(t.col1, CURRENT_TIMESTAMP)) :: interval", + ), + (3 * func.foobar(type_=Interval), + "(%(foobar_1)s * foobar()) :: interval" + ), + (literal(datetime.timedelta(seconds=10)) - literal(datetime.timedelta(seconds=10)), + "(%(param_1)s - %(param_2)s) :: interval" + ), + ]: + self.assert_compile( + select([extract(field, expr)]).select_from(t), + "SELECT EXTRACT(%s FROM %s) AS anon_1 FROM t" % ( + field, + compiled_expr + ) + ) class ReturningTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres'