From 598ae1d120a35ee9e19cf7f953c7dd51047954d6 Mon Sep 17 00:00:00 2001 From: indivar Date: Thu, 28 Sep 2023 13:43:57 -0400 Subject: [PATCH] add literal_processors for interval, PG and Oracle Implemented "literal value processing" for the :class:`.Interval` datatype for both the PostgreSQL and Oracle dialects, allowing literal rendering of interval values. Pull request courtesy Indivar Mishra. Fixes: #9737 Closes: #10383 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10383 Pull-request-sha: bf3a73dfd9d329779e12037ae62dc1032e76d0f6 Change-Id: Ic1a1c505f23eeb681059303799d5fc8821aadacf --- doc/build/changelog/unreleased_20/9737.rst | 7 ++ lib/sqlalchemy/dialects/oracle/types.py | 24 ++++++ lib/sqlalchemy/dialects/postgresql/types.py | 12 +++ lib/sqlalchemy/testing/requirements.py | 8 ++ lib/sqlalchemy/testing/suite/test_types.py | 90 +++++++++++++++++++++ test/dialect/oracle/test_types.py | 45 +++++++++++ test/dialect/postgresql/test_types.py | 39 ++++++++- test/requirements.py | 8 ++ 8 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 doc/build/changelog/unreleased_20/9737.rst diff --git a/doc/build/changelog/unreleased_20/9737.rst b/doc/build/changelog/unreleased_20/9737.rst new file mode 100644 index 0000000000..806ee05706 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9737.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: usecase, sql + :tickets: 9737 + + Implemented "literal value processing" for the :class:`.Interval` datatype + for both the PostgreSQL and Oracle dialects, allowing literal rendering of + interval values. Pull request courtesy Indivar Mishra. diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index c1f6d51916..4e616d1ceb 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -4,12 +4,23 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors +from __future__ import annotations + +import datetime as dt +from typing import no_type_check +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING from ... import exc from ...sql import sqltypes from ...types import NVARCHAR from ...types import VARCHAR +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.type_api import _LiteralProcessorType + class RAW(sqltypes._Binary): __visit_name__ = "RAW" @@ -212,6 +223,19 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): day_precision=self.day_precision, ) + @property + def python_type(self) -> Type[dt.timedelta]: + return dt.timedelta + + @no_type_check + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[dt.timedelta]]: + def process(value: dt.timedelta) -> str: + return f"NUMTODSINTERVAL({value.total_seconds()}, 'SECOND')" + + return process + class TIMESTAMP(sqltypes.TIMESTAMP): """Oracle implementation of ``TIMESTAMP``, which supports additional diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 2f49ff12a4..61116aa43d 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -7,6 +7,7 @@ from __future__ import annotations import datetime as dt from typing import Any +from typing import no_type_check from typing import Optional from typing import overload from typing import Type @@ -18,7 +19,9 @@ from ...sql import type_api from ...util.typing import Literal if TYPE_CHECKING: + from ...engine.interfaces import Dialect from ...sql.operators import OperatorType + from ...sql.type_api import _LiteralProcessorType from ...sql.type_api import TypeEngine _DECIMAL_TYPES = (1231, 1700) @@ -247,6 +250,15 @@ class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): def python_type(self) -> Type[dt.timedelta]: return dt.timedelta + @no_type_check + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[dt.timedelta]]: + def process(value: dt.timedelta) -> str: + return f"make_interval(secs=>{value.total_seconds()})" + + return process + PGInterval = INTERVAL diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index d13c548baf..5d1f3fb166 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -845,6 +845,14 @@ class SuiteRequirements(Requirements): """Target driver can create tables with a name like 'some " table'""" return exclusions.open() + @property + def datetime_interval(self): + """target dialect supports rendering of a datetime.timedelta as a + literal string, e.g. via the TypeEngine.literal_processor() method. + + """ + return exclusions.closed() + @property def datetime_literals(self): """target dialect supports rendering of a date, time, or datetime as a diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index be405667e1..4c7e45ac07 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -28,6 +28,7 @@ from ... import Date from ... import DateTime from ... import Float from ... import Integer +from ... import Interval from ... import JSON from ... import literal from ... import literal_column @@ -461,6 +462,94 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): ) +class IntervalTest(_LiteralRoundTripFixture, fixtures.TestBase): + __requires__ = ("datetime_interval",) + __backend__ = True + compare = None + datatype = Interval + data = datetime.timedelta(days=1, seconds=4) + + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) + + def test_select_direct_literal_interval(self, connection): + row = connection.execute(select(literal(self.data))).first() + eq_(row, (self.data,)) + + def test_arithmetic_operation_literal_interval(self, connection): + now = datetime.datetime.now().replace(microsecond=0) + # Able to subtract + row = connection.execute( + select(literal(now) - literal(self.data)) + ).scalar() + eq_(row, now - self.data) + + # Able to Add + row = connection.execute( + select(literal(now) + literal(self.data)) + ).scalar() + eq_(row, now + self.data) + + @testing.fixture + def arithmetic_table_fixture(cls, metadata, connection): + class Decorated(TypeDecorator): + impl = cls.datatype + cache_ok = True + + it = Table( + "interval_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("interval_data", cls.datatype), + Column("date_data", DateTime), + Column("decorated_interval_data", Decorated), + ) + it.create(connection) + return it + + def test_arithmetic_operation_table_interval_and_literal_interval( + self, connection, arithmetic_table_fixture + ): + interval_table = arithmetic_table_fixture + data = datetime.timedelta(days=2, seconds=5) + connection.execute( + interval_table.insert(), {"id": 1, "interval_data": data} + ) + # Subtraction Operation + row = connection.execute( + select(interval_table.c.interval_data - literal(self.data)) + ).scalar() + eq_(row, datetime.timedelta(days=1, seconds=1)) + + # Addition Operation + row = connection.execute( + select(interval_table.c.interval_data + literal(self.data)) + ).scalar() + eq_(row, datetime.timedelta(days=3, seconds=9)) + + def test_arithmetic_operation_table_date_and_literal_interval( + self, connection, arithmetic_table_fixture + ): + interval_table = arithmetic_table_fixture + now = datetime.datetime.now().replace(microsecond=0) + connection.execute( + interval_table.insert(), {"id": 1, "date_data": now} + ) + # Subtraction Operation + row = connection.execute( + select(interval_table.c.date_data - literal(self.data)) + ).scalar() + eq_(row, (now - self.data)) + + # Addition Operation + row = connection.execute( + select(interval_table.c.date_data + literal(self.data)) + ).scalar() + eq_(row, (now + self.data)) + + class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): compare = None @@ -1949,6 +2038,7 @@ __all__ = ( "TextTest", "NumericTest", "IntegerTest", + "IntervalTest", "CastTypeDecoratorTest", "DateTimeHistoricTest", "DateTimeCoercedToDateTimeTest", diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index a970adc4ba..65d4ba826f 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -208,6 +208,33 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): def test_float_type_compile(self, type_, sql_text): self.assert_compile(type_, sql_text) + @testing.combinations( + ( + text("select :parameter from dual").bindparams( + parameter=datetime.timedelta(days=2) + ), + "select NUMTODSINTERVAL(172800.0, 'SECOND') from dual", + ), + ( + text("SELECT :parameter from dual").bindparams( + parameter=datetime.timedelta(days=1, minutes=3, seconds=4) + ), + "SELECT NUMTODSINTERVAL(86584.0, 'SECOND') from dual", + ), + ( + text("select :parameter - :parameter2 from dual").bindparams( + parameter=datetime.timedelta(days=1, minutes=3, seconds=4), + parameter2=datetime.timedelta(days=0, minutes=1, seconds=4), + ), + ( + "select NUMTODSINTERVAL(86584.0, 'SECOND') - " + "NUMTODSINTERVAL(64.0, 'SECOND') from dual" + ), + ), + ) + def test_interval_literal_processor(self, type_, expected): + self.assert_compile(type_, expected, literal_binds=True) + class TypesTest(fixtures.TestBase): __only_on__ = "oracle" @@ -323,6 +350,24 @@ class TypesTest(fixtures.TestBase): datetime.timedelta(days=35, seconds=5743), ) + def test_interval_literal_processor(self, connection): + stmt = text("select :parameter - :parameter2 from dual") + result = connection.execute( + stmt.bindparams( + bindparam( + "parameter", + datetime.timedelta(days=1, minutes=3, seconds=4), + literal_execute=True, + ), + bindparam( + "parameter2", + datetime.timedelta(days=0, minutes=1, seconds=4), + literal_execute=True, + ), + ) + ).one() + eq_(result[0], datetime.timedelta(days=1, seconds=120)) + def test_no_decimal_float_precision(self): with expect_raises_message( exc.ArgumentError, diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 54a7c69cab..95bbb16636 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -3151,7 +3151,9 @@ class HashableFlagORMTest(fixtures.TestBase): ) -class TimestampTest(fixtures.TestBase, AssertsExecutionResults): +class TimestampTest( + fixtures.TestBase, AssertsCompiledSQL, AssertsExecutionResults +): __only_on__ = "postgresql" __backend__ = True @@ -3181,6 +3183,41 @@ class TimestampTest(fixtures.TestBase, AssertsExecutionResults): expr = column("bar", postgresql.INTERVAL) == datetime.timedelta(days=1) eq_(expr.right.type._type_affinity, types.Interval) + def test_interval_literal_processor(self, connection): + stmt = text("select :parameter - :parameter2") + result = connection.execute( + stmt.bindparams( + bindparam( + "parameter", + datetime.timedelta(days=1, minutes=3, seconds=4), + literal_execute=True, + ), + bindparam( + "parameter2", + datetime.timedelta(days=0, minutes=1, seconds=4), + literal_execute=True, + ), + ) + ).one() + eq_(result[0], datetime.timedelta(days=1, seconds=120)) + + @testing.combinations( + ( + text("select :parameter").bindparams( + parameter=datetime.timedelta(days=2) + ), + ("select make_interval(secs=>172800.0)"), + ), + ( + text("select :parameter").bindparams( + parameter=datetime.timedelta(days=730, seconds=2323213392), + ), + ("select make_interval(secs=>2386285392.0)"), + ), + ) + def test_interval_literal_processor_compiled(self, type_, expected): + self.assert_compile(type_, expected, literal_binds=True) + class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "postgresql" diff --git a/test/requirements.py b/test/requirements.py index 88798d6cd7..4a0b365c2b 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1202,6 +1202,14 @@ class DefaultRequirements(SuiteRequirements): def json_array_indexes(self): return self.json_type + @property + def datetime_interval(self): + """target dialect supports rendering of a datetime.timedelta as a + literal string, e.g. via the TypeEngine.literal_processor() method. + Added for Oracle and Postgresql as of now. + """ + return only_on(["oracle", "postgresql"]) + @property def datetime_literals(self): """target dialect supports rendering of a date, time, or datetime as a -- 2.39.5