]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add literal_processors for interval, PG and Oracle
authorindivar <indimishra@gmail.com>
Thu, 28 Sep 2023 17:43:57 +0000 (13:43 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Oct 2023 16:10:31 +0000 (12:10 -0400)
Implemented "literal value processing" for the :class:`.Interval` datatype
for both the PostgreSQL and Oracle dialects, allowing literal rendering of
interval values.  Pull request courtesy Indivar Mishra.

Fixes: #9737
Closes: #10383
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10383
Pull-request-sha: bf3a73dfd9d329779e12037ae62dc1032e76d0f6

Change-Id: Ic1a1c505f23eeb681059303799d5fc8821aadacf

doc/build/changelog/unreleased_20/9737.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/types.py
lib/sqlalchemy/dialects/postgresql/types.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/oracle/test_types.py
test/dialect/postgresql/test_types.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/9737.rst b/doc/build/changelog/unreleased_20/9737.rst
new file mode 100644 (file)
index 0000000..806ee05
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, sql
+    :tickets: 9737
+
+    Implemented "literal value processing" for the :class:`.Interval` datatype
+    for both the PostgreSQL and Oracle dialects, allowing literal rendering of
+    interval values.  Pull request courtesy Indivar Mishra.
index c1f6d51916d4192cce054cb753e838d0d368cc14..4e616d1ceb1d7ad3255dbbd2e9c2515019867145 100644 (file)
@@ -4,12 +4,23 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
+from __future__ import annotations
+
+import datetime as dt
+from typing import no_type_check
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
 
 from ... import exc
 from ...sql import sqltypes
 from ...types import NVARCHAR
 from ...types import VARCHAR
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.type_api import _LiteralProcessorType
+
 
 class RAW(sqltypes._Binary):
     __visit_name__ = "RAW"
@@ -212,6 +223,19 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
             day_precision=self.day_precision,
         )
 
+    @property
+    def python_type(self) -> Type[dt.timedelta]:
+        return dt.timedelta
+
+    @no_type_check
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[dt.timedelta]]:
+        def process(value: dt.timedelta) -> str:
+            return f"NUMTODSINTERVAL({value.total_seconds()}, 'SECOND')"
+
+        return process
+
 
 class TIMESTAMP(sqltypes.TIMESTAMP):
     """Oracle implementation of ``TIMESTAMP``, which supports additional
index 2f49ff12a459723e1738014e1e699e8e3d844c7c..61116aa43d6d33c8a53c0e5726144b62d8a31d80 100644 (file)
@@ -7,6 +7,7 @@ from __future__ import annotations
 
 import datetime as dt
 from typing import Any
+from typing import no_type_check
 from typing import Optional
 from typing import overload
 from typing import Type
@@ -18,7 +19,9 @@ from ...sql import type_api
 from ...util.typing import Literal
 
 if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
     from ...sql.operators import OperatorType
+    from ...sql.type_api import _LiteralProcessorType
     from ...sql.type_api import TypeEngine
 
 _DECIMAL_TYPES = (1231, 1700)
@@ -247,6 +250,15 @@ class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
     def python_type(self) -> Type[dt.timedelta]:
         return dt.timedelta
 
+    @no_type_check
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[dt.timedelta]]:
+        def process(value: dt.timedelta) -> str:
+            return f"make_interval(secs=>{value.total_seconds()})"
+
+        return process
+
 
 PGInterval = INTERVAL
 
index d13c548baf4940f05c3816e48282e93e7e742994..5d1f3fb1663958a96e974d280c765d571e71a508 100644 (file)
@@ -845,6 +845,14 @@ class SuiteRequirements(Requirements):
         """Target driver can create tables with a name like 'some " table'"""
         return exclusions.open()
 
+    @property
+    def datetime_interval(self):
+        """target dialect supports rendering of a datetime.timedelta as a
+        literal string, e.g. via the TypeEngine.literal_processor() method.
+
+        """
+        return exclusions.closed()
+
     @property
     def datetime_literals(self):
         """target dialect supports rendering of a date, time, or datetime as a
index be405667e11a4fd2336fa9acb1dbbe3fd2fb8e7b..4c7e45ac07723cfb23564d815aaa2aaf8d04f36b 100644 (file)
@@ -28,6 +28,7 @@ from ... import Date
 from ... import DateTime
 from ... import Float
 from ... import Integer
+from ... import Interval
 from ... import JSON
 from ... import literal
 from ... import literal_column
@@ -461,6 +462,94 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
         )
 
 
+class IntervalTest(_LiteralRoundTripFixture, fixtures.TestBase):
+    __requires__ = ("datetime_interval",)
+    __backend__ = True
+    compare = None
+    datatype = Interval
+    data = datetime.timedelta(days=1, seconds=4)
+
+    def test_literal(self, literal_round_trip):
+        literal_round_trip(self.datatype, [self.data], [self.data])
+
+    def test_select_direct_literal_interval(self, connection):
+        row = connection.execute(select(literal(self.data))).first()
+        eq_(row, (self.data,))
+
+    def test_arithmetic_operation_literal_interval(self, connection):
+        now = datetime.datetime.now().replace(microsecond=0)
+        # Able to subtract
+        row = connection.execute(
+            select(literal(now) - literal(self.data))
+        ).scalar()
+        eq_(row, now - self.data)
+
+        # Able to Add
+        row = connection.execute(
+            select(literal(now) + literal(self.data))
+        ).scalar()
+        eq_(row, now + self.data)
+
+    @testing.fixture
+    def arithmetic_table_fixture(cls, metadata, connection):
+        class Decorated(TypeDecorator):
+            impl = cls.datatype
+            cache_ok = True
+
+        it = Table(
+            "interval_table",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("interval_data", cls.datatype),
+            Column("date_data", DateTime),
+            Column("decorated_interval_data", Decorated),
+        )
+        it.create(connection)
+        return it
+
+    def test_arithmetic_operation_table_interval_and_literal_interval(
+        self, connection, arithmetic_table_fixture
+    ):
+        interval_table = arithmetic_table_fixture
+        data = datetime.timedelta(days=2, seconds=5)
+        connection.execute(
+            interval_table.insert(), {"id": 1, "interval_data": data}
+        )
+        # Subtraction Operation
+        row = connection.execute(
+            select(interval_table.c.interval_data - literal(self.data))
+        ).scalar()
+        eq_(row, datetime.timedelta(days=1, seconds=1))
+
+        # Addition Operation
+        row = connection.execute(
+            select(interval_table.c.interval_data + literal(self.data))
+        ).scalar()
+        eq_(row, datetime.timedelta(days=3, seconds=9))
+
+    def test_arithmetic_operation_table_date_and_literal_interval(
+        self, connection, arithmetic_table_fixture
+    ):
+        interval_table = arithmetic_table_fixture
+        now = datetime.datetime.now().replace(microsecond=0)
+        connection.execute(
+            interval_table.insert(), {"id": 1, "date_data": now}
+        )
+        # Subtraction Operation
+        row = connection.execute(
+            select(interval_table.c.date_data - literal(self.data))
+        ).scalar()
+        eq_(row, (now - self.data))
+
+        # Addition Operation
+        row = connection.execute(
+            select(interval_table.c.date_data + literal(self.data))
+        ).scalar()
+        eq_(row, (now + self.data))
+
+
 class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
     compare = None
 
@@ -1949,6 +2038,7 @@ __all__ = (
     "TextTest",
     "NumericTest",
     "IntegerTest",
+    "IntervalTest",
     "CastTypeDecoratorTest",
     "DateTimeHistoricTest",
     "DateTimeCoercedToDateTimeTest",
index a970adc4bacddace0f897faa74c0d2ff46965ebe..65d4ba826ff4a586f0698749e5cd52b1b2445094 100644 (file)
@@ -208,6 +208,33 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_float_type_compile(self, type_, sql_text):
         self.assert_compile(type_, sql_text)
 
+    @testing.combinations(
+        (
+            text("select :parameter from dual").bindparams(
+                parameter=datetime.timedelta(days=2)
+            ),
+            "select NUMTODSINTERVAL(172800.0, 'SECOND') from dual",
+        ),
+        (
+            text("SELECT :parameter from dual").bindparams(
+                parameter=datetime.timedelta(days=1, minutes=3, seconds=4)
+            ),
+            "SELECT NUMTODSINTERVAL(86584.0, 'SECOND') from dual",
+        ),
+        (
+            text("select :parameter - :parameter2 from dual").bindparams(
+                parameter=datetime.timedelta(days=1, minutes=3, seconds=4),
+                parameter2=datetime.timedelta(days=0, minutes=1, seconds=4),
+            ),
+            (
+                "select NUMTODSINTERVAL(86584.0, 'SECOND') - "
+                "NUMTODSINTERVAL(64.0, 'SECOND') from dual"
+            ),
+        ),
+    )
+    def test_interval_literal_processor(self, type_, expected):
+        self.assert_compile(type_, expected, literal_binds=True)
+
 
 class TypesTest(fixtures.TestBase):
     __only_on__ = "oracle"
@@ -323,6 +350,24 @@ class TypesTest(fixtures.TestBase):
             datetime.timedelta(days=35, seconds=5743),
         )
 
+    def test_interval_literal_processor(self, connection):
+        stmt = text("select :parameter - :parameter2 from dual")
+        result = connection.execute(
+            stmt.bindparams(
+                bindparam(
+                    "parameter",
+                    datetime.timedelta(days=1, minutes=3, seconds=4),
+                    literal_execute=True,
+                ),
+                bindparam(
+                    "parameter2",
+                    datetime.timedelta(days=0, minutes=1, seconds=4),
+                    literal_execute=True,
+                ),
+            )
+        ).one()
+        eq_(result[0], datetime.timedelta(days=1, seconds=120))
+
     def test_no_decimal_float_precision(self):
         with expect_raises_message(
             exc.ArgumentError,
index 54a7c69cab6f9d16344f62e2de870a16f977122e..95bbb16636ec1318489960348ee2a9e33b2f35c8 100644 (file)
@@ -3151,7 +3151,9 @@ class HashableFlagORMTest(fixtures.TestBase):
         )
 
 
-class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
+class TimestampTest(
+    fixtures.TestBase, AssertsCompiledSQL, AssertsExecutionResults
+):
     __only_on__ = "postgresql"
     __backend__ = True
 
@@ -3181,6 +3183,41 @@ class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
         expr = column("bar", postgresql.INTERVAL) == datetime.timedelta(days=1)
         eq_(expr.right.type._type_affinity, types.Interval)
 
+    def test_interval_literal_processor(self, connection):
+        stmt = text("select :parameter - :parameter2")
+        result = connection.execute(
+            stmt.bindparams(
+                bindparam(
+                    "parameter",
+                    datetime.timedelta(days=1, minutes=3, seconds=4),
+                    literal_execute=True,
+                ),
+                bindparam(
+                    "parameter2",
+                    datetime.timedelta(days=0, minutes=1, seconds=4),
+                    literal_execute=True,
+                ),
+            )
+        ).one()
+        eq_(result[0], datetime.timedelta(days=1, seconds=120))
+
+    @testing.combinations(
+        (
+            text("select :parameter").bindparams(
+                parameter=datetime.timedelta(days=2)
+            ),
+            ("select make_interval(secs=>172800.0)"),
+        ),
+        (
+            text("select :parameter").bindparams(
+                parameter=datetime.timedelta(days=730, seconds=2323213392),
+            ),
+            ("select make_interval(secs=>2386285392.0)"),
+        ),
+    )
+    def test_interval_literal_processor_compiled(self, type_, expected):
+        self.assert_compile(type_, expected, literal_binds=True)
+
 
 class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "postgresql"
index 88798d6cd7b6019fa88a2552a5dc4501126eb52d..4a0b365c2b546bd5bf5d699a3ef9de278dd092a7 100644 (file)
@@ -1202,6 +1202,14 @@ class DefaultRequirements(SuiteRequirements):
     def json_array_indexes(self):
         return self.json_type
 
+    @property
+    def datetime_interval(self):
+        """target dialect supports rendering of a datetime.timedelta as a
+        literal string, e.g. via the TypeEngine.literal_processor() method.
+        Added for Oracle and Postgresql as of now.
+        """
+        return only_on(["oracle", "postgresql"])
+
     @property
     def datetime_literals(self):
         """target dialect supports rendering of a date, time, or datetime as a