]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
add test for public schema
authorzhouyizhen <zhouyizhen@metrodata.tech>
Wed, 5 Jun 2024 03:39:02 +0000 (11:39 +0800)
committerzhouyizhen <zhouyizhen@metrodata.tech>
Wed, 5 Jun 2024 03:39:02 +0000 (11:39 +0800)
tests/test_postgresql.py

index 6ce4a7b91037d347784602dc69cd638f714fa721..e42ea9d3b2beb46d352596e4d8534a1d38958d3b 100644 (file)
@@ -859,8 +859,8 @@ class PostgresqlDetectSerialTest(TestBase):
         clear_staging_env()
 
     @provide_metadata
-    def _expect_default(self, c_expected, col, seq=None):
-        Table("t", self.metadata, col, schema="test_schema")
+    def _expect_default(self, c_expected, col, schema=None, seq=None):
+        Table("t", self.metadata, col, schema=schema)
 
         self.autogen_context.metadata = self.metadata
 
@@ -871,9 +871,7 @@ class PostgresqlDetectSerialTest(TestBase):
         insp = inspect(config.db)
 
         uo = ops.UpgradeOps(ops=[])
-        _compare_tables(
-            {("test_schema", "t")}, set(), insp, uo, self.autogen_context
-        )
+        _compare_tables({(schema, "t")}, set(), insp, uo, self.autogen_context)
         diffs = uo.as_diffs()
         tab = diffs[0][1]
 
@@ -886,12 +884,12 @@ class PostgresqlDetectSerialTest(TestBase):
 
         insp = inspect(config.db)
         uo = ops.UpgradeOps(ops=[])
-        m2 = MetaData(schema="test_schema")
+        m2 = MetaData(schema=schema)
         Table("t", m2, Column("x", BigInteger()))
         self.autogen_context.metadata = m2
         _compare_tables(
-            {("test_schema", "t")},
-            {("test_schema", "t")},
+            {(schema, "t")},
+            {(schema, "t")},
             insp,
             uo,
             self.autogen_context,
@@ -905,35 +903,47 @@ class PostgresqlDetectSerialTest(TestBase):
             c_expected,
         )
 
-    def test_serial(self):
-        self._expect_default(None, Column("x", Integer, primary_key=True))
+    @testing.combinations((None,), ("test_schema",))
+    def test_serial(self, schema):
+        self._expect_default(
+            None, Column("x", Integer, primary_key=True), schema
+        )
 
-    def test_separate_seq(self):
-        seq = Sequence("x_id_seq", schema="test_schema")
+    @testing.combinations((None,), ("test_schema",))
+    def test_separate_seq(self, schema):
+        seq = Sequence("x_id_seq", schema=schema)
+        seq_name = seq.name if schema is None else f"{schema}.{seq.name}"
         self._expect_default(
-            "nextval('test_schema.x_id_seq'::regclass)",
+            f"nextval('{seq_name}'::regclass)",
             Column(
                 "x", Integer, server_default=seq.next_value(), primary_key=True
             ),
+            schema,
             seq,
         )
 
-    def test_numeric(self):
-        seq = Sequence("x_id_seq", schema="test_schema")
+    @testing.combinations((None,), ("test_schema",))
+    def test_numeric(self, schema):
+        seq = Sequence("x_id_seq", schema=schema)
+        seq_name = seq.name if schema is None else f"{schema}.{seq.name}"
         self._expect_default(
-            "nextval('test_schema.x_id_seq'::regclass)",
+            f"nextval('{seq_name}'::regclass)",
             Column(
                 "x",
                 Numeric(8, 2),
                 server_default=seq.next_value(),
                 primary_key=True,
             ),
+            schema,
             seq,
         )
 
-    def test_no_default(self):
+    @testing.combinations((None,), ("test_schema",))
+    def test_no_default(self, schema):
         self._expect_default(
-            None, Column("x", Integer, autoincrement=False, primary_key=True)
+            None,
+            Column("x", Integer, autoincrement=False, primary_key=True),
+            schema,
         )