]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Apply dialect_options copy fix
authorGord Thompson <gord@gordthompson.com>
Sat, 6 Jun 2020 16:04:34 +0000 (10:04 -0600)
committergordthompson <gord@gordthompson.com>
Thu, 18 Jun 2020 18:05:52 +0000 (12:05 -0600)
Fixes: #5276
Change-Id: Ic608310d4a85934fc9fa4d72daef66323c6e2525

doc/build/changelog/unreleased_13/5276.rst [new file with mode: 0644]
lib/sqlalchemy/sql/schema.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_13/5276.rst b/doc/build/changelog/unreleased_13/5276.rst
new file mode 100644 (file)
index 0000000..d7c05d7
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, schema
+    :tickets: 5276
+
+    Fixed issue where ``dialect_options`` were omitted when a
+    database object (e.g., :class:`.Table`) was copied using
+    :func:`.tometadata`.
\ No newline at end of file
index 29ca81d26d27691334cafc3c9501510470a251ce..f6d8cfb1ceb908c4ae9837f2b0edeee8471891b2 100644 (file)
@@ -1556,6 +1556,18 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
             c.copy(**kw) for c in self.constraints if not c._type_bound
         ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint]
 
+        # ticket #5276
+        column_kwargs = {}
+        for dialect_name in self.dialect_options:
+            dialect_options = self.dialect_options[dialect_name]._non_defaults
+            for (
+                dialect_option_key,
+                dialect_option_value,
+            ) in dialect_options.items():
+                column_kwargs[
+                    dialect_name + "_" + dialect_option_key
+                ] = dialect_option_value
+
         server_default = self.server_default
         server_onupdate = self.server_onupdate
         if isinstance(server_default, Computed):
@@ -1574,7 +1586,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
             nullable=self.nullable,
             unique=self.unique,
             system=self.system,
-            # quote=self.quote,
+            # quote=self.quote,  # disabled 2013-08-27 (commit 031ef080)
             index=self.index,
             autoincrement=self.autoincrement,
             default=self.default,
@@ -1583,7 +1595,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
             server_onupdate=server_onupdate,
             doc=self.doc,
             comment=self.comment,
-            *args
+            *args,
+            **column_kwargs
         )
         return self._schema_item_copy(c)
 
@@ -2955,11 +2968,24 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
         return x in self.columns
 
     def copy(self, **kw):
+        # ticket #5276
+        constraint_kwargs = {}
+        for dialect_name in self.dialect_options:
+            dialect_options = self.dialect_options[dialect_name]._non_defaults
+            for (
+                dialect_option_key,
+                dialect_option_value,
+            ) in dialect_options.items():
+                constraint_kwargs[
+                    dialect_name + "_" + dialect_option_key
+                ] = dialect_option_value
+
         c = self.__class__(
             name=self.name,
             deferrable=self.deferrable,
             initially=self.initially,
-            *self.columns.keys()
+            *self.columns.keys(),
+            **constraint_kwargs
         )
         return self._schema_item_copy(c)
 
index 2145e72d09c9de619f569399de7c7d760800d716..4351e562ea1b3782e1b6fdd1b9bf7250e4f3ba3e 100644 (file)
@@ -5213,3 +5213,82 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             CreateIndex(ix), "CREATE INDEX ix_t_q ON t (q + 5)"
         )
+
+
+class CopyDialectOptionsTest(fixtures.TestBase):
+    @contextmanager
+    def _fixture(self):
+        from sqlalchemy.engine.default import DefaultDialect
+
+        class CopyDialectOptionsTestDialect(DefaultDialect):
+            construct_arguments = [
+                (Table, {"some_table_arg": None}),
+                (Column, {"some_column_arg": None}),
+                (Index, {"some_index_arg": None}),
+                (PrimaryKeyConstraint, {"some_pk_arg": None}),
+                (UniqueConstraint, {"some_uq_arg": None}),
+            ]
+
+        def load(dialect_name):
+            if dialect_name == "copydialectoptionstest":
+                return CopyDialectOptionsTestDialect
+            else:
+                raise exc.NoSuchModuleError("no dialect %r" % dialect_name)
+
+        with mock.patch("sqlalchemy.dialects.registry.load", load):
+            yield
+
+    @classmethod
+    def check_dialect_options_(cls, t):
+        eq_(
+            t.dialect_kwargs["copydialectoptionstest_some_table_arg"], "a1",
+        )
+        eq_(
+            t.c.foo.dialect_kwargs["copydialectoptionstest_some_column_arg"],
+            "a2",
+        )
+        eq_(
+            t.primary_key.dialect_kwargs["copydialectoptionstest_some_pk_arg"],
+            "a3",
+        )
+        eq_(
+            list(t.indexes)[0].dialect_kwargs[
+                "copydialectoptionstest_some_index_arg"
+            ],
+            "a4",
+        )
+        eq_(
+            list(c for c in t.constraints if isinstance(c, UniqueConstraint))[
+                0
+            ].dialect_kwargs["copydialectoptionstest_some_uq_arg"],
+            "a5",
+        )
+
+    def test_dialect_options_are_copied(self):
+        with self._fixture():
+            t1 = Table(
+                "t",
+                MetaData(),
+                Column(
+                    "foo",
+                    Integer,
+                    copydialectoptionstest_some_column_arg="a2",
+                ),
+                Column("bar", Integer),
+                PrimaryKeyConstraint(
+                    "foo", copydialectoptionstest_some_pk_arg="a3"
+                ),
+                UniqueConstraint(
+                    "bar", copydialectoptionstest_some_uq_arg="a5"
+                ),
+                copydialectoptionstest_some_table_arg="a1",
+            )
+            Index(
+                "idx", t1.c.foo, copydialectoptionstest_some_index_arg="a4",
+            )
+
+            self.check_dialect_options_(t1)
+
+            m2 = MetaData()
+            t2 = t1.tometadata(m2)  # make a copy
+            self.check_dialect_options_(t2)