]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Handle case where InstanceState.obj returns None
authorFederico Caselli <cfederico87@gmail.com>
Tue, 6 Oct 2020 19:21:46 +0000 (21:21 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 7 Oct 2020 04:13:35 +0000 (00:13 -0400)
Fixed bug where a call ``InstanceState.obj()`` could return None when
synchronizing the instance states of the objects in the session in case
they become out of scope but are not yet finalized by the gc. This
case does not happen in cPython, but it may present itself in pypy.

The approach is to allow None to be gracefully handled by the
evaluator itself, ensuring it returns None in all cases when None
is passed in.

Fixes: #5631
Change-Id: I53d38fbea2e72b2e677c6e7f70bf075a58e58945

lib/sqlalchemy/orm/evaluator.py
test/orm/test_evaluator.py

index f7f12ce127210aee7340b8bbc3ea68dbb5c0792e..23c48329da8ebb22f3bfa10839d6c777df9f8d7e 100644 (file)
@@ -17,6 +17,16 @@ class UnevaluatableError(Exception):
     pass
 
 
+class _NoObject(operators.ColumnOperators):
+    def operate(self, *arg, **kw):
+        return None
+
+    def reverse_operate(self, *arg, **kw):
+        return None
+
+
+_NO_OBJECT = _NoObject()
+
 _straight_ops = set(
     getattr(operators, op)
     for op in (
@@ -36,8 +46,10 @@ _straight_ops = set(
 )
 
 _extended_ops = {
-    operators.in_op: (lambda a, b: a in b),
-    operators.not_in_op: (lambda a, b: a not in b),
+    operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
+    operators.not_in_op: (
+        lambda a, b: a not in b if a is not _NO_OBJECT else None
+    ),
 }
 
 _notimplemented_ops = set(
@@ -111,7 +123,11 @@ class EvaluatorCompiler(object):
                 raise UnevaluatableError("Cannot evaluate column: %s" % clause)
 
         get_corresponding_attr = operator.attrgetter(key)
-        return lambda obj: get_corresponding_attr(obj)
+        return (
+            lambda obj: get_corresponding_attr(obj)
+            if obj is not None
+            else _NO_OBJECT
+        )
 
     def visit_tuple(self, clause):
         return self.visit_clauselist(clause)
@@ -137,7 +153,7 @@ class EvaluatorCompiler(object):
                 for sub_evaluate in evaluators:
                     value = sub_evaluate(obj)
                     if not value:
-                        if value is None:
+                        if value is None or value is _NO_OBJECT:
                             return None
                         return False
                 return True
@@ -148,7 +164,7 @@ class EvaluatorCompiler(object):
                 values = []
                 for sub_evaluate in evaluators:
                     value = sub_evaluate(obj)
-                    if value is None:
+                    if value is None or value is _NO_OBJECT:
                         return None
                     values.append(value)
                 return tuple(values)
index a6c889aa742919b81cbfbc48e93bb1da6fcae559..955e5134fc00af03ecf3a43be2130113ced90cbc 100644 (file)
@@ -100,7 +100,11 @@ class EvaluateTest(fixtures.MappedTest):
 
         eval_eq(
             User.name == None,  # noqa
-            testcases=[(User(name="foo"), False), (User(name=None), True)],
+            testcases=[
+                (User(name="foo"), False),
+                (User(name=None), True),
+                (None, None),
+            ],
         )
 
     def test_warn_on_unannotated_matched_column(self):
@@ -144,6 +148,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(name="foo"), False),
                 (User(name=True), False),
                 (User(name=False), True),
+                (None, None),
             ],
         )
 
@@ -153,6 +158,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(name="foo"), False),
                 (User(name=True), True),
                 (User(name=False), False),
+                (None, None),
             ],
         )
 
@@ -167,6 +173,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(id=1, name="bar"), False),
                 (User(id=2, name="bar"), False),
                 (User(id=1, name=None), None),
+                (None, None),
             ],
         )
 
@@ -179,6 +186,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(id=2, name="bar"), False),
                 (User(id=1, name=None), True),
                 (User(id=2, name=None), None),
+                (None, None),
             ],
         )
 
@@ -201,6 +209,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(id=2, name="bat"), False),
                 (User(id=1, name="bar"), True),
                 (User(id=1, name=None), None),
+                (None, None),
             ],
         )
 
@@ -211,6 +220,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(id=2, name="bat"), True),
                 (User(id=1, name="bar"), False),
                 (User(id=1, name=None), None),
+                (None, None),
             ],
         )
 
@@ -225,6 +235,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(id=1, name="bar"), False),
                 (User(id=2, name="bar"), True),
                 (User(id=1, name=None), None),
+                (None, None),
             ],
         )
 
@@ -236,6 +247,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(id=1, name="bar"), True),
                 (User(id=2, name="bar"), False),
                 (User(id=1, name=None), None),
+                (None, None),
             ],
         )
 
@@ -251,6 +263,7 @@ class EvaluateTest(fixtures.MappedTest):
                 (User(id=2, name="bar"), True),
                 (User(id=None, name="foo"), None),
                 (User(id=None, name=None), None),
+                (None, None),
             ],
         )