From: Mike Bayer Date: Tue, 24 Oct 2023 16:06:24 +0000 (-0400) Subject: accommodate NULL at the compiler level for literal_render X-Git-Tag: rel_2_0_23~12^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1a23c8dee5665ebda75a1ea7d5e7ca355ea1f78b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git accommodate NULL at the compiler level for literal_render Added compiler-level None/NULL handling for the "literal processors" of all datatypes that include literal processing, that is, where a value is rendered inline within a SQL statement rather than as a bound parameter, for all those types that do not feature explicit "null value" handling. Previously this behavior was undefined and inconsistent. Fixes: #10535 Change-Id: I746d19d6cec2aefa3244f5e5a6970950a698d96c --- diff --git a/doc/build/changelog/unreleased_20/10535.rst b/doc/build/changelog/unreleased_20/10535.rst new file mode 100644 index 0000000000..c8435bef1e --- /dev/null +++ b/doc/build/changelog/unreleased_20/10535.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 10535 + + Added compiler-level None/NULL handling for the "literal processors" of all + datatypes that include literal processing, that is, where a value is + rendered inline within a SQL statement rather than as a bound parameter, + for all those types that do not feature explicit "null value" handling. + Previously this behavior was undefined and inconsistent. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 50ba43b364..687de04e4d 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1537,28 +1537,22 @@ class MSUUid(sqltypes.Uuid): if self.native_uuid: def process(value): - if value is not None: - value = f"""'{str(value).replace("''", "'")}'""" - return value + return f"""'{str(value).replace("''", "'")}'""" return process else: if self.as_uuid: def process(value): - if value is not None: - value = f"""'{value.hex}'""" - return value + return f"""'{value.hex}'""" return process else: def process(value): - if value is not None: - value = f"""'{ + return f"""'{ value.replace("-", "").replace("'", "''") }'""" - return value return process diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index 62028c7673..c1f6d51916 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -116,38 +116,36 @@ class LONG(sqltypes.Text): class _OracleDateLiteralRender: def _literal_processor_datetime(self, dialect): def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS.FF')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS')""" - ) + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS.FF')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS')""" + ) return value return process def _literal_processor_date(self, dialect): def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) return value return process diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 20f772f54a..cb6899c5e9 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3766,6 +3766,12 @@ class SQLCompiler(Compiled): """ + if value is None and not type_.should_evaluate_none: + # issue #10535 - handle NULL in the compiler without placing + # this onto each type, except for "evaluate None" types + # (e.g. JSON) + return self.process(elements.Null._instance()) + processor = type_._cached_literal_processor(self.dialect) if processor: try: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 343575f196..8b75a51af8 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -740,16 +740,12 @@ class _RenderISO8601NoT: if _portion is not None: def process(value): - if value is not None: - value = f"""'{value.isoformat().split("T")[_portion]}'""" - return value + return f"""'{value.isoformat().split("T")[_portion]}'""" else: def process(value): - if value is not None: - value = f"""'{value.isoformat().replace("T", " ")}'""" - return value + return f"""'{value.isoformat().replace("T", " ")}'""" return process @@ -2484,6 +2480,9 @@ class JSON(Indexable, TypeEngine[Any]): value = int_processor(value) elif string_processor and isinstance(value, str): value = string_processor(value) + else: + raise NotImplementedError() + return value return process @@ -3706,28 +3705,20 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): if not self.as_uuid: def process(value): - if value is not None: - value = ( - f"""'{value.replace("-", "").replace("'", "''")}'""" - ) - return value + return f"""'{value.replace("-", "").replace("'", "''")}'""" return process else: if character_based_uuid: def process(value): - if value is not None: - value = f"""'{value.hex}'""" - return value + return f"""'{value.hex}'""" return process else: def process(value): - if value is not None: - value = f"""'{str(value).replace("'", "''")}'""" - return value + return f"""'{str(value).replace("'", "''")}'""" return process diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 0a1419f253..be405667e1 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -82,6 +82,11 @@ class _LiteralRoundTripFixture: ) connection.execute(ins) + ins = t.insert().values( + x=literal(None, type_, literal_execute=True) + ) + connection.execute(ins) + if support_whereclause and self.supports_whereclause: if compare: stmt = t.select().where( @@ -108,7 +113,7 @@ class _LiteralRoundTripFixture: ) ) else: - stmt = t.select() + stmt = t.select().where(t.c.x.is_not(None)) rows = connection.execute(stmt).all() assert rows, "No rows returned" @@ -118,6 +123,10 @@ class _LiteralRoundTripFixture: value = filter_(value) assert value in output + stmt = t.select().where(t.c.x.is_(None)) + rows = connection.execute(stmt).all() + eq_(rows, [(None,)]) + return run