]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow bind processors to work with expanding IN
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Feb 2018 00:47:24 +0000 (19:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Feb 2018 18:10:08 +0000 (13:10 -0500)
Fixed bug in new "expanding IN parameter" feature where the bind parameter
processors for values wasn't working at all, tests failed to cover this
pretty basic case which includes that ENUM values weren't working.

Change-Id: I8e2420d7229a3e253e43b5227ebb98f9fe0bd14a
Fixes: #4198
doc/build/changelog/unreleased_12/4198.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_12/4198.rst b/doc/build/changelog/unreleased_12/4198.rst
new file mode 100644 (file)
index 0000000..e5cc3cf
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4198
+
+    Fixed bug in new "expanding IN parameter" feature where the bind parameter
+    processors for values wasn't working at all, tests failed to cover this
+    pretty basic case which includes that ENUM values weren't working.
index ed2ed050916828067c40b75680c1f6546ab9d06f..5806fe2a93887031dac6e133d22144948ca6c17c 100644 (file)
@@ -766,7 +766,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                 compiled_params.update(to_update)
                 processors.update(
                     (key, processors[name])
-                    for key in to_update if name in processors
+                    for key, value in to_update if name in processors
                 )
                 if compiled.positional:
                     positiontup.extend(name for name, value in to_update)
index 002094f7bb416d6cbcd867bbcc26552b56fcceb6..5fbfcd2d82121a011aedaa4d3035121d49f5aea7 100644 (file)
@@ -299,21 +299,137 @@ class PickleTypesTest(fixtures.TestBase):
                 loads(dumps(meta))
 
 
-class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
+class _UserDefinedTypeFixture(object):
+    @classmethod
+    def define_tables(cls, metadata):
+        class MyType(types.UserDefinedType):
 
-    """tests user-defined types."""
+            def get_col_spec(self):
+                return "VARCHAR(100)"
+
+            def bind_processor(self, dialect):
+                def process(value):
+                    return "BIND_IN" + value
+                return process
+
+            def result_processor(self, dialect, coltype):
+                def process(value):
+                    return value + "BIND_OUT"
+                return process
+
+            def adapt(self, typeobj):
+                return typeobj()
+
+        class MyDecoratedType(types.TypeDecorator):
+            impl = String
+
+            def bind_processor(self, dialect):
+                impl_processor = super(MyDecoratedType, self).\
+                    bind_processor(dialect) or (lambda value: value)
+
+                def process(value):
+                    return "BIND_IN" + impl_processor(value)
+                return process
+
+            def result_processor(self, dialect, coltype):
+                impl_processor = super(MyDecoratedType, self).\
+                    result_processor(dialect, coltype) or (lambda value: value)
+
+                def process(value):
+                    return impl_processor(value) + "BIND_OUT"
+                return process
+
+            def copy(self):
+                return MyDecoratedType()
+
+        class MyNewUnicodeType(types.TypeDecorator):
+            impl = Unicode
+
+            def process_bind_param(self, value, dialect):
+                return "BIND_IN" + value
+
+            def process_result_value(self, value, dialect):
+                return value + "BIND_OUT"
+
+            def copy(self):
+                return MyNewUnicodeType(self.impl.length)
+
+        class MyNewIntType(types.TypeDecorator):
+            impl = Integer
+
+            def process_bind_param(self, value, dialect):
+                return value * 10
+
+            def process_result_value(self, value, dialect):
+                return value * 10
+
+            def copy(self):
+                return MyNewIntType()
+
+        class MyNewIntSubClass(MyNewIntType):
+
+            def process_result_value(self, value, dialect):
+                return value * 15
+
+            def copy(self):
+                return MyNewIntSubClass()
+
+        class MyUnicodeType(types.TypeDecorator):
+            impl = Unicode
+
+            def bind_processor(self, dialect):
+                impl_processor = super(MyUnicodeType, self).\
+                    bind_processor(dialect) or (lambda value: value)
+
+                def process(value):
+                    return "BIND_IN" + impl_processor(value)
+                return process
+
+            def result_processor(self, dialect, coltype):
+                impl_processor = super(MyUnicodeType, self).\
+                    result_processor(dialect, coltype) or (lambda value: value)
+
+                def process(value):
+                    return impl_processor(value) + "BIND_OUT"
+                return process
+
+            def copy(self):
+                return MyUnicodeType(self.impl.length)
+
+        Table(
+            'users', metadata,
+            Column('user_id', Integer, primary_key=True),
+            # totall custom type
+            Column('goofy', MyType, nullable=False),
+
+            # decorated type with an argument, so its a String
+            Column('goofy2', MyDecoratedType(50), nullable=False),
+
+            Column('goofy4', MyUnicodeType(50), nullable=False),
+            Column('goofy7', MyNewUnicodeType(50), nullable=False),
+            Column('goofy8', MyNewIntType, nullable=False),
+            Column('goofy9', MyNewIntSubClass, nullable=False),
+        )
+
+class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
+    __backend__ = True
+
+    def _data_fixture(self):
+        users = self.tables.users
+        with testing.db.connect() as conn:
+            conn.execute(users.insert(), dict(
+                user_id=2, goofy='jack', goofy2='jack', goofy4=util.u('jack'),
+                goofy7=util.u('jack'), goofy8=12, goofy9=12))
+            conn.execute(users.insert(), dict(
+                user_id=3, goofy='lala', goofy2='lala', goofy4=util.u('lala'),
+                goofy7=util.u('lala'), goofy8=15, goofy9=15))
+            conn.execute(users.insert(), dict(
+                user_id=4, goofy='fred', goofy2='fred', goofy4=util.u('fred'),
+                goofy7=util.u('fred'), goofy8=9, goofy9=9))
 
     def test_processing(self):
         users = self.tables.users
-        users.insert().execute(
-            user_id=2, goofy='jack', goofy2='jack', goofy4=util.u('jack'),
-            goofy7=util.u('jack'), goofy8=12, goofy9=12)
-        users.insert().execute(
-            user_id=3, goofy='lala', goofy2='lala', goofy4=util.u('lala'),
-            goofy7=util.u('lala'), goofy8=15, goofy9=15)
-        users.insert().execute(
-            user_id=4, goofy='fred', goofy2='fred', goofy4=util.u('fred'),
-            goofy7=util.u('fred'), goofy8=9, goofy9=9)
+        self._data_fixture()
 
         result = users.select().order_by(users.c.user_id).execute().fetchall()
         for assertstr, assertint, assertint2, row in zip(
@@ -331,6 +447,36 @@ class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
             for col in row[3], row[4]:
                 assert isinstance(col, util.text_type)
 
+    def test_plain_in(self):
+        users = self.tables.users
+        self._data_fixture()
+
+        stmt = select([users.c.user_id, users.c.goofy8]).where(
+                users.c.goofy8.in_([15, 9])
+            ).order_by(users.c.user_id)
+        result = testing.db.execute(stmt, {"goofy": [15, 9]})
+        eq_(result.fetchall(), [(3, 1500), (4, 900)])
+
+    def test_expanding_in(self):
+        users = self.tables.users
+        self._data_fixture()
+
+        stmt = select([users.c.user_id, users.c.goofy8]).where(
+                users.c.goofy8.in_(bindparam("goofy", expanding=True))
+            ).order_by(users.c.user_id)
+        result = testing.db.execute(stmt, {"goofy": [15, 9]})
+        eq_(result.fetchall(), [(3, 1500), (4, 900)])
+
+
+class UserDefinedTest(
+        _UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL):
+
+    run_create_tables = None
+    run_inserts = None
+    run_deletes = None
+
+    """tests user-defined types."""
+
     def test_typedecorator_literal_render(self):
         class MyType(types.TypeDecorator):
             impl = String
@@ -500,117 +646,6 @@ class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
         eq_(a.foo, 'foo')
         eq_(a.dialect_specific_args['bar'], 'bar')
 
-    @classmethod
-    def define_tables(cls, metadata):
-        class MyType(types.UserDefinedType):
-
-            def get_col_spec(self):
-                return "VARCHAR(100)"
-
-            def bind_processor(self, dialect):
-                def process(value):
-                    return "BIND_IN" + value
-                return process
-
-            def result_processor(self, dialect, coltype):
-                def process(value):
-                    return value + "BIND_OUT"
-                return process
-
-            def adapt(self, typeobj):
-                return typeobj()
-
-        class MyDecoratedType(types.TypeDecorator):
-            impl = String
-
-            def bind_processor(self, dialect):
-                impl_processor = super(MyDecoratedType, self).\
-                    bind_processor(dialect) or (lambda value: value)
-
-                def process(value):
-                    return "BIND_IN" + impl_processor(value)
-                return process
-
-            def result_processor(self, dialect, coltype):
-                impl_processor = super(MyDecoratedType, self).\
-                    result_processor(dialect, coltype) or (lambda value: value)
-
-                def process(value):
-                    return impl_processor(value) + "BIND_OUT"
-                return process
-
-            def copy(self):
-                return MyDecoratedType()
-
-        class MyNewUnicodeType(types.TypeDecorator):
-            impl = Unicode
-
-            def process_bind_param(self, value, dialect):
-                return "BIND_IN" + value
-
-            def process_result_value(self, value, dialect):
-                return value + "BIND_OUT"
-
-            def copy(self):
-                return MyNewUnicodeType(self.impl.length)
-
-        class MyNewIntType(types.TypeDecorator):
-            impl = Integer
-
-            def process_bind_param(self, value, dialect):
-                return value * 10
-
-            def process_result_value(self, value, dialect):
-                return value * 10
-
-            def copy(self):
-                return MyNewIntType()
-
-        class MyNewIntSubClass(MyNewIntType):
-
-            def process_result_value(self, value, dialect):
-                return value * 15
-
-            def copy(self):
-                return MyNewIntSubClass()
-
-        class MyUnicodeType(types.TypeDecorator):
-            impl = Unicode
-
-            def bind_processor(self, dialect):
-                impl_processor = super(MyUnicodeType, self).\
-                    bind_processor(dialect) or (lambda value: value)
-
-                def process(value):
-                    return "BIND_IN" + impl_processor(value)
-                return process
-
-            def result_processor(self, dialect, coltype):
-                impl_processor = super(MyUnicodeType, self).\
-                    result_processor(dialect, coltype) or (lambda value: value)
-
-                def process(value):
-                    return impl_processor(value) + "BIND_OUT"
-                return process
-
-            def copy(self):
-                return MyUnicodeType(self.impl.length)
-
-        Table(
-            'users', metadata,
-            Column('user_id', Integer, primary_key=True),
-            # totall custom type
-            Column('goofy', MyType, nullable=False),
-
-            # decorated type with an argument, so its a String
-            Column('goofy2', MyDecoratedType(50), nullable=False),
-
-            Column('goofy4', MyUnicodeType(50), nullable=False),
-            Column('goofy7', MyNewUnicodeType(50), nullable=False),
-            Column('goofy8', MyNewIntType, nullable=False),
-            Column('goofy9', MyNewIntSubClass, nullable=False),
-        )
-
 
 class TypeCoerceCastTest(fixtures.TablesTest):
 
@@ -1565,6 +1600,34 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             ]
         )
 
+    def test_pep435_enum_expanding_in(self):
+        stdlib_enum_table_custom_values =\
+            self.tables['stdlib_enum_table2']
+
+        stdlib_enum_table_custom_values.insert().execute([
+            {'id': 1, 'someotherenum': self.SomeOtherEnum.one},
+            {'id': 2, 'someotherenum': self.SomeOtherEnum.two},
+            {'id': 3, 'someotherenum': self.SomeOtherEnum.three}
+        ])
+
+        stmt = stdlib_enum_table_custom_values.select().where(
+            stdlib_enum_table_custom_values.c.someotherenum.in_(
+                bindparam("member", expanding=True)
+            )
+        ).order_by(stdlib_enum_table_custom_values.c.id)
+        eq_(
+            testing.db.execute(
+                stmt,
+                {"member": [
+                    self.SomeOtherEnum.one,
+                    self.SomeOtherEnum.three]}
+            ).fetchall(),
+            [
+                (1, self.SomeOtherEnum.one),
+                (3, self.SomeOtherEnum.three)
+            ]
+        )
+
     def test_adapt(self):
         from sqlalchemy.dialects.postgresql import ENUM
         e1 = Enum('one', 'two', 'three', native_enum=False)