]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support reflecting no inherit check constraint in pg.
authorEllis Valentiner <ellisvalentiner@gmail.com>
Mon, 8 Jan 2024 16:16:21 +0000 (11:16 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Jan 2024 16:58:03 +0000 (11:58 -0500)
Added support for reflection of PostgreSQL CHECK constraints marked with
"NO INHERIT", setting the key ``no_inherit=True`` in the reflected data.
Pull request courtesy Ellis Valentiner.

Fixes: #10777
Closes: #10778
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10778
Pull-request-sha: 058082ff6297f9ccdc4977e65ef024e9a093426e

Change-Id: Ia33e29c0c57cf0076e8819311f4628d712fdc332

doc/build/changelog/unreleased_20/10777.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/10777.rst b/doc/build/changelog/unreleased_20/10777.rst
new file mode 100644 (file)
index 0000000..cee5092
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, postgresql, reflection
+    :tickets: 10777
+
+    Added support for reflection of PostgreSQL CHECK constraints marked with
+    "NO INHERIT", setting the key ``no_inherit=True`` in the reflected data.
+    Pull request courtesy Ellis Valentiner.
index a7cd0ca8293014f8324e991eee5ea33c47d59c9b..c39e8be75cfd9a5975e49bc535f4100880992fe7 100644 (file)
@@ -4696,9 +4696,13 @@ class PGDialect(default.DefaultDialect):
             # "CHECK (((a > 1) AND (a < 5))) NOT VALID"
             # "CHECK (some_boolean_function(a))"
             # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)"
+            # "CHECK (a NOT NULL) NO INHERIT"
+            # "CHECK (a NOT NULL) NO INHERIT NOT VALID"
 
             m = re.match(
-                r"^CHECK *\((.+)\)( NOT VALID)?$", src, flags=re.DOTALL
+                r"^CHECK *\((.+)\)( NO INHERIT)?( NOT VALID)?$",
+                src,
+                flags=re.DOTALL,
             )
             if not m:
                 util.warn("Could not parse CHECK constraint text: %r" % src)
@@ -4712,8 +4716,14 @@ class PGDialect(default.DefaultDialect):
                 "sqltext": sqltext,
                 "comment": comment,
             }
-            if m and m.group(2):
-                entry["dialect_options"] = {"not_valid": True}
+            if m:
+                do = {}
+                if " NOT VALID" in m.groups():
+                    do["not_valid"] = True
+                if " NO INHERIT" in m.groups():
+                    do["no_inherit"] = True
+                if do:
+                    entry["dialect_options"] = do
 
             check_constraints[(schema, table_name)].append(entry)
         return check_constraints.items()
index ab4fa2c038d78a57e04c0e9e8f3c7f5317dbf4dc..dd6c8aa88ee35567b0a5fa9e1560234322456311 100644 (file)
@@ -2197,6 +2197,42 @@ class ReflectionTest(
             ],
         )
 
+    def test_reflect_with_no_inherit_check_constraint(self):
+        rows = [
+            ("foo", "some name", "CHECK ((a IS NOT NULL)) NO INHERIT", None),
+            (
+                "foo",
+                "some name",
+                "CHECK ((a IS NOT NULL)) NO INHERIT NOT VALID",
+                None,
+            ),
+        ]
+        conn = mock.Mock(
+            execute=lambda *arg, **kw: mock.MagicMock(
+                fetchall=lambda: rows, __iter__=lambda self: iter(rows)
+            )
+        )
+        check_constraints = testing.db.dialect.get_check_constraints(
+            conn, "foo"
+        )
+        eq_(
+            check_constraints,
+            [
+                {
+                    "name": "some name",
+                    "sqltext": "a IS NOT NULL",
+                    "dialect_options": {"no_inherit": True},
+                    "comment": None,
+                },
+                {
+                    "name": "some name",
+                    "sqltext": "a IS NOT NULL",
+                    "dialect_options": {"not_valid": True, "no_inherit": True},
+                    "comment": None,
+                },
+            ],
+        )
+
     def _apply_stm(self, connection, use_map):
         if use_map:
             return connection.execution_options(