]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement second-level type resolution for literals
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Jan 2022 17:20:46 +0000 (12:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Jan 2022 14:25:19 +0000 (09:25 -0500)
Added additional rule to the system that determines ``TypeEngine``
implementations from Python literals to apply a second level of adjustment
to the type, so that a Python datetime with or without tzinfo can set the
``timezone=True`` parameter on the returned :class:`.DateTime` object, as
well as :class:`.Time`. This helps with some round-trip scenarios on
type-sensitive PostgreSQL dialects such as asyncpg, psycopg3 (2.0 only).

Improved support for asyncpg handling of TIME WITH TIMEZONE, which
was not fully implemented.

Fixes: #7537
Change-Id: Icdb07db85af5f7f39f1c1ef855fe27609770094b
(cherry picked from commit 3b2e28bcb5ba32446a92b62b6862b7c11dabb592)

doc/build/changelog/unreleased_14/7537.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_types.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_14/7537.rst b/doc/build/changelog/unreleased_14/7537.rst
new file mode 100644 (file)
index 0000000..d48cf30
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: bug, sql, postgresql
+    :tickets: 7537
+
+    Added additional rule to the system that determines ``TypeEngine``
+    implementations from Python literals to apply a second level of adjustment
+    to the type, so that a Python datetime with or without tzinfo can set the
+    ``timezone=True`` parameter on the returned :class:`.DateTime` object, as
+    well as :class:`.Time`. This helps with some round-trip scenarios on
+    type-sensitive PostgreSQL dialects such as asyncpg, psycopg3 (2.0 only).
+
+.. change::
+    :tags: bug, postgresql, asyncpg
+    :tickets: 7537
+
+    Improved support for asyncpg handling of TIME WITH TIMEZONE, which
+    was not fully implemented.
index fedc0b495b4f4321c595824d733b3b3f91697643..f32192b3c8bcd9ac9371261e404611f775dbe6d9 100644 (file)
@@ -136,7 +136,10 @@ except ImportError:
 
 class AsyncpgTime(sqltypes.Time):
     def get_dbapi_type(self, dbapi):
-        return dbapi.TIME
+        if self.timezone:
+            return dbapi.TIME_W_TZ
+        else:
+            return dbapi.TIME
 
 
 class AsyncpgDate(sqltypes.Date):
@@ -818,6 +821,7 @@ class AsyncAdapt_asyncpg_dbapi:
     TIMESTAMP = util.symbol("TIMESTAMP")
     TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
     TIME = util.symbol("TIME")
+    TIME_W_TZ = util.symbol("TIME_W_TZ")
     DATE = util.symbol("DATE")
     INTERVAL = util.symbol("INTERVAL")
     NUMBER = util.symbol("NUMBER")
@@ -843,6 +847,7 @@ _pg_types = {
     AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
     AsyncAdapt_asyncpg_dbapi.DATE: "date",
     AsyncAdapt_asyncpg_dbapi.TIME: "time",
+    AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone",
     AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
     AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
     AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
index 3f3801ab009d159eee99741d50bc8bd2b931ddbd..c80b10fcc34db6fda7bb855d3843e76b9b10f1f1 100644 (file)
@@ -867,6 +867,13 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
+    def _resolve_for_literal(self, value):
+        with_timezone = value.tzinfo is not None
+        if with_timezone and not self.timezone:
+            return DATETIME_TIMEZONE
+        else:
+            return self
+
     @property
     def python_type(self):
         return dt.datetime
@@ -937,6 +944,13 @@ class Time(_LookupExpressionAdapter, TypeEngine):
     def python_type(self):
         return dt.time
 
+    def _resolve_for_literal(self, value):
+        with_timezone = value.tzinfo is not None
+        if with_timezone and not self.timezone:
+            return TIME_TIMEZONE
+        else:
+            return self
+
     @util.memoized_property
     def _expression_adaptations(self):
         # Based on https://www.postgresql.org/docs/current/\
@@ -3254,6 +3268,8 @@ STRINGTYPE = String()
 INTEGERTYPE = Integer()
 MATCHTYPE = MatchType()
 TABLEVALUE = TableValueType()
+DATETIME_TIMEZONE = DateTime(timezone=True)
+TIME_TIMEZONE = Time(timezone=True)
 
 _type_map = {
     int: Integer(),
@@ -3296,7 +3312,7 @@ def _resolve_value_to_type(value):
             )
         return NULLTYPE
     else:
-        return _result_type
+        return _result_type._resolve_for_literal(value)
 
 
 # back-assign to type_api
index 49f6cfe204a83bb70f9e9210937f747070802652..ecf68e62dd4b1aae8b0fb0862bcb3ae8807dd3bc 100644 (file)
@@ -545,6 +545,17 @@ class TypeEngine(Traversible):
         """
         return Variant(self, {dialect_name: to_instance(type_)})
 
+    def _resolve_for_literal(self, value):
+        """adjust this type given a literal Python value that will be
+        stored in a bound parameter.
+
+        Used exclusively by _resolve_value_to_type().
+
+        .. versionadded:: 1.4.30 or 2.0
+
+        """
+        return self
+
     @util.memoized_property
     def _type_affinity(self):
         """Return a rudimental 'affinity' value expressing the general class
index a0f262a760adab6cae0c53db2b1d5579ad8c7bd0..1c8858ec141a64d757818e805e1d76f1b07ee4d6 100644 (file)
@@ -753,6 +753,29 @@ class SuiteRequirements(Requirements):
 
         return exclusions.open()
 
+    @property
+    def datetime_timezone(self):
+        """target dialect supports representation of Python
+        datetime.datetime() with tzinfo with DateTime(timezone=True)."""
+
+        return exclusions.closed()
+
+    @property
+    def time_timezone(self):
+        """target dialect supports representation of Python
+        datetime.time() with tzinfo with Time(timezone=True)."""
+
+        return exclusions.closed()
+
+    @property
+    def datetime_implicit_bound(self):
+        """target dialect when given a datetime object will bind it such
+        that the database server knows the object is a datetime, and not
+        a plain string.
+
+        """
+        return exclusions.open()
+
     @property
     def datetime_microseconds(self):
         """target dialect supports representation of Python
@@ -767,6 +790,16 @@ class SuiteRequirements(Requirements):
         if TIMESTAMP is used."""
         return exclusions.closed()
 
+    @property
+    def timestamp_microseconds_implicit_bound(self):
+        """target dialect when given a datetime object which also includes
+        a microseconds portion when using the TIMESTAMP data type
+        will bind it such that the database server knows
+        the object is a datetime with microseconds, and not a plain string.
+
+        """
+        return self.timestamp_microseconds
+
     @property
     def datetime_historic(self):
         """target dialect supports representation of Python
index d62b608095aa5cbe785663992809ded2225b19bf..2fdea5e48e74203cbd81adf66dac6c2f6569f697 100644 (file)
@@ -41,6 +41,7 @@ from ... import UnicodeText
 from ... import util
 from ...orm import declarative_base
 from ...orm import Session
+from ...util import compat
 from ...util import u
 
 
@@ -308,6 +309,11 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
             Column("decorated_date_data", Decorated),
         )
 
+    @testing.requires.datetime_implicit_bound
+    def test_select_direct(self, connection):
+        result = connection.scalar(select(literal(self.data)))
+        eq_(result, self.data)
+
     def test_round_trip(self, connection):
         date_table = self.tables.date_table
 
@@ -382,6 +388,15 @@ class DateTimeTest(_DateFixture, fixtures.TablesTest):
     data = datetime.datetime(2012, 10, 15, 12, 57, 18)
 
 
+class DateTimeTZTest(_DateFixture, fixtures.TablesTest):
+    __requires__ = ("datetime_timezone",)
+    __backend__ = True
+    datatype = DateTime(timezone=True)
+    data = datetime.datetime(
+        2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc
+    )
+
+
 class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
     __requires__ = ("datetime_microseconds",)
     __backend__ = True
@@ -395,6 +410,11 @@ class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
     datatype = TIMESTAMP
     data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
 
+    @testing.requires.timestamp_microseconds_implicit_bound
+    def test_select_direct(self, connection):
+        result = connection.scalar(select(literal(self.data)))
+        eq_(result, self.data)
+
 
 class TimeTest(_DateFixture, fixtures.TablesTest):
     __requires__ = ("time",)
@@ -403,6 +423,13 @@ class TimeTest(_DateFixture, fixtures.TablesTest):
     data = datetime.time(12, 57, 18)
 
 
+class TimeTZTest(_DateFixture, fixtures.TablesTest):
+    __requires__ = ("time_timezone",)
+    __backend__ = True
+    datatype = Time(timezone=True)
+    data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc)
+
+
 class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
     __requires__ = ("time_microseconds",)
     __backend__ = True
@@ -1424,6 +1451,7 @@ __all__ = (
     "JSONLegacyStringCastIndexTest",
     "DateTest",
     "DateTimeTest",
+    "DateTimeTZTest",
     "TextTest",
     "NumericTest",
     "IntegerTest",
@@ -1433,6 +1461,7 @@ __all__ = (
     "TimeMicrosecondsTest",
     "TimestampMicrosecondsTest",
     "TimeTest",
+    "TimeTZTest",
     "DateTimeMicrosecondsTest",
     "DateHistoricTest",
     "StringTest",
index 006c523a69da07ff58766978193dc32ffafb9a4f..bf83b83b48bc55eaae569c2d50833dcdbeeacd47 100644 (file)
@@ -1126,6 +1126,27 @@ class DefaultRequirements(SuiteRequirements):
 
         return exclusions.open()
 
+    @property
+    def datetime_implicit_bound(self):
+        """target dialect when given a datetime object will bind it such
+        that the database server knows the object is a datetime, and not
+        a plain string.
+
+        """
+        # pg8000 works in main / 2.0, support in 1.4 is not fully
+        # present.
+        return exclusions.skip_if("postgresql+pg8000") + exclusions.fails_on(
+            ["mysql", "mariadb"]
+        )
+
+    @property
+    def datetime_timezone(self):
+        return exclusions.only_on("postgresql")
+
+    @property
+    def time_timezone(self):
+        return exclusions.only_on("postgresql") + exclusions.skip_if("+pg8000")
+
     @property
     def datetime_microseconds(self):
         """target dialect supports representation of Python
@@ -1143,6 +1164,10 @@ class DefaultRequirements(SuiteRequirements):
 
         return only_on(["oracle"])
 
+    @property
+    def timestamp_microseconds_implicit_bound(self):
+        return self.timestamp_microseconds + exclusions.fails_on(["oracle"])
+
     @property
     def datetime_historic(self):
         """target dialect supports representation of Python