]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The extract() function, which was slightly improved in
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Dec 2009 02:35:42 +0000 (02:35 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Dec 2009 02:35:42 +0000 (02:35 +0000)
0.5.7, needed a lot more work to generate the correct
typecast (the typecasts appear to be necessary in PG's
EXTRACT quite a lot of the time).  The typecast is
now generated using a rule dictionary based
on PG's documentation for date/time/interval arithmetic.
It also accepts text() constructs again, which was broken
in 0.5.7. [ticket:1647]

CHANGES
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/types.py
test/dialect/test_postgres.py

diff --git a/CHANGES b/CHANGES
index 700d4b088a13861b8e7c95fdc6b66b2ae113255a..f5c386f45acd3b2fa69815555ec0088b5b942d4a 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,6 +4,18 @@
 CHANGES
 =======
 
+0.5.8
+=====
+- postgresql
+    - The extract() function, which was slightly improved in
+      0.5.7, needed a lot more work to generate the correct
+      typecast (the typecasts appear to be necessary in PG's
+      EXTRACT quite a lot of the time).  The typecast is 
+      now generated using a rule dictionary based 
+      on PG's documentation for date/time/interval arithmetic.
+      It also accepts text() constructs again, which was broken
+      in 0.5.7. [ticket:1647]
+      
 0.5.7
 =====
 - orm
index a46ed6723d4895761d3a8032e0481b97fe624326..6605eebd97e623a12d158b0978d594825ee66bcb 100644 (file)
@@ -93,7 +93,7 @@ import decimal, random, re, string
 
 from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.engine import base, default
-from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import compiler, expression, util as sql_util
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
 
@@ -821,8 +821,16 @@ class PGCompiler(compiler.DefaultCompiler):
 
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
+        affinity = sql_util.determine_date_affinity(extract.expr)
+        
+        casts = {sqltypes.Date:'date', sqltypes.DateTime:'timestamp', sqltypes.Interval:'interval', sqltypes.Time:'time'}
+        cast = casts.get(affinity, None)
+        if isinstance(extract.expr, sql.ColumnElement) and cast is not None:
+            expr = extract.expr.op('::')(sql.literal_column(cast))
+        else:
+            expr = extract.expr
         return "EXTRACT(%s FROM %s)" % (
-            field, self.process(extract.expr.op('::')(sql.literal_column('timestamp'))))
+            field, self.process(expr))
 
 
 class PGSchemaGenerator(compiler.SchemaGenerator):
index dccd3d4627f325c2fdc22a56ce419c67296d4809..06cd78db1dd824860dca5476513faa0d2f2daa8a 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy import exc, schema, topological, util, sql
+from sqlalchemy import exc, schema, topological, util, sql, types as sqltypes
 from sqlalchemy.sql import expression, operators, visitors
 from itertools import chain
 
@@ -46,6 +46,90 @@ def find_join_source(clauses, join_to):
     else:
         return None, None
 
+_date_affinities = None
+def determine_date_affinity(expr):
+    """Given an expression, determine if it returns 'interval', 'date', or 'datetime'.
+    
+    the PG dialect uses this to generate the extract() function.
+    
+    It's less than ideal since it basically needs to duplicate PG's 
+    date arithmetic rules.   
+    
+    Rules are based on http://www.postgresql.org/docs/current/static/functions-datetime.html.
+    
+    Returns None if operators other than + or - are detected as well as types
+    outside of those above.
+    
+    """
+    
+    global _date_affinities
+    if _date_affinities is None:
+        Date, DateTime, Integer, \
+            Numeric, Interval, Time = \
+                                    sqltypes.Date, sqltypes.DateTime,\
+                                    sqltypes.Integer, sqltypes.Numeric,\
+                                    sqltypes.Interval, sqltypes.Time
+
+        _date_affinities = {
+            operators.add:{
+                (Date, Integer):Date,
+                (Date, Interval):DateTime,
+                (Date, Time):DateTime,
+                (Interval, Interval):Interval,
+                (DateTime, Interval):DateTime,
+                (Interval, Time):Time,
+            },
+            operators.sub:{
+                (Date, Integer):Date,
+                (Date, Interval):DateTime,
+                (Time, Time):Interval,
+                (Time, Interval):Time,
+                (DateTime, Interval):DateTime,
+                (Interval, Interval):Interval,
+                (DateTime, DateTime):Interval,
+            },
+            operators.mul:{
+                (Integer, Interval):Interval,
+                (Interval, Numeric):Interval,
+            },
+            operators.div: {
+                (Interval, Numeric):Interval
+            }
+        }
+    
+    if isinstance(expr, expression._BinaryExpression):
+        if expr.operator not in _date_affinities:
+            return None
+            
+        left_affin, right_affin = \
+            determine_date_affinity(expr.left), \
+            determine_date_affinity(expr.right)
+        
+        if operators.is_commutative(expr.operator):
+            key = tuple(sorted([left_affin, right_affin], key=lambda cls:cls.__name__))
+        else:
+            key = (left_affin, right_affin)
+        
+        lookup = _date_affinities[expr.operator]
+        return lookup.get(key, None)
+
+    # work around the fact that expressions put the wrong type
+    # on generated bind params when its "datetime + timedelta"
+    # and similar
+    if isinstance(expr, expression._BindParamClause):
+        type_ = sqltypes.type_map.get(type(expr.value), sqltypes.NullType)()
+    else:
+        type_ = expr.type
+
+    affinities = set([sqltypes.Date, sqltypes.DateTime, 
+                    sqltypes.Interval, sqltypes.Time, sqltypes.Integer])
+        
+    if type_ is not None and type_._type_affinity in affinities:
+        return type_._type_affinity
+    else:
+        return None
+    
+    
     
 def find_tables(clause, check_columns=False, 
                 include_aliases=False, include_joins=False, 
index 8e6accdbf54d6be2b6aea7e4e37e25c80d1c73a6..676e97f6e8adbd70784909c057c44eb9793e55e7 100644 (file)
@@ -839,7 +839,11 @@ class Interval(TypeDecorator):
         import sqlalchemy.databases.postgres as pg
         self.__supported = {pg.PGDialect:pg.PGInterval}
         del pg
-
+    
+    @property
+    def _type_affinity(self):
+        return Interval
+        
     def load_dialect_impl(self, dialect):
         if dialect.__class__ in self.__supported:
             return self.__supported[dialect.__class__]()
index f8eefec5c990eb91e8cb1ca751efd9fb8cb2794c..a53dc64f5bb98cffabb97fddddcd639d2ac085c4 100644 (file)
@@ -61,19 +61,66 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
 
     def test_extract(self):
-        t = table('t', column('col1'))
-
-        for field in 'year', 'month', 'day':
-            self.assert_compile(
-                select([extract(field, t.c.col1)]),
-                "SELECT EXTRACT(%s FROM t.col1 :: timestamp) AS anon_1 "
-                "FROM t" % field)
-
-        for field in 'year', 'month', 'day':
-            self.assert_compile(
-                select([extract(field, func.timestamp() - datetime.timedelta(days =5))]),
-                "SELECT EXTRACT(%s FROM (timestamp() - %%(timestamp_1)s) :: timestamp) AS anon_1"
-                 % field)
+
+        t = table('t', column('col1', DateTime), column('col2', Date), column('col3', Time))
+
+        for field in 'year', 'month', 'day', 'epoch', 'hour':
+            for expr, compiled_expr in [
+                ( t.c.col1, "t.col1 :: timestamp" ),
+                ( t.c.col2, "t.col2 :: date" ),
+                ( t.c.col3, "t.col3 :: time" ),
+                (func.current_timestamp() - datetime.timedelta(days=5),
+                    "(CURRENT_TIMESTAMP - %(current_timestamp_1)s) :: timestamp"
+                ),
+                (func.current_timestamp() + func.current_timestamp(), 
+                    "CURRENT_TIMESTAMP + CURRENT_TIMESTAMP" # invalid, no cast.
+                ),
+                (text("foo.date + foo.time"), 
+                    "foo.date + foo.time" # plain text.  no cast.
+                ),
+                (func.current_timestamp() + datetime.timedelta(days=5), 
+                    "(CURRENT_TIMESTAMP + %(current_timestamp_1)s) :: timestamp"
+                ),
+                (t.c.col2 + t.c.col3,
+                    "(t.col2 + t.col3) :: timestamp"
+                ),
+                # addition is commutative
+                (t.c.col2 + datetime.timedelta(days=5),
+                    "(t.col2 + %(col2_1)s) :: timestamp"
+                ),
+                (datetime.timedelta(days=5) + t.c.col2,
+                    "(%(col2_1)s + t.col2) :: timestamp"
+                ),
+                # subtraction is not
+                (t.c.col1 - datetime.timedelta(seconds=30),
+                    "(t.col1 - %(col1_1)s) :: timestamp"
+                ),
+                (datetime.timedelta(seconds=30) - t.c.col1,
+                    "%(col1_1)s - t.col1" # invalid - no cast.
+                ),
+                (func.coalesce(t.c.col1, func.current_timestamp()),
+                    "coalesce(t.col1, CURRENT_TIMESTAMP) :: timestamp"
+                ),
+                (t.c.col3 + datetime.timedelta(seconds=30),
+                    "(t.col3 + %(col3_1)s) :: time"
+                ),
+                (func.current_timestamp() - func.coalesce(t.c.col1, func.current_timestamp()),
+                    "(CURRENT_TIMESTAMP - coalesce(t.col1, CURRENT_TIMESTAMP)) :: interval",
+                ),
+                (3 * func.foobar(type_=Interval),
+                    "(%(foobar_1)s * foobar()) :: interval"
+                ),
+                (literal(datetime.timedelta(seconds=10)) - literal(datetime.timedelta(seconds=10)),
+                    "(%(param_1)s - %(param_2)s) :: interval"
+                ),
+            ]:
+                self.assert_compile(
+                    select([extract(field, expr)]).select_from(t),
+                    "SELECT EXTRACT(%s FROM %s) AS anon_1 FROM t" % (
+                        field, 
+                        compiled_expr
+                    )
+                )
 
 class ReturningTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgres'