]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Genericize str() for types
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Aug 2020 17:57:04 +0000 (13:57 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Aug 2020 23:49:15 +0000 (19:49 -0400)
Remove lookup logic that attempts to locate a dialect for a type,
just use StrSQLTypeCompiler.

Cleaned up the internal ``str()`` for datatypes so that all types produce a
string representation without any dialect present, including that it works
for third-party dialect types without that dialect being present.  The
string representation defaults to being the UPPERCASE name of that type
with nothing else.

Fixes: #4262
Change-Id: I02149e8a1ba1e7336149e962939b07ae0df83c6b

doc/build/changelog/unreleased_14/4262.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/type_api.py
test/dialect/mssql/test_types.py
test/sql/test_compiler.py
test/sql/test_metadata.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/4262.rst b/doc/build/changelog/unreleased_14/4262.rst
new file mode 100644 (file)
index 0000000..8377dac
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, types
+    :tickets: 4262
+
+    Cleaned up the internal ``str()`` for datatypes so that all types produce a
+    string representation without any dialect present, including that it works
+    for third-party dialect types without that dialect being present.  The
+    string representation defaults to being the UPPERCASE name of that type
+    with nothing else.
+
index 61e26b003c0c29e9308fbb8954ecc7353a432c4b..a8bd1de33723cf4dec5c4acc2fddd2b626166ae0 100644 (file)
@@ -4270,6 +4270,14 @@ class GenericTypeCompiler(TypeCompiler):
 
 
 class StrSQLTypeCompiler(GenericTypeCompiler):
+    def process(self, type_, **kw):
+        try:
+            _compiler_dispatch = type_._compiler_dispatch
+        except AttributeError:
+            return self._visit_unknown(type_, **kw)
+        else:
+            return _compiler_dispatch(self, **kw)
+
     def __getattr__(self, key):
         if key.startswith("visit_"):
             return self._visit_unknown
@@ -4277,7 +4285,21 @@ class StrSQLTypeCompiler(GenericTypeCompiler):
             raise AttributeError(key)
 
     def _visit_unknown(self, type_, **kw):
-        return "%s" % type_.__class__.__name__
+        if type_.__class__.__name__ == type_.__class__.__name__.upper():
+            return type_.__class__.__name__
+        else:
+            return repr(type_)
+
+    def visit_null(self, type_, **kw):
+        return "NULL"
+
+    def visit_user_defined(self, type_, **kw):
+        try:
+            get_col_spec = type_.get_col_spec
+        except AttributeError:
+            return repr(type_)
+        else:
+            return get_col_spec(**kw)
 
 
 class IdentifierPreparer(object):
index 2d23c56e182c412181bb8e48b792b88bdc75aa27..1284ef5155a3d6a2228c0121f2108d45ed01413d 100644 (file)
@@ -626,12 +626,7 @@ class TypeEngine(Traversible):
     @util.preload_module("sqlalchemy.engine.default")
     def _default_dialect(self):
         default = util.preloaded.engine_default
-        if self.__class__.__module__.startswith("sqlalchemy.dialects"):
-            tokens = self.__class__.__module__.split(".")[0:3]
-            mod = ".".join(tokens)
-            return getattr(__import__(mod).dialects, tokens[-1]).dialect()
-        else:
-            return default.DefaultDialect()
+        return default.StrCompileDialect()
 
     def __str__(self):
         if util.py2k:
index e28a4249833fb8b69b3408eb042e1918f3cbb0c4..399e0ca9059217869ee40c2af9f01d7db101d4f7 100644 (file)
@@ -953,9 +953,14 @@ class TypeRoundTripTest(
         )
         for col, spec in zip(reflected_binary.c, columns):
             eq_(
-                str(col.type),
+                col.type.compile(dialect=mssql.dialect()),
                 spec[3],
-                "column %s %s != %s" % (col.key, str(col.type), spec[3]),
+                "column %s %s != %s"
+                % (
+                    col.key,
+                    col.type.compile(dialect=mssql.dialect()),
+                    spec[3],
+                ),
             )
             c1 = testing.db.dialect.type_descriptor(col.type).__class__
             c2 = testing.db.dialect.type_descriptor(
index d79d00555419c4593c9cec4b78e38ab2fa279d85..1d31f1ea5ef9c51232bd424c93e43f3184cfe1c7 100644 (file)
@@ -3940,7 +3940,7 @@ class StringifySpecialTest(fixtures.TestBase):
 
         eq_ignore_whitespace(
             str(stmt),
-            "SELECT CAST(mytable.myid AS MyType) AS myid FROM mytable",
+            "SELECT CAST(mytable.myid AS MyType()) AS myid FROM mytable",
         )
 
     def test_within_group(self):
index 3303eac1d12d1303e29a631a61f476b434119dfc..dc4e342fda82a40290e54ba76b9d8861e5dd5889 100644 (file)
@@ -3970,13 +3970,9 @@ class ColumnOptionsTest(fixtures.TestBase):
         assert Column(String, default=g2).default is g2
         assert Column(String, onupdate=g2).onupdate is g2
 
-    def _null_type_error(self, col):
-        t = Table("t", MetaData(), col)
-        assert_raises_message(
-            exc.CompileError,
-            r"\(in table 't', column 'foo'\): Can't generate DDL for NullType",
-            schema.CreateTable(t).compile,
-        )
+    def _null_type_no_error(self, col):
+        c_str = str(schema.CreateColumn(col).compile())
+        assert "NULL" in c_str
 
     def _no_name_error(self, col):
         assert_raises_message(
@@ -3997,13 +3993,13 @@ class ColumnOptionsTest(fixtures.TestBase):
 
     def test_argument_signatures(self):
         self._no_name_error(Column())
-        self._null_type_error(Column("foo"))
+        self._null_type_no_error(Column("foo"))
         self._no_name_error(Column(default="foo"))
 
         self._no_name_error(Column(Sequence("a")))
-        self._null_type_error(Column("foo", default="foo"))
+        self._null_type_no_error(Column("foo", default="foo"))
 
-        self._null_type_error(Column("foo", Sequence("a")))
+        self._null_type_no_error(Column("foo", Sequence("a")))
 
         self._no_name_error(Column(ForeignKey("bar.id")))
 
index 30e5d1fca175bdae7c5106083a92a1cbc757dfe7..fac9fd1399aaa224b1eda90fe46b562592371642 100644 (file)
@@ -294,6 +294,33 @@ class AdaptTest(fixtures.TestBase):
             t1 = typ()
         repr(t1)
 
+    @testing.uses_deprecated()
+    @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
+    def test_str(self, typ):
+        if issubclass(typ, ARRAY):
+            t1 = typ(String)
+        else:
+            t1 = typ()
+        str(t1)
+
+    def test_str_third_party(self):
+        class TINYINT(types.TypeEngine):
+            __visit_name__ = "TINYINT"
+
+        eq_(str(TINYINT()), "TINYINT")
+
+    def test_str_third_party_uppercase_no_visit_name(self):
+        class TINYINT(types.TypeEngine):
+            pass
+
+        eq_(str(TINYINT()), "TINYINT")
+
+    def test_str_third_party_camelcase_no_visit_name(self):
+        class TinyInt(types.TypeEngine):
+            pass
+
+        eq_(str(TinyInt()), "TinyInt()")
+
     def test_adapt_constructor_copy_override_kw(self):
         """test that adapt() can accept kw args that override
         the state of the original object.
@@ -2878,10 +2905,16 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_default_compile_mysql_integer(self):
         self.assert_compile(
             dialects.mysql.INTEGER(display_width=5),
-            "INTEGER(5)",
+            "INTEGER",
             allow_dialect_select=True,
         )
 
+        self.assert_compile(
+            dialects.mysql.INTEGER(display_width=5),
+            "INTEGER(5)",
+            dialect="mysql",
+        )
+
     def test_numeric_plain(self):
         self.assert_compile(types.NUMERIC(), "NUMERIC")