]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Make if_exists and if_not_exists flags on ddl statements match compiler
authorJesse Bakker <github@jessebakker.com>
Tue, 6 Sep 2022 20:00:10 +0000 (16:00 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Tue, 4 Oct 2022 02:40:20 +0000 (02:40 +0000)
Added ``if_exists`` and ``if_not_exists`` parameters for all "Create" /
"Drop" constructs including :class:`.CreateSequence`,
:class:`.DropSequence`, :class:`.CreateIndex`, :class:`.DropIndex`, etc.
allowing generic "IF EXISTS" / "IF NOT EXISTS" phrases to be rendered
within DDL. Pull request courtesy Jesse Bakker.

Fixes: #7354
Closes: #8492
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8492
Pull-request-sha: d107c6ce553bd430111607815f5b3938ffc4770c

Change-Id: I367e57b2d9216f5180bcc44e86ca6f3dc794e5ca

doc/build/changelog/unreleased_20/7354.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
test/sql/test_constraints.py
test/sql/test_metadata.py
test/sql/test_sequences.py

diff --git a/doc/build/changelog/unreleased_20/7354.rst b/doc/build/changelog/unreleased_20/7354.rst
new file mode 100644 (file)
index 0000000..dfbd0e8
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 7354
+
+    Added ``if_exists`` and ``if_not_exists`` parameters for all "Create" /
+    "Drop" constructs including :class:`.CreateSequence`,
+    :class:`.DropSequence`, :class:`.CreateIndex`, :class:`.DropIndex`, etc.
+    allowing generic "IF EXISTS" / "IF NOT EXISTS" phrases to be rendered
+    within DDL. Pull request courtesy Jesse Bakker.
+
index c7e226fcc639de220c0a067c708c50cc03d1dbaa..dd40bfe345505a823ef5dcf1383aa3d190b55c4a 100644 (file)
@@ -5396,12 +5396,16 @@ class DDLCompiler(Compiled):
         return self.sql_compiler.post_process_text(ddl.statement % context)
 
     def visit_create_schema(self, create, **kw):
-        schema = self.preparer.format_schema(create.element)
-        return "CREATE SCHEMA " + schema
+        text = "CREATE SCHEMA "
+        if create.if_not_exists:
+            text += "IF NOT EXISTS "
+        return text + self.preparer.format_schema(create.element)
 
     def visit_drop_schema(self, drop, **kw):
-        schema = self.preparer.format_schema(drop.element)
-        text = "DROP SCHEMA " + schema
+        text = "DROP SCHEMA "
+        if drop.if_exists:
+            text += "IF EXISTS "
+        text += self.preparer.format_schema(drop.element)
         if drop.cascade:
             text += " CASCADE"
         return text
@@ -5650,9 +5654,11 @@ class DDLCompiler(Compiled):
         return " ".join(text)
 
     def visit_create_sequence(self, create, prefix=None, **kw):
-        text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
-            create.element
-        )
+        text = "CREATE SEQUENCE "
+        if create.if_not_exists:
+            text += "IF NOT EXISTS "
+        text += self.preparer.format_sequence(create.element)
+
         if prefix:
             text += prefix
         if create.element.start is None:
@@ -5663,7 +5669,10 @@ class DDLCompiler(Compiled):
         return text
 
     def visit_drop_sequence(self, drop, **kw):
-        return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+        text = "DROP SEQUENCE "
+        if drop.if_exists:
+            text += "IF EXISTS "
+        return text + self.preparer.format_sequence(drop.element)
 
     def visit_drop_constraint(self, drop, **kw):
         constraint = drop.element
index 3c7c674f5082cd966f8e21535e587f4e6c71b45c..e744f0c1de3c733ed9cda3b2ac13f3589d751ed2 100644 (file)
@@ -40,6 +40,7 @@ if typing.TYPE_CHECKING:
     from .schema import Constraint
     from .schema import ForeignKeyConstraint
     from .schema import SchemaItem
+    from .schema import Sequence
     from .schema import Table
     from ..engine.base import _CompiledCacheType
     from ..engine.base import Connection
@@ -434,12 +435,8 @@ class _CreateDropBase(ExecutableDDLElement):
     def __init__(
         self,
         element,
-        if_exists=False,
-        if_not_exists=False,
     ):
         self.element = self.target = element
-        self.if_exists = if_exists
-        self.if_not_exists = if_not_exists
         self._ddl_if = getattr(element, "_ddl_if", None)
 
     @property
@@ -457,7 +454,19 @@ class _CreateDropBase(ExecutableDDLElement):
         return False
 
 
-class CreateSchema(_CreateDropBase):
+class _CreateBase(_CreateDropBase):
+    def __init__(self, element, if_not_exists=False):
+        super().__init__(element)
+        self.if_not_exists = if_not_exists
+
+
+class _DropBase(_CreateDropBase):
+    def __init__(self, element, if_exists=False):
+        super().__init__(element)
+        self.if_exists = if_exists
+
+
+class CreateSchema(_CreateBase):
     """Represent a CREATE SCHEMA statement.
 
     The argument here is the string name of the schema.
@@ -469,19 +478,14 @@ class CreateSchema(_CreateDropBase):
     def __init__(
         self,
         name,
-        quote=None,
-        if_exists=False,
         if_not_exists=False,
     ):
         """Create a new :class:`.CreateSchema` construct."""
 
-        self.quote = quote
-        self.element = name
-        self.if_exists = if_exists
-        self.if_not_exists = if_not_exists
+        super().__init__(element=name, if_not_exists=if_not_exists)
 
 
-class DropSchema(_CreateDropBase):
+class DropSchema(_DropBase):
     """Represent a DROP SCHEMA statement.
 
     The argument here is the string name of the schema.
@@ -493,22 +497,16 @@ class DropSchema(_CreateDropBase):
     def __init__(
         self,
         name,
-        quote=None,
         cascade=False,
         if_exists=False,
-        if_not_exists=False,
     ):
         """Create a new :class:`.DropSchema` construct."""
 
-        self.quote = quote
+        super().__init__(element=name, if_exists=if_exists)
         self.cascade = cascade
-        self.quote = quote
-        self.element = name
-        self.if_exists = if_exists
-        self.if_not_exists = if_not_exists
 
 
-class CreateTable(_CreateDropBase):
+class CreateTable(_CreateBase):
     """Represent a CREATE TABLE statement."""
 
     __visit_name__ = "create_table"
@@ -544,7 +542,7 @@ class CreateTable(_CreateDropBase):
         self.include_foreign_key_constraints = include_foreign_key_constraints
 
 
-class _DropView(_CreateDropBase):
+class _DropView(_DropBase):
     """Semi-public 'DROP VIEW' construct.
 
     Used by the test suite for dialect-agnostic drops of views.
@@ -669,7 +667,7 @@ class CreateColumn(BaseDDLElement):
         self.element = element
 
 
-class DropTable(_CreateDropBase):
+class DropTable(_DropBase):
     """Represent a DROP TABLE statement."""
 
     __visit_name__ = "drop_table"
@@ -689,19 +687,25 @@ class DropTable(_CreateDropBase):
         super().__init__(element, if_exists=if_exists)
 
 
-class CreateSequence(_CreateDropBase):
+class CreateSequence(_CreateBase):
     """Represent a CREATE SEQUENCE statement."""
 
     __visit_name__ = "create_sequence"
 
+    def __init__(self, element: Sequence, if_not_exists: bool = False):
+        super().__init__(element, if_not_exists=if_not_exists)
+
 
-class DropSequence(_CreateDropBase):
+class DropSequence(_DropBase):
     """Represent a DROP SEQUENCE statement."""
 
     __visit_name__ = "drop_sequence"
 
+    def __init__(self, element: Sequence, if_exists: bool = False):
+        super().__init__(element, if_exists=if_exists)
+
 
-class CreateIndex(_CreateDropBase):
+class CreateIndex(_CreateBase):
     """Represent a CREATE INDEX statement."""
 
     __visit_name__ = "create_index"
@@ -711,7 +715,6 @@ class CreateIndex(_CreateDropBase):
 
         :param element: a :class:`_schema.Index` that's the subject
          of the CREATE.
-        :param on: See the description for 'on' in :class:`.DDL`.
         :param if_not_exists: if True, an IF NOT EXISTS operator will be
          applied to the construct.
 
@@ -721,7 +724,7 @@ class CreateIndex(_CreateDropBase):
         super().__init__(element, if_not_exists=if_not_exists)
 
 
-class DropIndex(_CreateDropBase):
+class DropIndex(_DropBase):
     """Represent a DROP INDEX statement."""
 
     __visit_name__ = "drop_index"
@@ -731,7 +734,6 @@ class DropIndex(_CreateDropBase):
 
         :param element: a :class:`_schema.Index` that's the subject
          of the DROP.
-        :param on: See the description for 'on' in :class:`.DDL`.
         :param if_exists: if True, an IF EXISTS operator will be applied to the
          construct.
 
@@ -741,26 +743,26 @@ class DropIndex(_CreateDropBase):
         super().__init__(element, if_exists=if_exists)
 
 
-class AddConstraint(_CreateDropBase):
+class AddConstraint(_CreateBase):
     """Represent an ALTER TABLE ADD CONSTRAINT statement."""
 
     __visit_name__ = "add_constraint"
 
-    def __init__(self, element, *args, **kw):
-        super().__init__(element, *args, **kw)
+    def __init__(self, element):
+        super().__init__(element)
         element._create_rule = util.portable_instancemethod(
             self._create_rule_disable
         )
 
 
-class DropConstraint(_CreateDropBase):
+class DropConstraint(_DropBase):
     """Represent an ALTER TABLE DROP CONSTRAINT statement."""
 
     __visit_name__ = "drop_constraint"
 
-    def __init__(self, element, cascade=False, **kw):
+    def __init__(self, element, cascade=False, if_exists=False, **kw):
         self.cascade = cascade
-        super().__init__(element, **kw)
+        super().__init__(element, if_exists=if_exists, **kw)
         element._create_rule = util.portable_instancemethod(
             self._create_rule_disable
         )
index 462667bedcedb875894b6ff929d3a78218697424..b1b731d66f5b5fd80c7817d8c5526d39ec959cdd 100644 (file)
@@ -765,6 +765,14 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL):
         i = Index("xyz", t.c.x)
         self.assert_compile(schema.CreateIndex(i), "CREATE INDEX xyz ON t (x)")
 
+    def test_create_index_if_not_exists(self):
+        t = Table("t", MetaData(), Column("x", Integer))
+        i = Index("xyz", t.c.x)
+        self.assert_compile(
+            schema.CreateIndex(i, if_not_exists=True),
+            "CREATE INDEX IF NOT EXISTS xyz ON t (x)",
+        )
+
     def test_drop_index_plain_unattached(self):
         self.assert_compile(
             schema.DropIndex(Index(name="xyz")), "DROP INDEX xyz"
@@ -775,6 +783,12 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL):
             schema.DropIndex(Index(name="xyz")), "DROP INDEX xyz"
         )
 
+    def test_drop_index_if_exists(self):
+        self.assert_compile(
+            schema.DropIndex(Index(name="xyz"), if_exists=True),
+            "DROP INDEX IF EXISTS xyz",
+        )
+
     def test_create_index_schema(self):
         t = Table("t", MetaData(), Column("x", Integer), schema="foo")
         i = Index("xyz", t.c.x)
index 38255f9775a3e51d1781495b87810982c3cc477f..6d93cb234e547e59f8d51758314504d073076e45 100644 (file)
@@ -2620,9 +2620,17 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             schema.CreateSchema("sa_schema"), "CREATE SCHEMA sa_schema"
         )
+        self.assert_compile(
+            schema.CreateSchema("sa_schema", if_not_exists=True),
+            "CREATE SCHEMA IF NOT EXISTS sa_schema",
+        )
         self.assert_compile(
             schema.DropSchema("sa_schema"), "DROP SCHEMA sa_schema"
         )
+        self.assert_compile(
+            schema.DropSchema("sa_schema", if_exists=True),
+            "DROP SCHEMA IF EXISTS sa_schema",
+        )
         self.assert_compile(
             schema.DropSchema("sa_schema", cascade=True),
             "DROP SCHEMA sa_schema CASCADE",
index 19f95c66197f15fc89c1a0ee27aaec9ba5a9967c..457aeb960b986ad03e20c846a86db214d7e200b3 100644 (file)
@@ -92,10 +92,20 @@ class SequenceDDLTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "CREATE SEQUENCE foo_seq START WITH 1 ORDER",
         )
 
+        self.assert_compile(
+            CreateSequence(Sequence("foo_seq"), if_not_exists=True),
+            "CREATE SEQUENCE IF NOT EXISTS foo_seq START WITH 1",
+        )
+
         self.assert_compile(
             DropSequence(Sequence("foo_seq")), "DROP SEQUENCE foo_seq"
         )
 
+        self.assert_compile(
+            DropSequence(Sequence("foo_seq"), if_exists=True),
+            "DROP SEQUENCE IF EXISTS foo_seq",
+        )
+
 
 class SequenceExecTest(fixtures.TestBase):
     __requires__ = ("sequences",)