]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Deannoate functions before matching .__class__
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2021 05:18:06 +0000 (01:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2021 05:19:32 +0000 (01:19 -0400)
Fixed regression where the SQL compilation of a :class:`.Function` would
not work correctly if the object had been "annotated", which is an internal
memoization process used mostly by the ORM. In particular it could affect
ORM lazy loads which make greater use of this feature in 1.4.

Fixes: #6095
Change-Id: I7a6527df651f440a04d911ba78ee0b0dd4436dcd

doc/build/changelog/unreleased_14/6095.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/fixtures.py
test/orm/test_lazy_relations.py
test/sql/test_functions.py

diff --git a/doc/build/changelog/unreleased_14/6095.rst b/doc/build/changelog/unreleased_14/6095.rst
new file mode 100644 (file)
index 0000000..bcf44b1
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, regression, orm
+    :tickets: 6095
+
+    Fixed regression where the SQL compilation of a :class:`.Function` would
+    not work correctly if the object had been "annotated", which is an internal
+    memoization process used mostly by the ORM. In particular it could affect
+    ORM lazy loads which make greater use of this feature in 1.4.
index 0ea251fb4b217a28d86b418a76670b4b2deed8b0..2762091039314d4b767d0d4cf170f36b761366eb 100644 (file)
@@ -1720,7 +1720,7 @@ class SQLCompiler(Compiled):
         if disp:
             text = disp(func, **kwargs)
         else:
-            name = FUNCTIONS.get(func.__class__, None)
+            name = FUNCTIONS.get(func._deannotate().__class__, None)
             if name:
                 if func._has_args:
                     name += "%(expr)s"
index 95dce02a9d4b2687725e091c15695b8c54933355..f47277b4aea4f2042ccab1be6edf724544f73029 100644 (file)
@@ -71,6 +71,12 @@ class TestBase(object):
         # run a close all connections.
         conn.close()
 
+    @config.fixture()
+    def registry(self, metadata):
+        reg = registry(metadata=metadata)
+        yield reg
+        reg.dispose()
+
     @config.fixture()
     def future_connection(self, future_engine, connection):
         # integrate the future_engine and connection fixtures so
index 154e119529aa14dd6706a47b2790b3e32bdd1365..1cbe1060620bbad0b16b34dd3fac7564608448d9 100644 (file)
@@ -6,8 +6,10 @@ import sqlalchemy as sa
 from sqlalchemy import and_
 from sqlalchemy import bindparam
 from sqlalchemy import Boolean
+from sqlalchemy import Date
 from sqlalchemy import ForeignKey
 from sqlalchemy import ForeignKeyConstraint
+from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import orm
 from sqlalchemy import select
@@ -25,6 +27,7 @@ from sqlalchemy.orm import Session
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing.assertsql import CompiledSQL
@@ -761,6 +764,43 @@ class LazyTest(_fixtures.FixtureTest):
 
             eq_(a1.user.id, 8)
 
+    @testing.only_on("sqlite")
+    def test_annotated_fn_criteria(self, registry, connection):
+        """this test is a secondary test for the compilation of functions
+        that are annotated.
+
+        """
+
+        @registry.mapped
+        class A(object):
+            __tablename__ = "a"
+
+            id = Column(Integer, primary_key=True)
+            _date = Column(Date, default=func.current_date())
+            b_id = Column(Integer, ForeignKey("b.id"))
+            b = relationship("B")
+
+        @registry.mapped
+        class B(object):
+            __tablename__ = "b"
+
+            id = Column(Integer, primary_key=True)
+            a_s = relationship(
+                "A",
+                primaryjoin="and_(B.id == A.b_id, "
+                "A._date >= func.current_date())",
+                viewonly=True,
+            )
+
+        registry.metadata.create_all(connection)
+        with Session(connection) as sess:
+            b1 = B(id=1)
+            a1 = A(b=b1)
+            sess.add_all([a1, b1])
+            sess.commit()
+
+            is_(sess.get(B, 1).a_s[0], a1)
+
     def test_uses_get_compatible_types(self):
         """test the use_get optimization with compatible
         but non-identical types"""
index 96e0a91291b593670302a0605027bc53f9a16bea..c5aca5d7fa87144bf17207e568c3c2472130ed26 100644 (file)
@@ -256,6 +256,13 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         fn = func.coalesce("x", "y")._annotate({"foo": "bar"})
         self.assert_compile(fn, "coalesce(:coalesce_1, :coalesce_2)")
 
+    def test_annotation_dialect_specific(self):
+        fn = func.current_date()
+        self.assert_compile(fn, "CURRENT_DATE", dialect="sqlite")
+
+        fn = fn._annotate({"foo": "bar"})
+        self.assert_compile(fn, "CURRENT_DATE", dialect="sqlite")
+
     def test_custom_default_namespace(self):
         class myfunc(GenericFunction):
             pass