From: Mike Bayer Date: Wed, 5 Jan 2022 17:20:46 +0000 (-0500) Subject: implement second-level type resolution for literals X-Git-Tag: rel_1_4_30~10^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e215db01d48c418e190936e6b36ea49c6eb22072;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement second-level type resolution for literals Added additional rule to the system that determines ``TypeEngine`` implementations from Python literals to apply a second level of adjustment to the type, so that a Python datetime with or without tzinfo can set the ``timezone=True`` parameter on the returned :class:`.DateTime` object, as well as :class:`.Time`. This helps with some round-trip scenarios on type-sensitive PostgreSQL dialects such as asyncpg, psycopg3 (2.0 only). Improved support for asyncpg handling of TIME WITH TIMEZONE, which was not fully implemented. Fixes: #7537 Change-Id: Icdb07db85af5f7f39f1c1ef855fe27609770094b (cherry picked from commit 3b2e28bcb5ba32446a92b62b6862b7c11dabb592) --- diff --git a/doc/build/changelog/unreleased_14/7537.rst b/doc/build/changelog/unreleased_14/7537.rst new file mode 100644 index 0000000000..d48cf30a07 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7537.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: bug, sql, postgresql + :tickets: 7537 + + Added additional rule to the system that determines ``TypeEngine`` + implementations from Python literals to apply a second level of adjustment + to the type, so that a Python datetime with or without tzinfo can set the + ``timezone=True`` parameter on the returned :class:`.DateTime` object, as + well as :class:`.Time`. This helps with some round-trip scenarios on + type-sensitive PostgreSQL dialects such as asyncpg, psycopg3 (2.0 only). + +.. change:: + :tags: bug, postgresql, asyncpg + :tickets: 7537 + + Improved support for asyncpg handling of TIME WITH TIMEZONE, which + was not fully implemented. diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index fedc0b495b..f32192b3c8 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -136,7 +136,10 @@ except ImportError: class AsyncpgTime(sqltypes.Time): def get_dbapi_type(self, dbapi): - return dbapi.TIME + if self.timezone: + return dbapi.TIME_W_TZ + else: + return dbapi.TIME class AsyncpgDate(sqltypes.Date): @@ -818,6 +821,7 @@ class AsyncAdapt_asyncpg_dbapi: TIMESTAMP = util.symbol("TIMESTAMP") TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ") TIME = util.symbol("TIME") + TIME_W_TZ = util.symbol("TIME_W_TZ") DATE = util.symbol("DATE") INTERVAL = util.symbol("INTERVAL") NUMBER = util.symbol("NUMBER") @@ -843,6 +847,7 @@ _pg_types = { AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone", AsyncAdapt_asyncpg_dbapi.DATE: "date", AsyncAdapt_asyncpg_dbapi.TIME: "time", + AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone", AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval", AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric", AsyncAdapt_asyncpg_dbapi.FLOAT: "float", diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 3f3801ab00..c80b10fcc3 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -867,6 +867,13 @@ class DateTime(_LookupExpressionAdapter, TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.DATETIME + def _resolve_for_literal(self, value): + with_timezone = value.tzinfo is not None + if with_timezone and not self.timezone: + return DATETIME_TIMEZONE + else: + return self + @property def python_type(self): return dt.datetime @@ -937,6 +944,13 @@ class Time(_LookupExpressionAdapter, TypeEngine): def python_type(self): return dt.time + def _resolve_for_literal(self, value): + with_timezone = value.tzinfo is not None + if with_timezone and not self.timezone: + return TIME_TIMEZONE + else: + return self + @util.memoized_property def _expression_adaptations(self): # Based on https://www.postgresql.org/docs/current/\ @@ -3254,6 +3268,8 @@ STRINGTYPE = String() INTEGERTYPE = Integer() MATCHTYPE = MatchType() TABLEVALUE = TableValueType() +DATETIME_TIMEZONE = DateTime(timezone=True) +TIME_TIMEZONE = Time(timezone=True) _type_map = { int: Integer(), @@ -3296,7 +3312,7 @@ def _resolve_value_to_type(value): ) return NULLTYPE else: - return _result_type + return _result_type._resolve_for_literal(value) # back-assign to type_api diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 49f6cfe204..ecf68e62dd 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -545,6 +545,17 @@ class TypeEngine(Traversible): """ return Variant(self, {dialect_name: to_instance(type_)}) + def _resolve_for_literal(self, value): + """adjust this type given a literal Python value that will be + stored in a bound parameter. + + Used exclusively by _resolve_value_to_type(). + + .. versionadded:: 1.4.30 or 2.0 + + """ + return self + @util.memoized_property def _type_affinity(self): """Return a rudimental 'affinity' value expressing the general class diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index a0f262a760..1c8858ec14 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -753,6 +753,29 @@ class SuiteRequirements(Requirements): return exclusions.open() + @property + def datetime_timezone(self): + """target dialect supports representation of Python + datetime.datetime() with tzinfo with DateTime(timezone=True).""" + + return exclusions.closed() + + @property + def time_timezone(self): + """target dialect supports representation of Python + datetime.time() with tzinfo with Time(timezone=True).""" + + return exclusions.closed() + + @property + def datetime_implicit_bound(self): + """target dialect when given a datetime object will bind it such + that the database server knows the object is a datetime, and not + a plain string. + + """ + return exclusions.open() + @property def datetime_microseconds(self): """target dialect supports representation of Python @@ -767,6 +790,16 @@ class SuiteRequirements(Requirements): if TIMESTAMP is used.""" return exclusions.closed() + @property + def timestamp_microseconds_implicit_bound(self): + """target dialect when given a datetime object which also includes + a microseconds portion when using the TIMESTAMP data type + will bind it such that the database server knows + the object is a datetime with microseconds, and not a plain string. + + """ + return self.timestamp_microseconds + @property def datetime_historic(self): """target dialect supports representation of Python diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index d62b608095..2fdea5e48e 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -41,6 +41,7 @@ from ... import UnicodeText from ... import util from ...orm import declarative_base from ...orm import Session +from ...util import compat from ...util import u @@ -308,6 +309,11 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): Column("decorated_date_data", Decorated), ) + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + def test_round_trip(self, connection): date_table = self.tables.date_table @@ -382,6 +388,15 @@ class DateTimeTest(_DateFixture, fixtures.TablesTest): data = datetime.datetime(2012, 10, 15, 12, 57, 18) +class DateTimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_timezone",) + __backend__ = True + datatype = DateTime(timezone=True) + data = datetime.datetime( + 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc + ) + + class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): __requires__ = ("datetime_microseconds",) __backend__ = True @@ -395,6 +410,11 @@ class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): datatype = TIMESTAMP data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + @testing.requires.timestamp_microseconds_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + class TimeTest(_DateFixture, fixtures.TablesTest): __requires__ = ("time",) @@ -403,6 +423,13 @@ class TimeTest(_DateFixture, fixtures.TablesTest): data = datetime.time(12, 57, 18) +class TimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_timezone",) + __backend__ = True + datatype = Time(timezone=True) + data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc) + + class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): __requires__ = ("time_microseconds",) __backend__ = True @@ -1424,6 +1451,7 @@ __all__ = ( "JSONLegacyStringCastIndexTest", "DateTest", "DateTimeTest", + "DateTimeTZTest", "TextTest", "NumericTest", "IntegerTest", @@ -1433,6 +1461,7 @@ __all__ = ( "TimeMicrosecondsTest", "TimestampMicrosecondsTest", "TimeTest", + "TimeTZTest", "DateTimeMicrosecondsTest", "DateHistoricTest", "StringTest", diff --git a/test/requirements.py b/test/requirements.py index 006c523a69..bf83b83b48 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1126,6 +1126,27 @@ class DefaultRequirements(SuiteRequirements): return exclusions.open() + @property + def datetime_implicit_bound(self): + """target dialect when given a datetime object will bind it such + that the database server knows the object is a datetime, and not + a plain string. + + """ + # pg8000 works in main / 2.0, support in 1.4 is not fully + # present. + return exclusions.skip_if("postgresql+pg8000") + exclusions.fails_on( + ["mysql", "mariadb"] + ) + + @property + def datetime_timezone(self): + return exclusions.only_on("postgresql") + + @property + def time_timezone(self): + return exclusions.only_on("postgresql") + exclusions.skip_if("+pg8000") + @property def datetime_microseconds(self): """target dialect supports representation of Python @@ -1143,6 +1164,10 @@ class DefaultRequirements(SuiteRequirements): return only_on(["oracle"]) + @property + def timestamp_microseconds_implicit_bound(self): + return self.timestamp_microseconds + exclusions.fails_on(["oracle"]) + @property def datetime_historic(self): """target dialect supports representation of Python