]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
The behavior of :func:`.extract` has been simplified on the
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Jun 2013 03:53:27 +0000 (23:53 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Jun 2013 03:54:11 +0000 (23:54 -0400)
Postgresql dialect to no longer inject a hardcoded ``::timestamp``
or similar cast into the given expression, as this interfered
with types such as timezone-aware datetimes, but also
does not appear to be at all necessary with modern versions
of psycopg2.  Also in 0.8.2.
[ticket:2740]

Conflicts:
doc/build/changelog/changelog_09.rst

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_query.py

index 79f2ce738116bc2d58067ae3d48733ef8ff5ef48..442b7a2d9c4e4e44e82a0d55f5c33ee592bd85d6 100644 (file)
@@ -6,6 +6,18 @@
 .. changelog::
     :version: 0.8.2
 
+    .. change::
+        :tags: bug, postgresql
+        :tickets: 2740
+
+        The behavior of :func:`.extract` has been simplified on the
+        Postgresql dialect to no longer inject a hardcoded ``::timestamp``
+        or similar cast into the given expression, as this interfered
+        with types such as timezone-aware datetimes, but also
+        does not appear to be at all necessary with modern versions
+        of psycopg2.
+
+
     .. change::
         :tags: bug, firebird
         :tickets: 2757
index f9bd49f0cf15e0cfda7c3f7e2f037d52bc5875a0..2af5d92b5f627e1cb57cc8e17f8a431515ceda72 100644 (file)
@@ -1028,28 +1028,6 @@ class PGCompiler(compiler.SQLCompiler):
 
         return 'RETURNING ' + ', '.join(columns)
 
-    def visit_extract(self, extract, **kwargs):
-        field = self.extract_map.get(extract.field, extract.field)
-        if extract.expr.type:
-            affinity = extract.expr.type._type_affinity
-        else:
-            affinity = None
-
-        casts = {
-                    sqltypes.Date: 'date',
-                    sqltypes.DateTime: 'timestamp',
-                    sqltypes.Interval: 'interval',
-                    sqltypes.Time: 'time'
-                }
-        cast = casts.get(affinity, None)
-        if isinstance(extract.expr, sql.ColumnElement) and cast is not None:
-            expr = extract.expr.op('::', precedence=100)(
-                                        sql.literal_column(cast))
-        else:
-            expr = extract.expr
-        return "EXTRACT(%s FROM %s)" % (
-            field, self.process(expr))
-
 
     def visit_substring_func(self, func, **kw):
         s = self.process(func.clauses.clauses[0], **kw)
index a79c0e7de1ab92a44401346b3275f379bf1df63d..11661b11f5a54bfaea45b1a19363b9b3c1a272de 100644 (file)
@@ -229,61 +229,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
 
 
-    def test_extract(self):
-        t = table('t', column('col1', DateTime), column('col2', Date),
-                  column('col3', Time), column('col4',
-                  postgresql.INTERVAL))
-        for field in 'year', 'month', 'day', 'epoch', 'hour':
-            for expr, compiled_expr in [  # invalid, no cast. plain
-                                          # text.  no cast. addition is
-                                          # commutative subtraction is
-                                          # not invalid - no cast. dont
-                                          # crack up on entirely
-                                          # unsupported types
-                (t.c.col1, 't.col1 :: timestamp'),
-                (t.c.col2, 't.col2 :: date'),
-                (t.c.col3, 't.col3 :: time'),
-                (func.current_timestamp() - datetime.timedelta(days=5),
-                 '(CURRENT_TIMESTAMP - %(current_timestamp_1)s) :: '
-                 'timestamp'),
-                (func.current_timestamp() + func.current_timestamp(),
-                 'CURRENT_TIMESTAMP + CURRENT_TIMESTAMP'),
-                (text('foo.date + foo.time'), 'foo.date + foo.time'),
-                (func.current_timestamp() + datetime.timedelta(days=5),
-                 '(CURRENT_TIMESTAMP + %(current_timestamp_1)s) :: '
-                 'timestamp'),
-                (t.c.col2 + t.c.col3, '(t.col2 + t.col3) :: timestamp'
-                 ),
-                (t.c.col2 + datetime.timedelta(days=5),
-                 '(t.col2 + %(col2_1)s) :: timestamp'),
-                (datetime.timedelta(days=5) + t.c.col2,
-                 '(%(col2_1)s + t.col2) :: timestamp'),
-                (t.c.col1 + t.c.col4, '(t.col1 + t.col4) :: timestamp'
-                 ),
-                (t.c.col1 - datetime.timedelta(seconds=30),
-                 '(t.col1 - %(col1_1)s) :: timestamp'),
-                (datetime.timedelta(seconds=30) - t.c.col1,
-                 '%(col1_1)s - t.col1'),
-                (func.coalesce(t.c.col1, func.current_timestamp()),
-                 'coalesce(t.col1, CURRENT_TIMESTAMP) :: timestamp'),
-                (t.c.col3 + datetime.timedelta(seconds=30),
-                 '(t.col3 + %(col3_1)s) :: time'),
-                (func.current_timestamp() - func.coalesce(t.c.col1,
-                 func.current_timestamp()),
-                 '(CURRENT_TIMESTAMP - coalesce(t.col1, '
-                 'CURRENT_TIMESTAMP)) :: interval'),
-                (3 * func.foobar(type_=Interval),
-                 '(%(foobar_1)s * foobar()) :: interval'),
-                (literal(datetime.timedelta(seconds=10))
-                 - literal(datetime.timedelta(seconds=10)),
-                 '(%(param_1)s - %(param_2)s) :: interval'),
-                (t.c.col3 + 'some string', 't.col3 + %(col3_1)s'),
-                ]:
-                self.assert_compile(select([extract(field,
-                                    expr)]).select_from(t),
-                                    'SELECT EXTRACT(%s FROM %s) AS '
-                                    'anon_1 FROM t' % (field,
-                                    compiled_expr))
 
     def test_reserved_words(self):
         table = Table("pg_table", MetaData(),
index a7bcbf3daf664e6f8a3caaf434a763875d106682..9335c5bbc12b3322bc063246ad2fe3c1f573fb63 100644 (file)
@@ -12,7 +12,8 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
             SmallInteger, Enum, REAL, update, insert, Index, delete, \
             and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text
 from sqlalchemy import exc
-import logging
+from sqlalchemy.dialects import postgresql
+import datetime
 
 class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
@@ -721,3 +722,162 @@ class TupleTest(fixtures.TestBase):
                 ).scalar(),
                 exp
             )
+
+
+
+class ExtractTest(fixtures.TablesTest):
+    """The rationale behind this test is that for many years we've had a system
+    of embedding type casts into the expressions rendered by visit_extract()
+    on the postgreql platform.  The reason for this cast is not clear.
+    So here we try to produce a wide range of cases to ensure that these casts
+    are not needed; see [ticket:2740].
+
+    """
+    __only_on__ = 'postgresql'
+
+    run_inserts = 'once'
+    run_deletes = None
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('t', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('dtme', DateTime),
+            Column('dt', Date),
+            Column('tm', Time),
+            Column('intv', postgresql.INTERVAL),
+            Column('dttz', DateTime(timezone=True))
+            )
+
+    @classmethod
+    def insert_data(cls):
+        # TODO: why does setting hours to anything
+        # not affect the TZ in the DB col ?
+        class TZ(datetime.tzinfo):
+            def utcoffset(self, dt):
+                return datetime.timedelta(hours=4)
+
+        testing.db.execute(
+            cls.tables.t.insert(),
+            {
+                'dtme': datetime.datetime(2012, 5, 10, 12, 15, 25),
+                'dt': datetime.date(2012, 5, 10),
+                'tm': datetime.time(12, 15, 25),
+                'intv': datetime.timedelta(seconds=570),
+                'dttz': datetime.datetime(2012, 5, 10, 12, 15, 25, tzinfo=TZ())
+            },
+        )
+
+    def _test(self, expr, field="all", overrides=None):
+        t = self.tables.t
+
+        if field == "all":
+            fields = {"year": 2012, "month": 5, "day": 10,
+                                "epoch": 1336652125.0,
+                                "hour": 12, "minute": 15}
+        elif field == "time":
+            fields = {"hour": 12, "minute": 15, "second": 25}
+        elif field == 'date':
+            fields = {"year": 2012, "month": 5, "day": 10}
+        elif field == 'all+tz':
+            fields = {"year": 2012, "month": 5, "day": 10,
+                                "epoch": 1336637725.0,
+                                "hour": 4,
+                                # can't figure out how to get a specific
+                                # tz into the DB
+                                #"timezone": -14400
+                                }
+        else:
+            fields = field
+
+        if overrides:
+            fields.update(overrides)
+
+        for field in fields:
+            result = testing.db.scalar(
+                        select([extract(field, expr)]).select_from(t))
+            eq_(result, fields[field])
+
+    def test_one(self):
+        t = self.tables.t
+        self._test(t.c.dtme, "all")
+
+    def test_two(self):
+        t = self.tables.t
+        self._test(t.c.dtme + t.c.intv,
+                overrides={"epoch": 1336652695.0, "minute": 24})
+
+    def test_three(self):
+        t = self.tables.t
+
+        actual_ts = testing.db.scalar(func.current_timestamp()) - \
+                        datetime.timedelta(days=5)
+        self._test(func.current_timestamp() - datetime.timedelta(days=5),
+                {"hour": actual_ts.hour, "year": actual_ts.year,
+                "month": actual_ts.month}
+            )
+
+    def test_four(self):
+        t = self.tables.t
+        self._test(datetime.timedelta(days=5) + t.c.dt,
+                overrides={"day": 15, "epoch": 1337040000.0, "hour": 0,
+                            "minute": 0}
+            )
+
+    def test_five(self):
+        t = self.tables.t
+        self._test(func.coalesce(t.c.dtme, func.current_timestamp()),
+                    overrides={"epoch": 1336666525.0})
+
+    def test_six(self):
+        t = self.tables.t
+        self._test(t.c.tm + datetime.timedelta(seconds=30), "time",
+                    overrides={"second": 55})
+
+    def test_seven(self):
+        self._test(literal(datetime.timedelta(seconds=10))
+                 - literal(datetime.timedelta(seconds=10)), "all",
+                 overrides={"hour": 0, "minute": 0, "month": 0,
+                        "year": 0, "day": 0, "epoch": 0})
+
+    def test_eight(self):
+        t = self.tables.t
+        self._test(t.c.tm + datetime.timedelta(seconds=30),
+                {"hour": 12, "minute": 15, "second": 55})
+
+    def test_nine(self):
+        self._test(text("t.dt + t.tm"))
+
+    def test_ten(self):
+        t = self.tables.t
+        self._test(t.c.dt + t.c.tm)
+
+    def test_eleven(self):
+        self._test(func.current_timestamp() - func.current_timestamp(),
+                {"year": 0, "month": 0, "day": 0, "hour": 0}
+            )
+
+    def test_twelve(self):
+        t = self.tables.t
+        actual_ts = testing.db.scalar(
+                    func.current_timestamp()).replace(tzinfo=None) - \
+                        datetime.datetime(2012, 5, 10, 12, 15, 25)
+
+        self._test(func.current_timestamp() - func.coalesce(t.c.dtme,
+                 func.current_timestamp()),
+                    {"day": actual_ts.days}
+                )
+
+    def test_thirteen(self):
+        t = self.tables.t
+        self._test(t.c.dttz, "all+tz")
+
+    def test_fourteen(self):
+        t = self.tables.t
+        self._test(t.c.tm, "time")
+
+    def test_fifteen(self):
+        t = self.tables.t
+        self._test(datetime.timedelta(days=5) + t.c.dtme,
+                overrides={"day": 15, "epoch": 1337084125.0}
+            )