]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
accommodate NULL at the compiler level for literal_render
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Oct 2023 16:06:24 +0000 (12:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Oct 2023 16:06:24 +0000 (12:06 -0400)
Added compiler-level None/NULL handling for the "literal processors" of all
datatypes that include literal processing, that is, where a value is
rendered inline within a SQL statement rather than as a bound parameter,
for all those types that do not feature explicit "null value" handling.
Previously this behavior was undefined and inconsistent.

Fixes: #10535
Change-Id: I746d19d6cec2aefa3244f5e5a6970950a698d96c

doc/build/changelog/unreleased_20/10535.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/types.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/suite/test_types.py

diff --git a/doc/build/changelog/unreleased_20/10535.rst b/doc/build/changelog/unreleased_20/10535.rst
new file mode 100644 (file)
index 0000000..c8435be
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 10535
+
+    Added compiler-level None/NULL handling for the "literal processors" of all
+    datatypes that include literal processing, that is, where a value is
+    rendered inline within a SQL statement rather than as a bound parameter,
+    for all those types that do not feature explicit "null value" handling.
+    Previously this behavior was undefined and inconsistent.
index 50ba43b364bf98d4f3f7928465a623519157a95e..687de04e4d3da4cd0d2d8d567d2beece7f94b9d1 100644 (file)
@@ -1537,28 +1537,22 @@ class MSUUid(sqltypes.Uuid):
         if self.native_uuid:
 
             def process(value):
-                if value is not None:
-                    value = f"""'{str(value).replace("''", "'")}'"""
-                return value
+                return f"""'{str(value).replace("''", "'")}'"""
 
             return process
         else:
             if self.as_uuid:
 
                 def process(value):
-                    if value is not None:
-                        value = f"""'{value.hex}'"""
-                    return value
+                    return f"""'{value.hex}'"""
 
                 return process
             else:
 
                 def process(value):
-                    if value is not None:
-                        value = f"""'{
+                    return f"""'{
                             value.replace("-", "").replace("'", "''")
                         }'"""
-                    return value
 
                 return process
 
index 62028c767386df9458285c2973b006673c86fbb0..c1f6d51916d4192cce054cb753e838d0d368cc14 100644 (file)
@@ -116,38 +116,36 @@ class LONG(sqltypes.Text):
 class _OracleDateLiteralRender:
     def _literal_processor_datetime(self, dialect):
         def process(value):
-            if value is not None:
-                if getattr(value, "microsecond", None):
-                    value = (
-                        f"""TO_TIMESTAMP"""
-                        f"""('{value.isoformat().replace("T", " ")}', """
-                        """'YYYY-MM-DD HH24:MI:SS.FF')"""
-                    )
-                else:
-                    value = (
-                        f"""TO_DATE"""
-                        f"""('{value.isoformat().replace("T", " ")}', """
-                        """'YYYY-MM-DD HH24:MI:SS')"""
-                    )
+            if getattr(value, "microsecond", None):
+                value = (
+                    f"""TO_TIMESTAMP"""
+                    f"""('{value.isoformat().replace("T", " ")}', """
+                    """'YYYY-MM-DD HH24:MI:SS.FF')"""
+                )
+            else:
+                value = (
+                    f"""TO_DATE"""
+                    f"""('{value.isoformat().replace("T", " ")}', """
+                    """'YYYY-MM-DD HH24:MI:SS')"""
+                )
             return value
 
         return process
 
     def _literal_processor_date(self, dialect):
         def process(value):
-            if value is not None:
-                if getattr(value, "microsecond", None):
-                    value = (
-                        f"""TO_TIMESTAMP"""
-                        f"""('{value.isoformat().split("T")[0]}', """
-                        """'YYYY-MM-DD')"""
-                    )
-                else:
-                    value = (
-                        f"""TO_DATE"""
-                        f"""('{value.isoformat().split("T")[0]}', """
-                        """'YYYY-MM-DD')"""
-                    )
+            if getattr(value, "microsecond", None):
+                value = (
+                    f"""TO_TIMESTAMP"""
+                    f"""('{value.isoformat().split("T")[0]}', """
+                    """'YYYY-MM-DD')"""
+                )
+            else:
+                value = (
+                    f"""TO_DATE"""
+                    f"""('{value.isoformat().split("T")[0]}', """
+                    """'YYYY-MM-DD')"""
+                )
             return value
 
         return process
index 20f772f54a3c0f1fa2bd02d0978c015084a4f0b4..cb6899c5e9a8761a321a54a232c39a61f293cc3b 100644 (file)
@@ -3766,6 +3766,12 @@ class SQLCompiler(Compiled):
 
         """
 
+        if value is None and not type_.should_evaluate_none:
+            # issue #10535 - handle NULL in the compiler without placing
+            # this onto each type, except for "evaluate None" types
+            # (e.g. JSON)
+            return self.process(elements.Null._instance())
+
         processor = type_._cached_literal_processor(self.dialect)
         if processor:
             try:
index 343575f196e763886ee11611c9daa66b42d1a2ad..8b75a51af83c6b4eecde3df130ade5554b45c3c8 100644 (file)
@@ -740,16 +740,12 @@ class _RenderISO8601NoT:
         if _portion is not None:
 
             def process(value):
-                if value is not None:
-                    value = f"""'{value.isoformat().split("T")[_portion]}'"""
-                return value
+                return f"""'{value.isoformat().split("T")[_portion]}'"""
 
         else:
 
             def process(value):
-                if value is not None:
-                    value = f"""'{value.isoformat().replace("T", " ")}'"""
-                return value
+                return f"""'{value.isoformat().replace("T", " ")}'"""
 
         return process
 
@@ -2484,6 +2480,9 @@ class JSON(Indexable, TypeEngine[Any]):
                     value = int_processor(value)
                 elif string_processor and isinstance(value, str):
                     value = string_processor(value)
+                else:
+                    raise NotImplementedError()
+
                 return value
 
             return process
@@ -3706,28 +3705,20 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]):
         if not self.as_uuid:
 
             def process(value):
-                if value is not None:
-                    value = (
-                        f"""'{value.replace("-", "").replace("'", "''")}'"""
-                    )
-                return value
+                return f"""'{value.replace("-", "").replace("'", "''")}'"""
 
             return process
         else:
             if character_based_uuid:
 
                 def process(value):
-                    if value is not None:
-                        value = f"""'{value.hex}'"""
-                    return value
+                    return f"""'{value.hex}'"""
 
                 return process
             else:
 
                 def process(value):
-                    if value is not None:
-                        value = f"""'{str(value).replace("'", "''")}'"""
-                    return value
+                    return f"""'{str(value).replace("'", "''")}'"""
 
                 return process
 
index 0a1419f2535a7caf8d45bedcb9ab5ddb457822e5..be405667e11a4fd2336fa9acb1dbbe3fd2fb8e7b 100644 (file)
@@ -82,6 +82,11 @@ class _LiteralRoundTripFixture:
                 )
                 connection.execute(ins)
 
+            ins = t.insert().values(
+                x=literal(None, type_, literal_execute=True)
+            )
+            connection.execute(ins)
+
             if support_whereclause and self.supports_whereclause:
                 if compare:
                     stmt = t.select().where(
@@ -108,7 +113,7 @@ class _LiteralRoundTripFixture:
                         )
                     )
             else:
-                stmt = t.select()
+                stmt = t.select().where(t.c.x.is_not(None))
 
             rows = connection.execute(stmt).all()
             assert rows, "No rows returned"
@@ -118,6 +123,10 @@ class _LiteralRoundTripFixture:
                     value = filter_(value)
                 assert value in output
 
+            stmt = t.select().where(t.c.x.is_(None))
+            rows = connection.execute(stmt).all()
+            eq_(rows, [(None,)])
+
         return run