]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added reflection method :meth:`.Inspector.get_sequence_names`
authorFederico Caselli <cfederico87@gmail.com>
Mon, 13 Apr 2020 10:16:21 +0000 (12:16 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 3 Jun 2020 18:53:47 +0000 (20:53 +0200)
Added new reflection method :meth:`.Inspector.get_sequence_names` which
returns all the sequences defined. Support for this method has been added
to the backend that support :class:`.Sequence`: PostgreSql, Oracle,
MSSQL and MariaDB >= 10.3.

Fixes: #2056
Change-Id: I0949696a39aa28c849edf2504779241f7443778a

doc/build/changelog/unreleased_14/2056.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/suite/test_sequence.py

diff --git a/doc/build/changelog/unreleased_14/2056.rst b/doc/build/changelog/unreleased_14/2056.rst
new file mode 100644 (file)
index 0000000..6c26afe
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: reflection, usecase
+    :tickets: 2056
+
+    Added new reflection method :meth:`.Inspector.get_sequence_names` which
+    returns all the sequences defined and :meth:`.Inspector.has_sequence` to
+    check if a particular sequence exits.
+    Support for this method has been added to the backend that support
+    :class:`.Sequence`: PostgreSQL, Oracle and MariaDB >= 10.3.
index c35ab288079a103cce5dcf0380b7fbbea0524aff..5aaecf23a82aaa4fbaf1745cff9536b793354360 100644 (file)
@@ -2639,6 +2639,19 @@ class MSDialect(default.DefaultDialect):
 
         return c.first() is not None
 
+    @reflection.cache
+    @_db_plus_owner_listing
+    def get_sequence_names(self, connection, dbname, owner, schema, **kw):
+        sequences = ischema.sequences
+
+        s = sql.select([sequences.c.sequence_name])
+        if owner:
+            s = s.where(sequences.c.sequence_schema == owner)
+
+        c = connection.execute(s)
+
+        return [row[0] for row in c]
+
     @reflection.cache
     def get_schema_names(self, connection, **kw):
         s = sql.select(
index d009d656edefac2f1ae4b5bd02efa8a8bd008c56..b34422e65530f02ad49d4539c39cdd8bfa61e2a3 100644 (file)
@@ -2517,6 +2517,8 @@ class MySQLDialect(default.DefaultDialect):
                 rs.close()
 
     def has_sequence(self, connection, sequence_name, schema=None):
+        if not self.supports_sequences:
+            self._sequences_not_supported()
         if not schema:
             schema = self.default_schema_name
         # MariaDB implements sequences as a special type of table
@@ -2524,13 +2526,40 @@ class MySQLDialect(default.DefaultDialect):
         cursor = connection.execute(
             sql.text(
                 "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
-                "WHERE TABLE_NAME=:name AND "
+                "WHERE TABLE_TYPE='SEQUENCE' and TABLE_NAME=:name AND "
                 "TABLE_SCHEMA=:schema_name"
             ),
             dict(name=sequence_name, schema_name=schema),
         )
         return cursor.first() is not None
 
+    def _sequences_not_supported(self):
+        raise NotImplementedError(
+            "Sequences are supported only by the "
+            "MariaDB series 10.3 or greater"
+        )
+
+    @reflection.cache
+    def get_sequence_names(self, connection, schema=None, **kw):
+        if not self.supports_sequences:
+            self._sequences_not_supported()
+        if not schema:
+            schema = self.default_schema_name
+        # MariaDB implements sequences as a special type of table
+        cursor = connection.execute(
+            sql.text(
+                "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
+                "WHERE TABLE_TYPE='SEQUENCE' and TABLE_SCHEMA=:schema_name"
+            ),
+            dict(schema_name=schema),
+        )
+        return [
+            row[0]
+            for row in self._compat_fetchall(
+                cursor, charset=self._connection_charset
+            )
+        ]
+
     def initialize(self, connection):
         self._connection_charset = self._detect_charset(connection)
         self._detect_sql_mode(connection)
index 481ea726333bce93727629a501d8ea32ec2e29f0..5e912a0c2f13edf7ab6a2e7f4193b609f606eae6 100644 (file)
@@ -1634,6 +1634,19 @@ class OracleDialect(default.DefaultDialect):
         )
         return [self.normalize_name(row[0]) for row in cursor]
 
+    @reflection.cache
+    def get_sequence_names(self, connection, schema=None, **kw):
+        if not schema:
+            schema = self.default_schema_name
+        cursor = connection.execute(
+            sql.text(
+                "SELECT sequence_name FROM all_sequences "
+                "WHERE sequence_owner = :schema_name"
+            ),
+            schema_name=self.denormalize_name(schema),
+        )
+        return [self.normalize_name(row[0]) for row in cursor]
+
     @reflection.cache
     def get_table_options(self, connection, table_name, schema=None, **kw):
         options = {}
index 441e77a378eb0dcf61c456c4703a67c45d3e573d..2bfb9b4942d988c1dc4be1cdc14202ac856feb92 100644 (file)
@@ -2682,39 +2682,23 @@ class PGDialect(default.DefaultDialect):
 
     def has_sequence(self, connection, sequence_name, schema=None):
         if schema is None:
-            cursor = connection.execute(
-                sql.text(
-                    "SELECT relname FROM pg_class c join pg_namespace n on "
-                    "n.oid=c.relnamespace where relkind='S' and "
-                    "n.nspname=current_schema() "
-                    "and relname=:name"
-                ).bindparams(
-                    sql.bindparam(
-                        "name",
-                        util.text_type(sequence_name),
-                        type_=sqltypes.Unicode,
-                    )
-                )
-            )
-        else:
-            cursor = connection.execute(
-                sql.text(
-                    "SELECT relname FROM pg_class c join pg_namespace n on "
-                    "n.oid=c.relnamespace where relkind='S' and "
-                    "n.nspname=:schema and relname=:name"
-                ).bindparams(
-                    sql.bindparam(
-                        "name",
-                        util.text_type(sequence_name),
-                        type_=sqltypes.Unicode,
-                    ),
-                    sql.bindparam(
-                        "schema",
-                        util.text_type(schema),
-                        type_=sqltypes.Unicode,
-                    ),
-                )
+            schema = self.default_schema_name
+        cursor = connection.execute(
+            sql.text(
+                "SELECT relname FROM pg_class c join pg_namespace n on "
+                "n.oid=c.relnamespace where relkind='S' and "
+                "n.nspname=:schema and relname=:name"
+            ).bindparams(
+                sql.bindparam(
+                    "name",
+                    util.text_type(sequence_name),
+                    type_=sqltypes.Unicode,
+                ),
+                sql.bindparam(
+                    "schema", util.text_type(schema), type_=sqltypes.Unicode,
+                ),
             )
+        )
 
         return bool(cursor.first())
 
@@ -2870,6 +2854,23 @@ class PGDialect(default.DefaultDialect):
         )
         return [name for name, in result]
 
+    @reflection.cache
+    def get_sequence_names(self, connection, schema=None, **kw):
+        if not schema:
+            schema = self.default_schema_name
+        cursor = connection.execute(
+            sql.text(
+                "SELECT relname FROM pg_class c join pg_namespace n on "
+                "n.oid=c.relnamespace where relkind='S' and "
+                "n.nspname=:schema"
+            ).bindparams(
+                sql.bindparam(
+                    "schema", util.text_type(schema), type_=sqltypes.Unicode,
+                ),
+            )
+        )
+        return [row[0] for row in cursor]
+
     @reflection.cache
     def get_view_definition(self, connection, view_name, schema=None, **kw):
         view_def = connection.scalar(
index 49d9af9661e62112f00fe859da617e1ebc02af2f..59b9cd4ce293167ed92d2f681a4abaf7bdce47cb 100644 (file)
@@ -304,8 +304,17 @@ class Dialect(object):
     def get_view_names(self, connection, schema=None, **kw):
         """Return a list of all view names available in the database.
 
-        schema:
-          Optional, retrieve names from a non-default schema.
+        :param schema: schema name to query, if not the default schema.
+        """
+
+        raise NotImplementedError()
+
+    def get_sequence_names(self, connection, schema=None, **kw):
+        """Return a list of all sequence names available in the database.
+
+        :param schema: schema name to query, if not the default schema.
+
+        .. versionadded:: 1.4
         """
 
         raise NotImplementedError()
index 344d5511d19c54242bb85cfa486f8efb83bf7d12..fded37b2aaec8b7a29f536b64a0907948e2fc5a4 100644 (file)
@@ -255,11 +255,6 @@ class Inspector(object):
          support named schemas, behavior is undefined if ``schema`` is not
          passed as ``None``.  For special quoting, use :class:`.quoted_name`.
 
-        :param order_by: Optional, may be the string "foreign_key" to sort
-         the result on foreign key dependencies.  Does not automatically
-         resolve cycles, and will raise :class:`.CircularDependencyError`
-         if cycles exist.
-
         .. seealso::
 
             :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names`
@@ -276,6 +271,10 @@ class Inspector(object):
     def has_table(self, table_name, schema=None):
         """Return True if the backend has a table of the given name.
 
+
+        :param table_name: name of the table to check
+        :param schema: schema name to query, if not the default schema.
+
         .. versionadded:: 1.4
 
         """
@@ -283,6 +282,19 @@ class Inspector(object):
         with self._operation_context() as conn:
             return self.dialect.has_table(conn, table_name, schema)
 
+    def has_sequence(self, sequence_name, schema=None):
+        """Return True if the backend has a table of the given name.
+
+        :param sequence_name: name of the table to check
+        :param schema: schema name to query, if not the default schema.
+
+        .. versionadded:: 1.4
+
+        """
+        # TODO: info_cache?
+        with self._operation_context() as conn:
+            return self.dialect.has_sequence(conn, sequence_name, schema)
+
     def get_sorted_table_and_fkc_names(self, schema=None):
         """Return dependency-sorted table and foreign key constraint names in
         referred to within a particular schema.
@@ -401,6 +413,19 @@ class Inspector(object):
                 conn, schema, info_cache=self.info_cache
             )
 
+    def get_sequence_names(self, schema=None):
+        """Return all sequence names in `schema`.
+
+        :param schema: Optional, retrieve names from a non-default schema.
+         For special quoting, use :class:`.quoted_name`.
+
+        """
+
+        with self._operation_context() as conn:
+            return self.dialect.get_sequence_names(
+                conn, schema, info_cache=self.info_cache
+            )
+
     def get_view_definition(self, view_name, schema=None):
         """Return definition for `view_name`.
 
index 041daf35e1d8f2e03f9214f50040e53d9478660a..931fa6b7761842458fab5323f357cdb2d9bcc325 100644 (file)
@@ -135,6 +135,7 @@ class TablesTest(TestBase):
     metadata = None
     tables = None
     other = None
+    sequences = None
 
     @classmethod
     def setup_class(cls):
@@ -153,6 +154,7 @@ class TablesTest(TestBase):
 
         cls.other = adict()
         cls.tables = adict()
+        cls.sequences = adict()
 
         cls.bind = cls.setup_bind()
         cls.metadata = sa.MetaData()
@@ -172,6 +174,7 @@ class TablesTest(TestBase):
             if cls.run_create_tables == "once":
                 cls.metadata.create_all(cls.bind)
             cls.tables.update(cls.metadata.tables)
+            cls.sequences.update(cls.metadata._sequences)
 
     def _setup_each_tables(self):
         if self.run_define_tables == "each":
@@ -179,6 +182,7 @@ class TablesTest(TestBase):
             if self.run_create_tables == "each":
                 self.metadata.create_all(self.bind)
             self.tables.update(self.metadata.tables)
+            self.sequences.update(self.metadata._sequences)
         elif self.run_create_tables == "each":
             self.metadata.create_all(self.bind)
 
index dda447c0d0be2313742edf89dff44341f1f04449..55e8e84062f7e3ef2fa3f4e45724d9bd05b68c1f 100644 (file)
@@ -1,12 +1,13 @@
 from .. import config
 from .. import fixtures
 from ..assertions import eq_
+from ..assertions import is_true
 from ..config import requirements
 from ..schema import Column
 from ..schema import Table
+from ... import inspect
 from ... import Integer
 from ... import MetaData
-from ... import schema
 from ... import Sequence
 from ... import String
 from ... import testing
@@ -88,69 +89,108 @@ class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
         )
 
 
-class HasSequenceTest(fixtures.TestBase):
+class HasSequenceTest(fixtures.TablesTest):
+    run_deletes = None
+
     __requires__ = ("sequences",)
     __backend__ = True
 
-    def test_has_sequence(self, connection):
-        s1 = Sequence("user_id_seq")
-        connection.execute(schema.CreateSequence(s1))
-        try:
-            eq_(
-                connection.dialect.has_sequence(connection, "user_id_seq"),
-                True,
+    @classmethod
+    def define_tables(cls, metadata):
+        Sequence("user_id_seq", metadata=metadata)
+        Sequence("other_seq", metadata=metadata)
+        if testing.requires.schemas.enabled:
+            Sequence(
+                "user_id_seq", schema=config.test_schema, metadata=metadata
+            )
+            Sequence(
+                "schema_seq", schema=config.test_schema, metadata=metadata
             )
-        finally:
-            connection.execute(schema.DropSequence(s1))
+        Table(
+            "user_id_table", metadata, Column("id", Integer, primary_key=True),
+        )
+
+    def test_has_sequence(self, connection):
+        eq_(
+            inspect(connection).has_sequence("user_id_seq"), True,
+        )
+
+    def test_has_sequence_other_object(self, connection):
+        eq_(
+            inspect(connection).has_sequence("user_id_table"), False,
+        )
 
     @testing.requires.schemas
     def test_has_sequence_schema(self, connection):
-        s1 = Sequence("user_id_seq", schema=config.test_schema)
-        connection.execute(schema.CreateSequence(s1))
-        try:
-            eq_(
-                connection.dialect.has_sequence(
-                    connection, "user_id_seq", schema=config.test_schema
-                ),
-                True,
-            )
-        finally:
-            connection.execute(schema.DropSequence(s1))
+        eq_(
+            inspect(connection).has_sequence(
+                "user_id_seq", schema=config.test_schema
+            ),
+            True,
+        )
 
     def test_has_sequence_neg(self, connection):
-        eq_(connection.dialect.has_sequence(connection, "user_id_seq"), False)
+        eq_(
+            inspect(connection).has_sequence("some_sequence"), False,
+        )
 
     @testing.requires.schemas
     def test_has_sequence_schemas_neg(self, connection):
         eq_(
-            connection.dialect.has_sequence(
-                connection, "user_id_seq", schema=config.test_schema
+            inspect(connection).has_sequence(
+                "some_sequence", schema=config.test_schema
             ),
             False,
         )
 
     @testing.requires.schemas
     def test_has_sequence_default_not_in_remote(self, connection):
-        s1 = Sequence("user_id_seq")
-        connection.execute(schema.CreateSequence(s1))
-        try:
-            eq_(
-                connection.dialect.has_sequence(
-                    connection, "user_id_seq", schema=config.test_schema
-                ),
-                False,
-            )
-        finally:
-            connection.execute(schema.DropSequence(s1))
+        eq_(
+            inspect(connection).has_sequence(
+                "other_sequence", schema=config.test_schema
+            ),
+            False,
+        )
 
     @testing.requires.schemas
     def test_has_sequence_remote_not_in_default(self, connection):
-        s1 = Sequence("user_id_seq", schema=config.test_schema)
-        connection.execute(schema.CreateSequence(s1))
-        try:
-            eq_(
-                connection.dialect.has_sequence(connection, "user_id_seq"),
-                False,
-            )
-        finally:
-            connection.execute(schema.DropSequence(s1))
+        eq_(
+            inspect(connection).has_sequence("schema_seq"), False,
+        )
+
+    def test_get_sequence_names(self, connection):
+        exp = {"other_seq", "user_id_seq"}
+
+        res = set(inspect(connection).get_sequence_names())
+        is_true(res.intersection(exp) == exp)
+        is_true("schema_seq" not in res)
+
+    @testing.requires.schemas
+    def test_get_sequence_names_no_sequence_schema(self, connection):
+        eq_(
+            inspect(connection).get_sequence_names(
+                schema=config.test_schema_2
+            ),
+            [],
+        )
+
+    @testing.requires.schemas
+    def test_get_sequence_names_sequences_schema(self, connection):
+        eq_(
+            sorted(
+                inspect(connection).get_sequence_names(
+                    schema=config.test_schema
+                )
+            ),
+            ["schema_seq", "user_id_seq"],
+        )
+
+
+class HasSequenceTestEmpty(fixtures.TestBase):
+    __requires__ = ("sequences",)
+    __backend__ = True
+
+    def test_get_sequence_names_no_sequence(self, connection):
+        eq_(
+            inspect(connection).get_sequence_names(), [],
+        )