]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Remove usage of no longer needed compat code
authorCaselIT <cfederico87@gmail.com>
Tue, 23 Nov 2021 20:51:23 +0000 (21:51 +0100)
committerCaselIT <cfederico87@gmail.com>
Tue, 23 Nov 2021 20:52:10 +0000 (21:52 +0100)
Change-Id: I3180931673496260614e69e95f7da09d68b51714

14 files changed:
alembic/autogenerate/render.py
alembic/config.py
alembic/ddl/impl.py
alembic/ddl/postgresql.py
alembic/operations/schemaobj.py
alembic/script/revision.py
alembic/testing/assertions.py
alembic/testing/fixtures.py
alembic/testing/schemacompare.py
alembic/util/compat.py
alembic/util/langhelpers.py
alembic/util/messaging.py
alembic/util/sqla_compat.py
tests/test_script_consumption.py

index b8226f7a7007bcbf00fc59bb3adcda419af69847..77f0a8666edd62d1180138e84cc34befe7609bb6 100644 (file)
@@ -18,9 +18,7 @@ from sqlalchemy.sql.elements import conv
 
 from .. import util
 from ..operations import ops
-from ..util import compat
 from ..util import sqla_compat
-from ..util.compat import string_types
 
 if TYPE_CHECKING:
     from typing import Literal
@@ -567,8 +565,8 @@ def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]:
     if name is None:
         return name
     elif isinstance(name, sql.elements.quoted_name):
-        return compat.text_type(name)
-    elif isinstance(name, compat.string_types):
+        return str(name)
+    elif isinstance(name, str):
         return name
 
 
@@ -757,14 +755,14 @@ def _render_server_default(
     elif sqla_compat._server_default_is_identity(default):
         return _render_identity(cast("Identity", default), autogen_context)
     elif isinstance(default, sa_schema.DefaultClause):
-        if isinstance(default.arg, compat.string_types):
+        if isinstance(default.arg, str):
             default = default.arg
         else:
             return _render_potential_expr(
                 default.arg, autogen_context, is_server_default=True
             )
 
-    if isinstance(default, string_types) and repr_:
+    if isinstance(default, str) and repr_:
         default = repr(re.sub(r"^'|'$", "", default))
 
     return cast(str, default)
@@ -1109,7 +1107,7 @@ def _render_check_constraint(
 def _execute_sql(
     autogen_context: "AutogenContext", op: "ops.ExecuteSQLOp"
 ) -> str:
-    if not isinstance(op.sqltext, string_types):
+    if not isinstance(op.sqltext, str):
         raise NotImplementedError(
             "Autogenerate rendering of SQL Expression language constructs "
             "not supported here; please use a plain SQL string"
index 273acbb325d4afa798559eb8cccd01a4f069323b..f868bf7375a918beb55ddc1b8b5f3ecac6666ff9 100644 (file)
@@ -167,9 +167,9 @@ class Config:
         """
 
         if arg:
-            output = compat.text_type(text) % arg
+            output = str(text) % arg
         else:
-            output = compat.text_type(text)
+            output = str(text)
 
         util.write_outstream(self.stdout, output, "\n")
 
index 2ca316c7f47226293369486b10d654a5e34dea64..10dcc7344cf4acf5fed16652731ed629b3c0afef 100644 (file)
@@ -19,8 +19,6 @@ from sqlalchemy import text
 from . import base
 from .. import util
 from ..util import sqla_compat
-from ..util.compat import string_types
-from ..util.compat import text_type
 
 if TYPE_CHECKING:
     from io import StringIO
@@ -124,7 +122,7 @@ class DefaultImpl(metaclass=ImplMeta):
 
     def static_output(self, text: str) -> None:
         assert self.output_buffer is not None
-        self.output_buffer.write(text_type(text + "\n\n"))
+        self.output_buffer.write(text + "\n\n")
         self.output_buffer.flush()
 
     def requires_recreate_in_batch(
@@ -162,7 +160,7 @@ class DefaultImpl(metaclass=ImplMeta):
         multiparams: Sequence[dict] = (),
         params: Dict[str, int] = util.immutabledict(),
     ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
-        if isinstance(construct, string_types):
+        if isinstance(construct, str):
             construct = text(construct)
         if self.as_sql:
             if multiparams or params:
@@ -177,9 +175,7 @@ class DefaultImpl(metaclass=ImplMeta):
                 compile_kw = {}
 
             self.static_output(
-                text_type(
-                    construct.compile(dialect=self.dialect, **compile_kw)
-                )
+                str(construct.compile(dialect=self.dialect, **compile_kw))
                 .replace("\t", "    ")
                 .strip()
                 + self.command_terminator
@@ -554,8 +550,8 @@ class DefaultImpl(metaclass=ImplMeta):
 
     def correct_for_autogen_constraints(
         self,
-        conn_uniques: Union[Set["UniqueConstraint"]],
-        conn_indexes: Union[Set["Index"]],
+        conn_uniques: Set["UniqueConstraint"],
+        conn_indexes: Set["Index"],
         metadata_unique_constraints: Set["UniqueConstraint"],
         metadata_indexes: Set["Index"],
     ) -> None:
@@ -580,7 +576,7 @@ class DefaultImpl(metaclass=ImplMeta):
         compile_kw = dict(
             compile_kwargs={"literal_binds": True, "include_table": False}
         )
-        return text_type(expr.compile(dialect=self.dialect, **compile_kw))
+        return str(expr.compile(dialect=self.dialect, **compile_kw))
 
     def _compat_autogen_column_reflect(
         self, inspector: "Inspector"
index 9fb9ac980582c1b35c25e8f8af9f8e12adab0aac..6174f382a05026e872cb82e1c2e68464cf7c3c4a 100644 (file)
@@ -38,7 +38,6 @@ from ..operations import ops
 from ..operations import schemaobj
 from ..operations.base import BatchOperations
 from ..operations.base import Operations
-from ..util import compat
 from ..util import sqla_compat
 
 if TYPE_CHECKING:
@@ -118,9 +117,7 @@ class PostgresqlImpl(DefaultImpl):
         if (
             not isinstance(inspector_column.type, Numeric)
             and metadata_column.server_default is not None
-            and isinstance(
-                metadata_column.server_default.arg, compat.string_types
-            )
+            and isinstance(metadata_column.server_default.arg, str)
             and not re.match(r"^'.*'$", rendered_metadata_default)
         ):
             rendered_metadata_default = "'%s'" % rendered_metadata_default
index 3bff50837a63f255b095858394cc31259de23ede..c8fab933929fd4bfd7b874a31f79eb6e0b984fc1 100644 (file)
@@ -16,7 +16,6 @@ from sqlalchemy.types import NULLTYPE
 
 from .. import util
 from ..util import sqla_compat
-from ..util.compat import string_types
 
 if TYPE_CHECKING:
     from sqlalchemy.sql.elements import ColumnElement
@@ -269,7 +268,7 @@ class SchemaObjects:
         ForeignKey.
 
         """
-        if isinstance(fk._colspec, string_types):  # type:ignore[attr-defined]
+        if isinstance(fk._colspec, str):  # type:ignore[attr-defined]
             table_key, cname = fk._colspec.rsplit(  # type:ignore[attr-defined]
                 ".", 1
             )
index eccb98ec8a8b1cce8c36922b6e7d48e67f0622e9..4b4e29c9c7fa5a2fedbb55c41dc55ab2bfd71747 100644 (file)
@@ -20,7 +20,6 @@ from typing import Union
 from sqlalchemy import util as sqlautil
 
 from .. import util
-from ..util import compat
 
 if TYPE_CHECKING:
     from typing import Literal
@@ -724,16 +723,12 @@ class RevisionMap:
         self, id_: Optional[str]
     ) -> Tuple[Tuple[str, ...], Optional[str]]:
         branch_label: Optional[str]
-        if isinstance(id_, compat.string_types) and "@" in id_:
+        if isinstance(id_, str) and "@" in id_:
             branch_label, id_ = id_.split("@", 1)
 
         elif id_ is not None and (
-            (
-                isinstance(id_, tuple)
-                and id_
-                and not isinstance(id_[0], compat.string_types)
-            )
-            or not isinstance(id_, compat.string_types + (tuple,))
+            (isinstance(id_, tuple) and id_ and not isinstance(id_[0], str))
+            or not isinstance(id_, (str, tuple))
         ):
             raise RevisionError(
                 "revision identifier %r is not a string; ensure database "
@@ -1029,7 +1024,7 @@ class RevisionMap:
         walk to.
         """
         initial: Optional[_RevisionOrBase]
-        if isinstance(start, compat.string_types):
+        if isinstance(start, str):
             initial = self.get_revision(start)
         else:
             initial = start
@@ -1092,7 +1087,7 @@ class RevisionMap:
         if target is None:
             return None, None
         assert isinstance(
-            target, compat.string_types
+            target, str
         ), "Expected downgrade target in string form"
         match = _relative_destination.match(target)
         if match:
@@ -1183,7 +1178,7 @@ class RevisionMap:
         to. The target may be specified in absolute form, or relative to
         :current_revisions.
         """
-        if isinstance(target, compat.string_types):
+        if isinstance(target, str):
             match = _relative_destination.match(target)
         else:
             match = None
@@ -1400,7 +1395,7 @@ class RevisionMap:
 
         # Handled named bases (e.g. branch@... -> heads should only produce
         # targets on the given branch)
-        if isinstance(lower, compat.string_types) and "@" in lower:
+        if isinstance(lower, str) and "@" in lower:
             branch, _, _ = lower.partition("@")
             branch_rev = self.get_revision(branch)
             if branch_rev is not None and branch_rev.revision == branch:
index ed532062d1b81c54669f6023b384375c9f090ba9..e7a12c65ae02729f4d70e3c3e3eb96b04b4b7ef6 100644 (file)
@@ -5,7 +5,6 @@ from typing import Any
 from typing import Dict
 
 from sqlalchemy import exc as sa_exc
-from sqlalchemy import util
 from sqlalchemy.engine import default
 from sqlalchemy.testing.assertions import _expect_warnings
 from sqlalchemy.testing.assertions import eq_  # noqa
@@ -85,12 +84,10 @@ def _expect_raises(except_cls, msg=None, check_context=False):
         ec.error = err
         success = True
         if msg is not None:
-            assert re.search(
-                msg, util.text_type(err), re.UNICODE
-            ), "%r !~ %s" % (msg, err)
+            assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
         if check_context and not are_we_already_in_a_traceback:
             _assert_proper_exception_context(err)
-        print(util.text_type(err).encode("utf-8"))
+        print(str(err).encode("utf-8"))
 
     # assert outside the block so it works for AssertionError too !
     assert success, "Callable did not raise an exception"
index 5e6ba89cb45d3b215ad29da2aaa04ebb11687072..849bc83089b31861264737ff03efcdb8dba87e96 100644 (file)
@@ -25,8 +25,6 @@ from ..environment import EnvironmentContext
 from ..migration import MigrationContext
 from ..operations import Operations
 from ..util import sqla_compat
-from ..util.compat import string_types
-from ..util.compat import text_type
 from ..util.sqla_compat import create_mock_engine
 from ..util.sqla_compat import sqla_14
 from ..util.sqla_compat import sqla_1x
@@ -203,10 +201,10 @@ def op_fixture(
     if not as_sql:
 
         def execute(stmt, *multiparam, **param):
-            if isinstance(stmt, string_types):
+            if isinstance(stmt, str):
                 stmt = text(stmt)
             assert stmt.supports_execution
-            sql = text_type(stmt.compile(dialect=ctx_dialect))
+            sql = str(stmt.compile(dialect=ctx_dialect))
 
             buf.write(sql)
 
index 500cee80622cd779eaff478decc437bb882cd841..44094216913755268d5554b80cd58a74deb508ce 100644 (file)
@@ -1,5 +1,6 @@
+from itertools import zip_longest
+
 from sqlalchemy import schema
-from sqlalchemy import util
 
 
 class CompareTable:
@@ -10,7 +11,7 @@ class CompareTable:
         if self.table.name != other.name or self.table.schema != other.schema:
             return False
 
-        for c1, c2 in util.zip_longest(self.table.c, other.c):
+        for c1, c2 in zip_longest(self.table.c, other.c):
             if (c1 is None and c2 is not None) or (
                 c2 is None and c1 is not None
             ):
@@ -86,7 +87,7 @@ class CompareForeignKey:
         )
         if not r1:
             return False
-        for c1, c2 in util.zip_longest(self.constraint.columns, other.columns):
+        for c1, c2 in zip_longest(self.constraint.columns, other.columns):
             if (c1 is None and c2 is not None) or (
                 c2 is None and c1 is not None
             ):
@@ -113,7 +114,7 @@ class ComparePrimaryKey:
         if not r1:
             return False
 
-        for c1, c2 in util.zip_longest(self.constraint.columns, other.columns):
+        for c1, c2 in zip_longest(self.constraint.columns, other.columns):
             if (c1 is None and c2 is not None) or (
                 c2 is None and c1 is not None
             ):
@@ -141,7 +142,7 @@ class CompareUniqueConstraint:
         if not r1:
             return False
 
-        for c1, c2 in util.zip_longest(self.constraint.columns, other.columns):
+        for c1, c2 in zip_longest(self.constraint.columns, other.columns):
             if (c1 is None and c2 is not None) or (
                 c2 is None and c1 is not None
             ):
index 48218ab2772533524534e807965c76ca2f90891f..54420cbc95480a3040cb739d8761fb1bdee66fe6 100644 (file)
@@ -12,10 +12,6 @@ py39 = sys.version_info >= (3, 9)
 py38 = sys.version_info >= (3, 8)
 py37 = sys.version_info >= (3, 7)
 
-string_types = (str,)
-binary_type = bytes
-text_type = str
-
 
 # produce a wrapper that allows encoded text to stream
 # into a given buffer, but doesn't close it.
index 4db9a5f06d37a270d1a7fcb3eccf03324b606d78..fd7ccb8fd1e3151810797d80ef356cb51bc89f9f 100644 (file)
@@ -21,7 +21,6 @@ from sqlalchemy.util import to_list  # noqa
 from sqlalchemy.util import unique_list  # noqa
 
 from .compat import inspect_getfullargspec
-from .compat import string_types
 
 
 _T = TypeVar("_T")
@@ -209,7 +208,7 @@ def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
 def to_tuple(x, default=None):
     if x is None:
         return default
-    elif isinstance(x, string_types):
+    elif isinstance(x, str):
         return (x,)
     elif isinstance(x, Iterable):
         return tuple(x)
@@ -241,7 +240,7 @@ class Dispatcher:
 
     def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
 
-        if isinstance(obj, string_types):
+        if isinstance(obj, str):
             targets: Sequence = [obj]
         elif isinstance(obj, type):
             targets = obj.__mro__
index 062890a32ee92c9464870b5d40c57eddbf278f36..66f8cc256a3938c486f12abdfee03abe8e7bed77 100644 (file)
@@ -12,8 +12,6 @@ import warnings
 from sqlalchemy.engine import url
 
 from . import sqla_compat
-from .compat import binary_type
-from .compat import string_types
 
 log = logging.getLogger(__name__)
 
@@ -37,7 +35,7 @@ except (ImportError, IOError):
 def write_outstream(stream: TextIO, *text) -> None:
     encoding = getattr(stream, "encoding", "ascii") or "ascii"
     for t in text:
-        if not isinstance(t, binary_type):
+        if not isinstance(t, bytes):
             t = t.encode(encoding, "replace")
         t = t.decode(encoding)
         try:
@@ -100,7 +98,7 @@ def msg(msg: str, newline: bool = True, flush: bool = False) -> None:
 def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str:
     if value is None:
         return ""
-    elif isinstance(value, string_types):
+    elif isinstance(value, str):
         return value
     elif isinstance(value, Iterable):
         return ", ".join(value)
index 221e20e86c71f7b91114d7c44526e0f49719fa1c..4d0041e10cc23181c4cccec9737512a333ad162b 100644 (file)
@@ -23,8 +23,6 @@ from sqlalchemy.sql.elements import quoted_name
 from sqlalchemy.sql.elements import TextClause
 from sqlalchemy.sql.visitors import traverse
 
-from . import compat
-
 if TYPE_CHECKING:
     from sqlalchemy import Index
     from sqlalchemy import Table
@@ -338,7 +336,7 @@ def _textual_index_column(
     table: "Table", text_: Union[str, "TextClause", "ColumnElement"]
 ) -> Union["ColumnElement", "Column"]:
     """a workaround for the Index construct's severe lack of flexibility"""
-    if isinstance(text_, compat.string_types):
+    if isinstance(text_, str):
         c = Column(text_, sqltypes.NULLTYPE)
         table.append_column(c)
         return c
index b3146d3cba095a3913e63765e40d2bac01d3546d..96161f6dece8e8357319d51460676b84aae9c490 100644 (file)
@@ -28,7 +28,6 @@ from alembic.testing.env import write_script
 from alembic.testing.fixtures import capture_context_buffer
 from alembic.testing.fixtures import FutureEngineMixin
 from alembic.testing.fixtures import TestBase
-from alembic.util import compat
 
 
 class PatchEnvironment:
@@ -383,7 +382,7 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
             assert isinstance(step.is_upgrade, bool)
             assert isinstance(step.is_stamp, bool)
             assert isinstance(step.is_migration, bool)
-            assert isinstance(step.up_revision_id, compat.string_types)
+            assert isinstance(step.up_revision_id, str)
             assert isinstance(step.up_revision, Script)
 
             for revtype in "up", "down", "source", "destination":
@@ -393,12 +392,12 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
                     assert isinstance(rev, Script)
                 revids = getattr(step, "%s_revision_ids" % revtype)
                 for revid in revids:
-                    assert isinstance(revid, compat.string_types)
+                    assert isinstance(revid, str)
 
             heads = kw["heads"]
             assert hasattr(heads, "__iter__")
             for h in heads:
-                assert h is None or isinstance(h, compat.string_types)
+                assert h is None or isinstance(h, str)
 
 
 class OfflineTransactionalDDLTest(TestBase):