From a2468c8a31c8308cdb5740f2401e9dedd003836e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 17 Aug 2012 18:35:25 -0400 Subject: [PATCH] - [feature] To complement [ticket:2547], types can now provide "bind expressions" and "column expressions" which allow compile-time injection of SQL expressions into statements on a per-column or per-bind level. This is to suit the use case of a type which needs to augment bind- and result- behavior at the SQL level, as opposed to in the Python level. Allows for schemes like transparent encryption/ decryption, usage of Postgis functions, etc. [ticket:1534] - update postgis example fully. - still need to repair the result map propagation here to be transparent for cases like "labeled column". --- CHANGES | 25 +++ examples/postgis/postgis.py | 31 ++-- lib/sqlalchemy/dialects/access/base.py | 7 - lib/sqlalchemy/dialects/firebird/base.py | 8 +- lib/sqlalchemy/dialects/mssql/base.py | 31 ++-- lib/sqlalchemy/dialects/postgresql/base.py | 5 +- lib/sqlalchemy/sql/compiler.py | 136 +++++++++----- lib/sqlalchemy/types.py | 36 +++- test/aaa_profiling/test_compiler.py | 22 ++- test/lib/profiles.txt | 36 ++-- test/sql/test_returning.py | 2 +- test/sql/test_type_expressions.py | 196 +++++++++++++++++++++ 12 files changed, 404 insertions(+), 131 deletions(-) create mode 100644 test/sql/test_type_expressions.py diff --git a/CHANGES b/CHANGES index 1ff19ce4a7..babb2ba33d 100644 --- a/CHANGES +++ b/CHANGES @@ -309,6 +309,31 @@ underneath "0.7.xx". replaced by inserted_primary_key. - sql + - [feature] Major rework of operator system + in Core, to allow redefinition of existing + operators as well as addition of new operators + at the type level. New types can be created + from existing ones which add or redefine + operations that are exported out to column + expressions, in a similar manner to how the + ORM has allowed comparator_factory. The new + architecture moves this capability into the + Core so that it is consistently usable in + all cases, propagating cleanly using existing + type propagation behavior. [ticket:2547] + + - [feature] To complement [ticket:2547], types + can now provide "bind expressions" and + "column expressions" which allow compile-time + injection of SQL expressions into statements + on a per-column or per-bind level. This is + to suit the use case of a type which needs + to augment bind- and result- behavior at the + SQL level, as opposed to in the Python level. + Allows for schemes like transparent encryption/ + decryption, usage of Postgis functions, etc. + [ticket:1534] + - [feature] Revised the rules used to determine the operator precedence for the user-defined operator, i.e. that granted using the ``op()`` diff --git a/examples/postgis/postgis.py b/examples/postgis/postgis.py index 1239d66a42..77fcacba17 100644 --- a/examples/postgis/postgis.py +++ b/examples/postgis/postgis.py @@ -38,7 +38,6 @@ class TextualGisElement(GisElement, expression.Function): """ def __init__(self, desc, srid=-1): - assert isinstance(desc, basestring) self.desc = desc expression.Function.__init__(self, "ST_GeomFromText", desc, srid) @@ -63,21 +62,30 @@ class Geometry(UserDefinedType): # override the __eq__() operator def __eq__(self, other): - return self.op('~=')(_to_postgis(other)) + return self.op('~=')(other) # add a custom operator def intersects(self, other): - return self.op('&&')(_to_postgis(other)) + return self.op('&&')(other) # any number of GIS operators can be overridden/added here # using the techniques above. + def _coerce_compared_value(self, op, value): + return self + def get_col_spec(self): return self.name + def bind_expression(self, bindvalue): + return TextualGisElement(bindvalue) + + def column_expression(self, col): + return func.ST_AsText(col, type_=self) + def bind_processor(self, dialect): def process(value): - if value is not None: + if isinstance(value, GisElement): return value.desc else: return value @@ -165,8 +173,6 @@ def setup_ddl_events(): table.columns = table.info.pop('_saved_columns') setup_ddl_events() -# ORM integration - def _to_postgis(value): """Interpret a value as a GIS-compatible construct. @@ -188,17 +194,6 @@ def _to_postgis(value): else: raise Exception("Invalid type") -# without importing "orm", the "attribute_instrument" -# event isn't even set up. -from sqlalchemy import orm - -@event.listens_for(type, "attribute_instrument") -def attribute_instrument(cls, key, inst): - type_ = getattr(inst, "type", None) - if isinstance(type_, Geometry): - @event.listens_for(inst, "set", retval=True) - def set_value(state, value, oldvalue, initiator): - return _to_postgis(value) # illustrate usage @@ -245,7 +240,7 @@ if __name__ == '__main__': session.commit() # after flush and/or commit, all the TextualGisElements become PersistentGisElements. - assert str(r.road_geom) == "01020000000200000000000000B832084100000000E813104100000000283208410000000088601041" + assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)" r1 = session.query(Road).filter(Road.road_name=='Graeme Ave').one() diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index f107c9c8c1..1f119098b6 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -361,13 +361,6 @@ class AccessCompiler(compiler.SQLCompiler): """Access uses "mod" instead of "%" """ return binary.operator == '%' and 'mod' or binary.operator - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label() - else: - return super(AccessCompiler, self).\ - label_select_column(select, column, asfrom) - function_rewrites = {'current_date': 'now', 'current_timestamp': 'now', 'length': 'len', diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index ad6dcee54d..f7877a901f 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -278,15 +278,11 @@ class FBCompiler(sql.compiler.SQLCompiler): return "" def returning_clause(self, stmt, returning_cols): - columns = [ - self.process( - self.label_select_column(None, c, asfrom=False), - within_columns_clause=True, - result_map=self.result_map - ) + self._label_select_column(None, c, True, False, {}) for c in expression._select_iterables(returning_cols) ] + return 'RETURNING ' + ', '.join(columns) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 83f6346a78..0dd610788a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -849,7 +849,7 @@ class MSSQLCompiler(compiler.SQLCompiler): return ("ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)) - def visit_column(self, column, result_map=None, **kwargs): + def visit_column(self, column, add_to_result_map=None, **kwargs): if column.table is not None and \ (not self.isupdate and not self.isdelete) or self.is_subquery(): # translate for schema-qualified table aliases @@ -858,20 +858,19 @@ class MSSQLCompiler(compiler.SQLCompiler): converted = expression._corresponding_column_or_error( t, column) - if result_map is not None: - result_map[column.name + if add_to_result_map is not None: + self.result_map[column.name if self.dialect.case_sensitive else column.name.lower()] = \ - (column.name, (column, ), + (column.name, (column, ) + add_to_result_map, column.type) return super(MSSQLCompiler, self).\ visit_column(converted, result_map=None, **kwargs) - return super(MSSQLCompiler, self).visit_column(column, - result_map=result_map, - **kwargs) + return super(MSSQLCompiler, self).visit_column( + column, add_to_result_map=add_to_result_map, **kwargs) def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where @@ -898,21 +897,13 @@ class MSSQLCompiler(compiler.SQLCompiler): target = stmt.table.alias("deleted") adapter = sql_util.ClauseAdapter(target) - def col_label(col): - adapted = adapter.traverse(col) - if isinstance(col, expression.Label): - return adapted.label(c.key) - else: - return self.label_select_column(None, adapted, asfrom=False) columns = [ - self.process( - col_label(c), - within_columns_clause=True, - result_map=self.result_map - ) - for c in expression._select_iterables(returning_cols) - ] + self._label_select_column(None, adapter.traverse(c), + True, False, {}) + for c in expression._select_iterables(returning_cols) + ] + return 'OUTPUT ' + ', '.join(columns) def get_cte_preamble(self, recursive): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 36da14d333..d159649e06 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -691,10 +691,7 @@ class PGCompiler(compiler.SQLCompiler): def returning_clause(self, stmt, returning_cols): columns = [ - self.process( - self.label_select_column(None, c, asfrom=False), - within_columns_clause=True, - result_map=self.result_map) + self._label_select_column(None, c, True, False, {}) for c in expression._select_iterables(returning_cols) ] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 300cdb6b4d..f975225d60 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -172,6 +172,7 @@ class _CompileLabel(visitors.Visitable): def quote(self): return self.element.quote + class SQLCompiler(engine.Compiled): """Default implementation of Compiled. @@ -373,7 +374,8 @@ class SQLCompiler(engine.Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" - def visit_label(self, label, result_map=None, + def visit_label(self, label, + add_to_result_map = None, within_label_clause=False, within_columns_clause=False, **kw): # only render labels within the columns clause @@ -385,14 +387,18 @@ class SQLCompiler(engine.Compiled): else: labelname = label.name - if result_map is not None: - result_map[labelname + if add_to_result_map is not None: + self.result_map[ + labelname if self.dialect.case_sensitive - else labelname.lower()] = ( - label.name, - (label, label.element, labelname, ) + - label._alt_names, - label.type) + else labelname.lower() + ] = ( + label.name, + (label, label.element, labelname, ) + + label._alt_names + + add_to_result_map, + label.type, + ) return label.element._compiler_dispatch(self, within_columns_clause=True, @@ -405,7 +411,7 @@ class SQLCompiler(engine.Compiled): within_columns_clause=False, **kw) - def visit_column(self, column, result_map=None, **kwargs): + def visit_column(self, column, add_to_result_map=None, **kwargs): name = orig_name = column.name if name is None: raise exc.CompileError("Cannot compile Column object until " @@ -415,12 +421,16 @@ class SQLCompiler(engine.Compiled): if not is_literal and isinstance(name, sql._truncated_label): name = self._truncated_identifier("colident", name) - if result_map is not None: - result_map[name + if add_to_result_map is not None: + self.result_map[ + name if self.dialect.case_sensitive - else name.lower()] = (orig_name, - (column, name, column.key), - column.type) + else name.lower() + ] = ( + orig_name, + (column, name, column.key) + add_to_result_map, + column.type + ) if is_literal: name = self.escape_literal_column(name) @@ -527,7 +537,7 @@ class SQLCompiler(engine.Compiled): cast.typeclause._compiler_dispatch(self, **kwargs)) def visit_over(self, over, **kwargs): - x ="%s OVER (" % over.func._compiler_dispatch(self, **kwargs) + x = "%s OVER (" % over.func._compiler_dispatch(self, **kwargs) if over.partition_by is not None: x += "PARTITION BY %s" % \ over.partition_by._compiler_dispatch(self, **kwargs) @@ -544,12 +554,13 @@ class SQLCompiler(engine.Compiled): return "EXTRACT(%s FROM %s)" % (field, extract.expr._compiler_dispatch(self, **kwargs)) - def visit_function(self, func, result_map=None, **kwargs): - if result_map is not None: - result_map[func.name + def visit_function(self, func, add_to_result_map=None, **kwargs): + if add_to_result_map is not None: + self.result_map[ + func.name if self.dialect.case_sensitive - else func.name.lower()] = \ - (func.name, None, func.type) + else func.name.lower() + ] = (func.name, add_to_result_map, func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -557,14 +568,15 @@ class SQLCompiler(engine.Compiled): else: name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") return ".".join(list(func.packagenames) + [name]) % \ - {'expr':self.function_argspec(func, **kwargs)} + {'expr': self.function_argspec(func, **kwargs)} def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) def visit_sequence(self, sequence): raise NotImplementedError( - "Dialect '%s' does not support sequence increments." % self.dialect.name + "Dialect '%s' does not support sequence increments." % + self.dialect.name ) def function_argspec(self, func, **kwargs): @@ -704,7 +716,14 @@ class SQLCompiler(engine.Compiled): def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, **kwargs): + literal_binds=False, + skip_bind_expression=False, + **kwargs): + + if not skip_bind_expression and bindparam.type._has_bind_expression: + bind_expression = bindparam.type.bind_expression(bindparam) + return self.process(bind_expression, + skip_bind_expression=True) if literal_binds or \ (within_columns_clause and \ @@ -912,17 +931,31 @@ class SQLCompiler(engine.Compiled): else: return alias.original._compiler_dispatch(self, **kwargs) - def label_select_column(self, select, column, asfrom): - """label columns present in a select().""" + def _label_select_column(self, select, column, populate_result_map, + asfrom, column_clause_args): + """produce labeled columns present in a select().""" + + if column.type._has_column_expression: + col_expr = column.type.column_expression(column) + if populate_result_map: + add_to_result_map = (column, ) + else: + add_to_result_map = None + else: + col_expr = column + if populate_result_map: + add_to_result_map = () + else: + add_to_result_map = None - if isinstance(column, sql.Label): - return column + if isinstance(col_expr, sql.Label): + result_expr = col_expr elif select is not None and \ select.use_labels and \ column._label: - return _CompileLabel( - column, + result_expr = _CompileLabel( + col_expr, column._label, alt_names=(column._key_label, ) ) @@ -933,15 +966,25 @@ class SQLCompiler(engine.Compiled): not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): - return _CompileLabel(column, sql._as_truncated(column.name), - alt_names=(column.key,)) + result_expr = _CompileLabel(col_expr, + sql._as_truncated(column.name), + alt_names=(column.key,)) elif not isinstance(column, (sql.UnaryExpression, sql.TextClause)) \ and (not hasattr(column, 'name') or \ isinstance(column, sql.Function)): - return _CompileLabel(column, column.anon_label) + result_expr = _CompileLabel(col_expr, column.anon_label) + elif col_expr is not column: + result_expr = _CompileLabel(col_expr, column.anon_label) else: - return column + result_expr = col_expr + + return result_expr._compiler_dispatch( + self, within_columns_clause=True, + add_to_result_map=add_to_result_map, + **column_clause_args + ) + def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) @@ -976,24 +1019,21 @@ class SQLCompiler(engine.Compiled): # to outermost if existingfroms: correlate_froms = # correlate_froms.union(existingfroms) - self.stack.append({'from': correlate_froms, 'iswrapper' - : iswrapper}) + self.stack.append({'from': correlate_froms, + 'iswrapper': iswrapper}) - if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map, - 'positional_names':positional_names} - else: - column_clause_args = {'positional_names':positional_names} + populate_result_map = compound_index == 1 and not entry or \ + entry.get('iswrapper', False) + column_clause_args = {'positional_names': positional_names} # the actual list of columns to print in the SELECT column list. inner_columns = [ c for c in [ - self.label_select_column(select, co, asfrom=asfrom).\ - _compiler_dispatch(self, - within_columns_clause=True, - **column_clause_args) - for co in util.unique_list(select.inner_columns) - ] + self._label_select_column(select, column, + populate_result_map, asfrom, + column_clause_args) + for column in util.unique_list(select.inner_columns) + ] if c is not None ] @@ -1059,8 +1099,8 @@ class SQLCompiler(engine.Compiled): text += self.for_update_clause(select) if self.ctes and \ - compound_index==1 and not entry: - text = self._render_cte_clause() + text + compound_index == 1 and not entry: + text = self._render_cte_clause() + text self.stack.pop(-1) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index bbeebf5d36..ee262b56b3 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -25,7 +25,7 @@ import codecs from . import exc, schema, util, processors, events, event from .sql import operators -from .sql.expression import _DefaultColumnComparator +from .sql.expression import _DefaultColumnComparator, column, bindparam from .util import pickle from .util.compat import decimal from .sql.visitors import Visitable @@ -163,6 +163,40 @@ class TypeEngine(AbstractType): """ return None + def column_expression(self, colexpr): + """Given a SELECT column expression, return a wrapping SQL expression.""" + + return None + + @util.memoized_property + def _has_column_expression(self): + """memoized boolean, check if column_expression is implemented.""" + return self.column_expression(column('x')) is not None + + def bind_expression(self, bindvalue): + """"Given a bind value (i.e. a :class:`.BindParameter` instance), + return a SQL expression in its place. + + This is typically a SQL function that wraps the existing value + in a bind. It is used for special data types that require + literals being wrapped in some special database function in all + cases, such as Postgis GEOMETRY types. + + The method is evaluated at statement compile time, as opposed + to statement construction time. + + Note that this method, when implemented, should always return + the exact same structure, without any conditional logic, as it + will be used during executemany() calls as well. + + """ + return None + + @util.memoized_property + def _has_bind_expression(self): + """memoized boolean, check if bind_expression is implemented.""" + return self.bind_expression(bindparam('x')) is not None + def compare_values(self, x, y): """Compare two values for equality.""" diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 129d0bf06c..deff49f0fc 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -34,20 +34,30 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): cls.dialect = default.DefaultDialect() - @profiling.function_call_count(62) + @profiling.function_call_count() def test_insert(self): t1.insert().compile(dialect=self.dialect) - @profiling.function_call_count(56) + @profiling.function_call_count() def test_update(self): t1.update().compile(dialect=self.dialect) - @profiling.function_call_count(110) def test_update_whereclause(self): - t1.update().where(t1.c.c2==12).compile(dialect=self.dialect) + t1.update().where(t1.c.c2 == 12).compile(dialect=self.dialect) + + @profiling.function_call_count() + def go(): + t1.update().where(t1.c.c2 == 12).compile(dialect=self.dialect) + go() - @profiling.function_call_count(139) def test_select(self): - s = select([t1], t1.c.c2==t2.c.c1) + # give some of the cached type values + # a chance to warm up + s = select([t1], t1.c.c2 == t2.c.c1) s.compile(dialect=self.dialect) + @profiling.function_call_count() + def go(): + s = select([t1], t1.c.c2 == t2.c.c1) + s.compile(dialect=self.dialect) + go() \ No newline at end of file diff --git a/test/lib/profiles.txt b/test/lib/profiles.txt index edbee75e95..707ecb6219 100644 --- a/test/lib/profiles.txt +++ b/test/lib/profiles.txt @@ -1,15 +1,15 @@ -# /Users/classic/dev/sqla_comparators/./test/lib/profiles.txt +# /Users/classic/dev/sqlalchemy/./test/lib/profiles.txt # This file is written out on a per-environment basis. -# For each test in aaa_profiling, the corresponding function and +# For each test in aaa_profiling, the corresponding function and # environment is located within this file. If it doesn't exist, # the test is skipped. -# If a callcount does exist, it is compared to what we received. +# If a callcount does exist, it is compared to what we received. # assertions are raised if the counts do not match. -# -# To add a new callcount test, apply the function_call_count -# decorator and re-run the tests using the --write-profiles option - +# +# To add a new callcount test, apply the function_call_count +# decorator and re-run the tests using the --write-profiles option - # this file will be rewritten including the new count. -# +# # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert @@ -26,16 +26,12 @@ test.aaa_profiling.test_compiler.CompileTest.test_insert 3.2_sqlite_pysqlite_noc # TEST: test.aaa_profiling.test_compiler.CompileTest.test_select -test.aaa_profiling.test_compiler.CompileTest.test_select 2.5_sqlite_pysqlite_nocextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 2.6_sqlite_pysqlite_nocextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_mysql_mysqldb_cextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_mysql_mysqldb_nocextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_postgresql_psycopg2_cextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_postgresql_psycopg2_nocextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_sqlite_pysqlite_cextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_sqlite_pysqlite_nocextensions 149 -test.aaa_profiling.test_compiler.CompileTest.test_select 3.2_postgresql_psycopg2_nocextensions 161 -test.aaa_profiling.test_compiler.CompileTest.test_select 3.2_sqlite_pysqlite_nocextensions 161 +test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_postgresql_psycopg2_nocextensions 133 +test.aaa_profiling.test_compiler.CompileTest.test_select 2.7_sqlite_pysqlite_nocextensions 133 + +# TEST: test.aaa_profiling.test_compiler.CompileTest.test_select_second_time + +test.aaa_profiling.test_compiler.CompileTest.test_select_second_time 2.7_sqlite_pysqlite_nocextensions 133 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update @@ -195,6 +191,8 @@ test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute 3.2_ # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile 2.7_postgresql_psycopg2_nocextensions 14 +test.aaa_profiling.test_resultset.ResultSetTest.test_contains_doesnt_compile 2.7_sqlite_pysqlite_nocextensions 14 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_string @@ -236,9 +234,7 @@ test.aaa_profiling.test_zoomark.ZooMarkTest.test_profile_2_insert 3.2_postgresql # TEST: test.aaa_profiling.test_zoomark.ZooMarkTest.test_profile_3_properties -test.aaa_profiling.test_zoomark.ZooMarkTest.test_profile_3_properties 2.7_postgresql_psycopg2_cextensions 3093 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_profile_3_properties 2.7_postgresql_psycopg2_nocextensions 3317 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_profile_3_properties 3.2_postgresql_psycopg2_nocextensions 3094 +test.aaa_profiling.test_zoomark.ZooMarkTest.test_profile_3_properties 2.7_postgresql_psycopg2_nocextensions 3526 # TEST: test.aaa_profiling.test_zoomark.ZooMarkTest.test_profile_4_expressions diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 4be2d74f79..ac5c69e523 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -143,7 +143,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result2.fetchall(), [(2,False),]) class SequenceReturningTest(fixtures.TestBase): - __requires__ = 'returning', + __requires__ = 'returning', 'sequences' def setup(self): meta = MetaData(testing.db) diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py new file mode 100644 index 0000000000..d64ee7f0e4 --- /dev/null +++ b/test/sql/test_type_expressions.py @@ -0,0 +1,196 @@ +from sqlalchemy import Table, Column, String, func, MetaData, select +from test.lib import fixtures, AssertsCompiledSQL, testing +from test.lib.testing import eq_ + +class SelectTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'default' + + def _fixture(self): + class MyString(String): + def bind_expression(self, bindvalue): + return func.lower(bindvalue) + + def column_expression(self, col): + return func.lower(col) + + test_table = Table( + 'test_table', + MetaData(), Column('x', String), Column('y', MyString) + ) + return test_table + + def test_select_cols(self): + table = self._fixture() + + self.assert_compile( + select([table]), + "SELECT test_table.x, lower(test_table.y) AS y_1 FROM test_table" + ) + + def test_select_cols_use_labels(self): + table = self._fixture() + + self.assert_compile( + select([table]).apply_labels(), + "SELECT test_table.x AS test_table_x, " + "lower(test_table.y) AS test_table_y FROM test_table" + ) + + def test_select_cols_use_labels_result_map_targeting(self): + table = self._fixture() + + compiled = select([table]).apply_labels().compile() + assert table.c.y in compiled.result_map['test_table_y'][1] + assert table.c.x in compiled.result_map['test_table_x'][1] + + # the lower() function goes into the result_map, we don't really + # need this but it's fine + self.assert_compile( + compiled.result_map['test_table_y'][1][1], + "lower(test_table.y)" + ) + # then the original column gets put in there as well. + # it's not important that it's the last value. + self.assert_compile( + compiled.result_map['test_table_y'][1][-1], + "test_table.y" + ) + + def test_insert_binds(self): + table = self._fixture() + + self.assert_compile( + table.insert(), + "INSERT INTO test_table (x, y) VALUES (:x, lower(:y))" + ) + + self.assert_compile( + table.insert().values(y="hi"), + "INSERT INTO test_table (y) VALUES (lower(:y))" + ) + + def test_select_binds(self): + table = self._fixture() + self.assert_compile( + select([table]).where(table.c.y == "hi"), + "SELECT test_table.x, lower(test_table.y) AS y_1 FROM " + "test_table WHERE test_table.y = lower(:y_2)" + ) + +class RoundTripTest(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + class MyString(String): + def bind_expression(self, bindvalue): + return func.lower(bindvalue) + + def column_expression(self, col): + return func.upper(col) + + Table( + 'test_table', + metadata, + Column('x', String(50)), + Column('y', MyString(50) + ) + ) + + def test_round_trip(self): + testing.db.execute( + self.tables.test_table.insert(), + {"x": "X1", "y": "Y1"}, + {"x": "X2", "y": "Y2"}, + {"x": "X3", "y": "Y3"}, + ) + + # test insert coercion alone + eq_( + testing.db.execute( + "select * from test_table order by y").fetchall(), + [ + ("X1", "y1"), + ("X2", "y2"), + ("X3", "y3"), + ] + ) + + # conversion back to upper + eq_( + testing.db.execute( + select([self.tables.test_table]).\ + order_by(self.tables.test_table.c.y) + ).fetchall(), + [ + ("X1", "Y1"), + ("X2", "Y2"), + ("X3", "Y3"), + ] + ) + + def test_targeting_no_labels(self): + testing.db.execute( + self.tables.test_table.insert(), + {"x": "X1", "y": "Y1"}, + ) + row = testing.db.execute(select([self.tables.test_table])).first() + eq_( + row[self.tables.test_table.c.y], + "Y1" + ) + + def test_targeting_apply_labels(self): + testing.db.execute( + self.tables.test_table.insert(), + {"x": "X1", "y": "Y1"}, + ) + row = testing.db.execute(select([self.tables.test_table]). + apply_labels()).first() + eq_( + row[self.tables.test_table.c.y], + "Y1" + ) + + @testing.fails_if(lambda: True, "still need to propagate " + "result_map more effectively") + def test_targeting_individual_labels(self): + testing.db.execute( + self.tables.test_table.insert(), + {"x": "X1", "y": "Y1"}, + ) + row = testing.db.execute(select([ + self.tables.test_table.c.x.label('xbar'), + self.tables.test_table.c.y.label('ybar') + ])).first() + eq_( + row[self.tables.test_table.c.y], + "Y1" + ) + +class ReturningTest(fixtures.TablesTest): + __requires__ = 'returning', + + @classmethod + def define_tables(cls, metadata): + class MyString(String): + def column_expression(self, col): + return func.lower(col) + + Table( + 'test_table', + metadata, Column('x', String(50)), + Column('y', MyString(50), server_default="YVALUE") + ) + + @testing.provide_metadata + def test_insert_returning(self): + table = self.tables.test_table + result = testing.db.execute( + table.insert().returning(table.c.y), + {"x": "xvalue"} + ) + eq_( + result.first(), + ("yvalue",) + ) + + -- 2.47.3