]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rewrite SQLite reflection tests into one consistent fixture, which tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Feb 2014 23:14:10 +0000 (18:14 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Feb 2014 23:14:10 +0000 (18:14 -0500)
both _resolve_type_affinity() directly as well as round trip tests fully.

lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/warnings.py
test/dialect/test_sqlite.py

index a1bd05d38b59f61e886768404a1d81bc671d6ae3..70caef1e42bcd238c40de7adceee0419f4b20dce 100644 (file)
@@ -857,24 +857,7 @@ class SQLiteDialect(default.DefaultDialect):
         return columns
 
     def _get_column_info(self, name, type_, nullable, default, primary_key):
-        match = re.match(r'([\w ]+)(\(.*?\))?', type_)
-        if match:
-            coltype = match.group(1)
-            args = match.group(2)
-        else:
-            coltype = ''
-            args = ''
-        coltype = self._resolve_type_affinity(coltype)
-        if args is not None:
-            args = re.findall(r'(\d+)', args)
-            try:
-                coltype = coltype(*[int(a) for a in args])
-            except TypeError:
-                util.warn(
-                        "Could not instantiate type %s with "
-                        "reflected arguments %s; using no arguments." %
-                        (coltype, args))
-                coltype = coltype()
+        coltype = self._resolve_type_affinity(type_)
 
         if default is not None:
             default = util.text_type(default)
@@ -888,7 +871,7 @@ class SQLiteDialect(default.DefaultDialect):
             'primary_key': primary_key,
         }
 
-    def _resolve_type_affinity(self, coltype):
+    def _resolve_type_affinity(self, type_):
         """Return a data type from a reflected column, using affinity tules.
 
         SQLite's goal for universal compatability introduces some complexity
@@ -905,18 +888,41 @@ class SQLiteDialect(default.DefaultDialect):
         DATE and DOUBLE).
 
         """
+        match = re.match(r'([\w ]+)(\(.*?\))?', type_)
+        if match:
+            coltype = match.group(1)
+            args = match.group(2)
+        else:
+            coltype = ''
+            args = ''
+
         if coltype in self.ischema_names:
-            return self.ischema_names[coltype]
+            coltype = self.ischema_names[coltype]
         elif 'INT' in coltype:
-            return sqltypes.INTEGER
+            coltype = sqltypes.INTEGER
         elif 'CHAR' in coltype or 'CLOB' in coltype or 'TEXT' in coltype:
-            return sqltypes.TEXT
+            coltype = sqltypes.TEXT
         elif 'BLOB' in coltype or not coltype:
-            return sqltypes.NullType
+            coltype = sqltypes.NullType
         elif 'REAL' in coltype or 'FLOA' in coltype or 'DOUB' in coltype:
-            return sqltypes.REAL
+            coltype = sqltypes.REAL
         else:
-            return sqltypes.NUMERIC
+            coltype = sqltypes.NUMERIC
+
+        if args is not None:
+            args = re.findall(r'(\d+)', args)
+            try:
+                coltype = coltype(*[int(a) for a in args])
+            except TypeError:
+                util.warn(
+                        "Could not instantiate type %s with "
+                        "reflected arguments %s; using no arguments." %
+                        (coltype, args))
+                coltype = coltype()
+        else:
+            coltype = coltype()
+
+        return coltype
 
     @reflection.cache
     def get_pk_constraint(self, connection, table_name, schema=None, **kw):
index 74a8933a62031888d4a13d85a021dc8a705330d6..849b1b5b49e008e788aaaa67a4674aef69477327 100644 (file)
@@ -9,7 +9,7 @@ from __future__ import absolute_import
 import warnings
 from .. import exc as sa_exc
 from .. import util
-
+import re
 
 def testing_warn(msg, stacklevel=3):
     """Replaces sqlalchemy.util.warn during tests."""
@@ -33,7 +33,7 @@ def resetwarnings():
     warnings.filterwarnings('error', category=sa_exc.SAWarning)
 
 
-def assert_warnings(fn, warnings):
+def assert_warnings(fn, warnings, regex=False):
     """Assert that each of the given warnings are emitted by fn."""
 
     from .assertions import eq_, emits_warning
@@ -45,7 +45,10 @@ def assert_warnings(fn, warnings):
         orig_warn(*args, **kw)
         popwarn = warnings.pop(0)
         canary.append(popwarn)
-        eq_(args[0], popwarn)
+        if regex:
+            assert re.match(popwarn, args[0])
+        else:
+            eq_(args[0], popwarn)
     util.warn = util.langhelpers.warn = capture_warnings
 
     result = emits_warning()(fn)()
index 38bb783044bddf3414bb4d347fabea8ef8d04d78..e7fcce859b3f8752b5acb40c327d5d80a428c497 100644 (file)
@@ -6,16 +6,14 @@ from collections import Counter
 import datetime
 
 from sqlalchemy.testing import eq_, assert_raises, \
-    assert_raises_message
-from sqlalchemy import Table, String, select, Text, CHAR, bindparam, Column,\
-    Unicode, Date, MetaData, UnicodeText, Time, Integer, TIMESTAMP, \
-    Boolean, func, NUMERIC, DateTime, extract, ForeignKey, text, Numeric,\
-    DefaultClause, and_, DECIMAL, TypeDecorator, create_engine, Float, \
-    INTEGER, UniqueConstraint, DATETIME, DATE, TIME, BOOLEAN, BIGINT, \
-    VARCHAR
-from sqlalchemy.types import UserDefinedType
+    assert_raises_message, is_
+from sqlalchemy import Table, select, bindparam, Column,\
+    MetaData, func, extract, ForeignKey, text, DefaultClause, and_, create_engine,\
+    UniqueConstraint
+from sqlalchemy.types import Integer, String, Boolean, DateTime, Date, Time
+from sqlalchemy import types as sqltypes
 from sqlalchemy.util import u, ue
-from sqlalchemy import exc, sql, schema, pool, types as sqltypes, util
+from sqlalchemy import exc, sql, schema, pool, util
 from sqlalchemy.dialects.sqlite import base as sqlite, \
     pysqlite as pysqlite_dialect
 from sqlalchemy.engine.url import make_url
@@ -81,9 +79,9 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
                         | dbapi.PARSE_COLNAMES}
         engine = engines.testing_engine(options={'connect_args'
                 : connect_args, 'native_datetime': True})
-        t = Table('datetest', MetaData(), Column('id', Integer,
-                  primary_key=True), Column('d1', Date), Column('d2',
-                  TIMESTAMP))
+        t = Table('datetest', MetaData(),
+                    Column('id', Integer, primary_key=True),
+                    Column('d1', Date), Column('d2', sqltypes.TIMESTAMP))
         t.create(engine)
         try:
             engine.execute(t.insert(), {'d1': datetime.date(2010, 5,
@@ -151,98 +149,18 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
         dialect = sqlite.dialect()
         for t in (
             String(convert_unicode=True),
-            CHAR(convert_unicode=True),
-            Unicode(),
-            UnicodeText(),
+            sqltypes.CHAR(convert_unicode=True),
+            sqltypes.Unicode(),
+            sqltypes.UnicodeText(),
             String(convert_unicode=True),
-            CHAR(convert_unicode=True),
-            Unicode(),
-            UnicodeText(),
+            sqltypes.CHAR(convert_unicode=True),
+            sqltypes.Unicode(),
+            sqltypes.UnicodeText(),
             ):
             bindproc = t.dialect_impl(dialect).bind_processor(dialect)
             assert not bindproc or \
                 isinstance(bindproc(util.u('some string')), util.text_type)
 
-    @testing.emits_warning("Could not instantiate")
-    @testing.provide_metadata
-    def test_type_reflection(self):
-        metadata = self.metadata
-
-        # (ask_for, roundtripped_as_if_different)
-
-        class AnyType(UserDefinedType):
-            def __init__(self, spec):
-                self.spec = spec
-
-            def get_col_spec(self):
-              return self.spec
-
-        specs = [
-            (String(), String()),
-            (String(1), String(1)),
-            (String(3), String(3)),
-            (Text(), Text()),
-            (Unicode(), String()),
-            (Unicode(1), String(1)),
-            (Unicode(3), String(3)),
-            (UnicodeText(), Text()),
-            (CHAR(1), ),
-            (CHAR(3), CHAR(3)),
-            (NUMERIC, NUMERIC()),
-            (NUMERIC(10, 2), NUMERIC(10, 2)),
-            (Numeric, NUMERIC()),
-            (Numeric(10, 2), NUMERIC(10, 2)),
-            (DECIMAL, DECIMAL()),
-            (DECIMAL(10, 2), DECIMAL(10, 2)),
-            (INTEGER, INTEGER()),
-            (BIGINT, BIGINT()),
-            (Float, Float()),
-            (NUMERIC(), ),
-            (TIMESTAMP, TIMESTAMP()),
-            (DATETIME, DATETIME()),
-            (DateTime, DateTime()),
-            (DateTime(), ),
-            (DATE, DATE()),
-            (Date, Date()),
-            (TIME, TIME()),
-            (Time, Time()),
-            (BOOLEAN, BOOLEAN()),
-            (Boolean, Boolean()),
-            # types with unsupported arguments
-            (AnyType("INTEGER(5)"), INTEGER()),
-            (AnyType("DATETIME(6, 12)"), DATETIME()),
-            ]
-        columns = [Column('c%i' % (i + 1), t[0]) for (i, t) in
-                   enumerate(specs)]
-        db = testing.db
-        Table('types', metadata, *columns)
-        metadata.create_all()
-        m2 = MetaData(db)
-        rt = Table('types', m2, autoload=True)
-        try:
-            db.execute('CREATE VIEW types_v AS SELECT * from types')
-            rv = Table('types_v', m2, autoload=True)
-            expected = [len(c) > 1 and c[1] or c[0] for c in specs]
-            for table in rt, rv:
-                for i, reflected in enumerate(table.c):
-                    assert isinstance(reflected.type,
-                            type(expected[i])), '%d: %r' % (i,
-                            type(expected[i]))
-        finally:
-            db.execute('DROP VIEW types_v')
-
-    @testing.provide_metadata
-    def test_unknown_reflection(self):
-        metadata = self.metadata
-        t = Table('t', metadata,
-            Column('x', sqltypes.BINARY(16)),
-            Column('y', sqltypes.BINARY())
-        )
-        t.create()
-        t2 = Table('t', MetaData(), autoload=True, autoload_with=testing.db)
-        assert isinstance(t2.c.x.type, sqltypes.Numeric)
-        assert isinstance(t2.c.y.type, sqltypes.Numeric)
-
 
 class DateTimeTest(fixtures.TestBase, AssertsCompiledSQL):
 
@@ -349,14 +267,14 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
 
         # (ask_for, roundtripped_as_if_different)
 
-        specs = [(String(3), '"foo"'), (NUMERIC(10, 2), '100.50'),
+        specs = [(String(3), '"foo"'), (sqltypes.NUMERIC(10, 2), '100.50'),
                  (Integer, '5'), (Boolean, 'False')]
         columns = [Column('c%i' % (i + 1), t[0],
                    server_default=text(t[1])) for (i, t) in
                    enumerate(specs)]
         db = testing.db
         m = MetaData(db)
-        t_table = Table('t_defaults', m, *columns)
+        Table('t_defaults', m, *columns)
         try:
             m.create_all()
             m2 = MetaData(db)
@@ -429,13 +347,8 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
         """test non-quoted integer value on older sqlite pragma"""
 
         dialect = sqlite.dialect()
-        eq_(
-            dialect._get_column_info("foo", "INTEGER", False, 3, False),
-            {'primary_key': False, 'nullable': False,
-                'default': '3', 'autoincrement': False,
-                'type': INTEGER, 'name': 'foo'}
-        )
-
+        info = dialect._get_column_info("foo", "INTEGER", False, 3, False)
+        eq_(info['default'], '3')
 
 
 
@@ -957,7 +870,7 @@ class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL):
                             dialect=sqlite.dialect())
 
     def test_sqlite_autoincrement_int_affinity(self):
-        class MyInteger(TypeDecorator):
+        class MyInteger(sqltypes.TypeDecorator):
             impl = Integer
         table = Table(
             'autoinctable',
@@ -1037,58 +950,130 @@ class ReflectFKConstraintTest(fixtures.TestBase):
         )
 
 
-class ColumnTypeAffinityReflectionTest(fixtures.TestBase):
-    """Tests on data type affinities for SQLite during relection.
+class TypeReflectionTest(fixtures.TestBase):
 
-    See http://www.sqlite.org/datatype3.html - section 2.
-    """
     __only_on__ = 'sqlite'
 
-    def setup(self):
-        testing.db.execute("""
-            CREATE TABLE a (
-                "id" INTEGER PRIMARY KEY,
-                "foo" DOUBLE,
-                "bar" DECIMAL(19,4),
-                "baz" VARCHAR(200),
-                "boff",
-                "biff" LONGTEXT
-            )""")
-        # These example names come from section 2.2 of the datatype docs,
-        # after pruning out types which we convert to more convenient types
-        self.example_typenames_integer = ["TINYINT", "MEDIUMINT", "INT2",
-            "UNSIGNED BIG INT", "INT8"]
-        self.example_typenames_text = ["CHARACTER(20)", "CLOB",
-            "VARYING CHARACTER(70)", "NATIVE CHARACTER(70)"]
-        self.example_typenames_none = [""]
-        self.example_typenames_real = ["DOUBLE PRECISION"]
-        cols = ["i%d %s" % (n, t) for n, t in enumerate(
-            self.example_typenames_integer)]
-        cols += ["t%d %s" % (n, t) for n, t in enumerate(
-            self.example_typenames_text)]
-        cols += ["o%d %s" % (n, t) for n, t in enumerate(
-            self.example_typenames_none)]
-        cols += ["n%d %s" % (n, t) for n, t in enumerate(
-            self.example_typenames_real)]
-        cols = ','.join(cols)
-        testing.db.execute("CREATE TABLE b (%s)" % (cols,))
-
-    def teardown(self):
-        testing.db.execute("drop table a")
-        testing.db.execute("drop table b")
-
-    def test_can_reflect_with_affinity(self):
-        "Test that 'affinity-types' don't break reflection outright."
-        meta = MetaData()
-        a = Table('a', meta, autoload=True, autoload_with=testing.db)
-        eq_(len(a.columns), 6)
+    def _fixed_lookup_fixture(self):
+        return [
+            (sqltypes.String(), sqltypes.VARCHAR()),
+            (sqltypes.String(1), sqltypes.VARCHAR(1)),
+            (sqltypes.String(3), sqltypes.VARCHAR(3)),
+            (sqltypes.Text(), sqltypes.TEXT()),
+            (sqltypes.Unicode(), sqltypes.VARCHAR()),
+            (sqltypes.Unicode(1), sqltypes.VARCHAR(1)),
+            (sqltypes.UnicodeText(), sqltypes.TEXT()),
+            (sqltypes.CHAR(3), sqltypes.CHAR(3)),
+            (sqltypes.NUMERIC, sqltypes.NUMERIC()),
+            (sqltypes.NUMERIC(10, 2), sqltypes.NUMERIC(10, 2)),
+            (sqltypes.Numeric, sqltypes.NUMERIC()),
+            (sqltypes.Numeric(10, 2), sqltypes.NUMERIC(10, 2)),
+            (sqltypes.DECIMAL, sqltypes.DECIMAL()),
+            (sqltypes.DECIMAL(10, 2), sqltypes.DECIMAL(10, 2)),
+            (sqltypes.INTEGER, sqltypes.INTEGER()),
+            (sqltypes.BIGINT, sqltypes.BIGINT()),
+            (sqltypes.Float, sqltypes.FLOAT()),
+            (sqltypes.TIMESTAMP, sqltypes.TIMESTAMP()),
+            (sqltypes.DATETIME, sqltypes.DATETIME()),
+            (sqltypes.DateTime, sqltypes.DATETIME()),
+            (sqltypes.DateTime(), sqltypes.DATETIME()),
+            (sqltypes.DATE, sqltypes.DATE()),
+            (sqltypes.Date, sqltypes.DATE()),
+            (sqltypes.TIME, sqltypes.TIME()),
+            (sqltypes.Time, sqltypes.TIME()),
+            (sqltypes.BOOLEAN, sqltypes.BOOLEAN()),
+            (sqltypes.Boolean, sqltypes.BOOLEAN()),
+        ]
+
+    def _unsupported_args_fixture(self):
+        return [
+            ("INTEGER(5)", sqltypes.INTEGER(),),
+            ("DATETIME(6, 12)", sqltypes.DATETIME())
+        ]
+
+    def _type_affinity_fixture(self):
+        return [
+            ("LONGTEXT", sqltypes.TEXT()),
+            ("TINYINT", sqltypes.INTEGER()),
+            ("MEDIUMINT", sqltypes.INTEGER()),
+            ("INT2", sqltypes.INTEGER()),
+            ("UNSIGNED BIG INT", sqltypes.INTEGER()),
+            ("INT8", sqltypes.INTEGER()),
+            ("CHARACTER(20)", sqltypes.TEXT()),
+            ("CLOB", sqltypes.TEXT()),
+            ("CLOBBER", sqltypes.TEXT()),
+            ("VARYING CHARACTER(70)", sqltypes.TEXT()),
+            ("NATIVE CHARACTER(70)", sqltypes.TEXT()),
+            ("BLOB", sqltypes.BLOB()),
+            ("BLOBBER", sqltypes.NullType()),
+            ("DOUBLE PRECISION", sqltypes.REAL()),
+            ("FLOATY", sqltypes.REAL()),
+            ("NOTHING WE KNOW", sqltypes.NUMERIC()),
+        ]
+
+    def _fixture_as_string(self, fixture):
+        for from_, to_ in fixture:
+            if isinstance(from_, sqltypes.TypeEngine):
+                from_ = str(from_.compile())
+            elif isinstance(from_, type):
+                from_ = str(from_().compile())
+            yield from_, to_
+
+    def _test_lookup_direct(self, fixture, warnings=False):
+        dialect = sqlite.dialect()
+        for from_, to_ in self._fixture_as_string(fixture):
+            if warnings:
+                def go():
+                    return dialect._resolve_type_affinity(from_)
+                final_type = testing.assert_warnings(go,
+                                ["Could not instantiate"], regex=True)
+            else:
+                final_type = dialect._resolve_type_affinity(from_)
+            expected_type = type(to_)
+            is_(type(final_type), expected_type)
+
+    def _test_round_trip(self, fixture, warnings=False):
+        from sqlalchemy import inspect
+        conn = testing.db.connect()
+        for from_, to_ in self._fixture_as_string(fixture):
+            inspector = inspect(conn)
+            conn.execute("CREATE TABLE foo (data %s)" % from_)
+            try:
+                if warnings:
+                    def go():
+                        return inspector.get_columns("foo")[0]
+                    col_info = testing.assert_warnings(go,
+                                    ["Could not instantiate"], regex=True)
+                else:
+                    col_info = inspector.get_columns("foo")[0]
+                expected_type = type(to_)
+                is_(type(col_info['type']), expected_type)
+
+                # test args
+                for attr in ("scale", "precision", "length"):
+                    if getattr(to_, attr, None) is not None:
+                        eq_(
+                            getattr(col_info['type'], attr),
+                            getattr(to_, attr, None)
+                        )
+            finally:
+                conn.execute("DROP TABLE foo")
+
+    def test_lookup_direct_lookup(self):
+        self._test_lookup_direct(self._fixed_lookup_fixture())
+
+    def test_lookup_direct_unsupported_args(self):
+        self._test_lookup_direct(self._unsupported_args_fixture(), warnings=True)
+
+    def test_lookup_direct_type_affinity(self):
+        self._test_lookup_direct(self._type_affinity_fixture())
+
+    def test_round_trip_direct_lookup(self):
+        self._test_round_trip(self._fixed_lookup_fixture())
+
+    def test_round_trip_direct_unsupported_args(self):
+        self._test_round_trip(self._unsupported_args_fixture(), warnings=True)
+
+    def test_round_trip_direct_type_affinity(self):
+        self._test_round_trip(self._type_affinity_fixture())
 
-    def test_correct_reflection_with_affinity(self):
-        "Test that coltypes are detected correctly from affinity rules."
-        meta = MetaData()
-        b = Table('b', meta, autoload=True, autoload_with=testing.db)
-        typecounts = Counter(type(col.type) for col in b.columns)
-        eq_(typecounts[sqltypes.INTEGER], len(self.example_typenames_integer))
-        eq_(typecounts[sqltypes.TEXT], len(self.example_typenames_text))
-        eq_(typecounts[sqltypes.NullType], len(self.example_typenames_none))
-        eq_(typecounts[sqltypes.REAL], len(self.example_typenames_real))