From 95d2771c6fbe75d0232a092f7ff1d4cbb82ed2ac Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 19 Jul 2007 07:11:55 +0000 Subject: [PATCH] - all "type" keyword arguments, such as those to bindparam(), column(), Column(), and func.(), renamed to "type_". those objects still name their "type" attribute as "type". - new SQL operator implementation which removes all hardcoded operators from expression structures and moves them into compilation; allows greater flexibility of operator compilation; for example, "+" compiles to "||" when used in a string context, or "concat(a,b)" on MySQL; whereas in a numeric context it compiles to "+". fixes [ticket:475]. - major cruft cleanup in ANSICompiler regarding its processing of update/insert bind parameters. code is actually readable ! - a clause element embedded in an UPDATE, i.e. for a correlated update, uses standard "grouping" rules now to place parenthesis. Doesn't change much, except if you embed a text() clause in there, it will not be automatically parenthesized (place parens in the text() manually). --- CHANGES | 8 + lib/sqlalchemy/ansisql.py | 237 ++++++++--------- lib/sqlalchemy/databases/mysql.py | 15 +- lib/sqlalchemy/databases/oracle.py | 24 +- lib/sqlalchemy/databases/postgres.py | 30 +-- lib/sqlalchemy/databases/sqlite.py | 6 - lib/sqlalchemy/engine/default.py | 4 + lib/sqlalchemy/orm/dependency.py | 2 +- lib/sqlalchemy/orm/interfaces.py | 2 +- lib/sqlalchemy/orm/mapper.py | 18 +- lib/sqlalchemy/orm/properties.py | 4 +- lib/sqlalchemy/orm/strategies.py | 4 +- lib/sqlalchemy/orm/sync.py | 3 +- lib/sqlalchemy/schema.py | 6 +- lib/sqlalchemy/sql.py | 363 ++++++++++++++++----------- lib/sqlalchemy/types.py | 6 +- test/engine/bind.py | 16 +- test/engine/parseconnect.py | 1 - test/orm/inheritance/polymorph2.py | 10 +- test/orm/query.py | 13 +- test/sql/case_statement.py | 8 +- test/sql/defaults.py | 4 +- test/sql/query.py | 6 +- test/sql/select.py | 48 ++-- 24 files changed, 446 insertions(+), 392 deletions(-) diff --git a/CHANGES b/CHANGES index 235e77644a..8aa7f6dd69 100644 --- a/CHANGES +++ b/CHANGES @@ -87,6 +87,9 @@ style of Hibernate - sql + - all "type" keyword arguments, such as those to bindparam(), column(), + Column(), and func.(), renamed to "type_". those objects + still name their "type" attribute as "type". - transactions: - added context manager (with statement) support for transactions - added support for two phase commit, works with mysql and postgres so far. @@ -95,6 +98,11 @@ - MetaData: - DynamicMetaData has been renamed to ThreadLocalMetaData - BoundMetaData has been removed- regular MetaData is equivalent + - new SQL operator implementation which removes all hardcoded operators + from expression structures and moves them into compilation; + allows greater flexibility of operator compilation; for example, "+" + compiles to "||" when used in a string context, or "concat(a,b)" on + MySQL; whereas in a numeric context it compiles to "+". fixes [ticket:475]. - "anonymous" alias and label names are now generated at SQL compilation time in a completely deterministic fashion...no more random hex IDs - significant architectural overhaul to SQL elements (ClauseElement). diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 361fd7b1ea..24ee13e472 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -12,7 +12,7 @@ module. from sqlalchemy import schema, sql, engine, util, sql_util, exceptions from sqlalchemy.engine import default -import string, re, sets, weakref, random +import string, re, sets, random, operator ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', @@ -43,6 +43,38 @@ ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') BIND_PARAMS = re.compile(r'(?', + operator.ge : '>=', + operator.eq : '=', + sql.ColumnOperators.concat_op : '||', + sql.ColumnOperators.like_op : 'LIKE', + sql.ColumnOperators.notlike_op : 'NOT LIKE', + sql.ColumnOperators.ilike_op : 'ILIKE', + sql.ColumnOperators.notilike_op : 'NOT ILIKE', + sql.ColumnOperators.between_op : 'BETWEEN', + sql.ColumnOperators.in_op : 'IN', + sql.ColumnOperators.notin_op : 'NOT IN', + sql.ColumnOperators.comma_op : ', ', + sql.Operators.from_ : 'FROM', + sql.Operators.as_ : 'AS', + sql.Operators.exists : 'EXISTS', + sql.Operators.is_ : 'IS', + sql.Operators.isnot : 'IS NOT' +} + class ANSIDialect(default.DefaultDialect): def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) @@ -77,6 +109,8 @@ class ANSICompiler(engine.Compiled): __traverse_options__ = {'column_collections':False, 'entry':True} + operators = OPERATORS + def __init__(self, dialect, statement, parameters=None, **kwargs): """Construct a new ``ANSICompiler`` object. @@ -264,7 +298,7 @@ class ANSICompiler(engine.Compiled): if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname self.column_labels[label.name] = labelname - self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname) + self.strings[label] = " ".join([self.strings[label.obj], self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)]) def visit_column(self, column): # there is actually somewhat of a ruleset when you would *not* necessarily @@ -317,15 +351,15 @@ class ANSICompiler(engine.Compiled): def visit_null(self, null): self.strings[null] = 'NULL' - def visit_clauselist(self, list): - sep = list.operator - if sep == ',': - sep = ', ' - elif sep is None or sep == " ": + def visit_clauselist(self, clauselist): + sep = clauselist.operator + if sep is None: sep = " " + elif sep == sql.ColumnOperators.comma_op: + sep = ', ' else: - sep = " " + sep + " " - self.strings[list] = string.join([s for s in [self.strings[c] for c in list.clauses] if s is not None], sep) + sep = " " + self.operator_string(clauselist.operator) + " " + self.strings[clauselist] = string.join([s for s in [self.strings[c] for c in clauselist.clauses] if s is not None], sep) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 @@ -362,20 +396,20 @@ class ANSICompiler(engine.Compiled): def visit_unary(self, unary): s = self.strings[unary.element] if unary.operator: - s = unary.operator + " " + s + s = self.operator_string(unary.operator) + " " + s if unary.modifier: s = s + " " + unary.modifier self.strings[unary] = s def visit_binary(self, binary): - result = self.strings[binary.left] - if binary.operator is not None: - result += " " + self.binary_operator_string(binary) - result += " " + self.strings[binary.right] - self.strings[binary] = result - - def binary_operator_string(self, binary): - return binary.operator + op = self.operator_string(binary.operator) + if callable(op): + self.strings[binary] = op(binary.left, binary.right) + else: + self.strings[binary] = self.strings[binary.left] + " " + op + " " + self.strings[binary.right] + + def operator_string(self, operator): + return self.operators.get(operator, str(operator)) def visit_bindparam(self, bindparam): # apply truncation to the ultimate generated name @@ -610,151 +644,86 @@ class ANSICompiler(engine.Compiled): " ON " + self.strings[join.onclause]) self.strings[join] = self.froms[join] - def visit_insert_column_default(self, column, default, parameters): - """Called when visiting an ``Insert`` statement. - - For each column in the table that contains a ``ColumnDefault`` - object, add a blank *placeholder* parameter so the ``Insert`` - gets compiled with this column's name in its column and - ``VALUES`` clauses. - """ - - parameters.setdefault(column.key, None) - - def visit_update_column_default(self, column, default, parameters): - """Called when visiting an ``Update`` statement. - - For each column in the table that contains a ``ColumnDefault`` - object as an onupdate, add a blank *placeholder* parameter so - the ``Update`` gets compiled with this column's name as one of - its ``SET`` clauses. - """ - - parameters.setdefault(column.key, None) - - def visit_insert_sequence(self, column, sequence, parameters): - """Called when visiting an ``Insert`` statement. - - This may be overridden compilers that support sequences to - place a blank *placeholder* parameter for each column in the - table that contains a Sequence object, so the Insert gets - compiled with this column's name in its column and ``VALUES`` - clauses. - """ - - pass - - def visit_insert_column(self, column, parameters): - """Called when visiting an ``Insert`` statement. - - This may be overridden by compilers who disallow NULL columns - being set in an ``Insert`` where there is a default value on - the column (i.e. postgres), to remove the column for which - there is a NULL insert from the parameter list. - """ - - pass - + def uses_sequences_for_inserts(self): + return False + def visit_insert(self, insert_stmt): - # scan the table's columns for defaults that have to be pre-set for an INSERT - # add these columns to the parameter list via visit_insert_XXX methods - default_params = {} + + # search for columns who will be required to have an explicit bound value. + # for inserts, this includes Python-side defaults, columns with sequences for dialects + # that support sequences, and primary key columns for dialects that explicitly insert + # pre-generated primary key values + required_cols = util.Set() class DefaultVisitor(schema.SchemaVisitor): - def visit_column(s, c): - self.visit_insert_column(c, default_params) + def visit_column(s, cd): + if c.primary_key and self.uses_sequences_for_inserts(): + required_cols.add(c) def visit_column_default(s, cd): - self.visit_insert_column_default(c, cd, default_params) + required_cols.add(c) def visit_sequence(s, seq): - self.visit_insert_sequence(c, seq, default_params) + if self.uses_sequences_for_inserts(): + required_cols.add(c) vis = DefaultVisitor() for c in insert_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): vis.traverse(c) self.isinsert = True - colparams = self._get_colparams(insert_stmt, default_params) - - self.inline_params = util.Set() - def create_param(col, p): - if isinstance(p, sql._BindParamClause): - self.binds[p.key] = p - if p.shortname is not None: - self.binds[p.shortname] = p - return self.bindparam_string(self._truncate_bindparam(p)) - else: - self.inline_params.add(col) - self.traverse(p) - if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): - return "(" + self.strings[p] + ")" - else: - return self.strings[p] + colparams = self._get_colparams(insert_stmt, required_cols) text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + - " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")") + " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") self.strings[insert_stmt] = text def visit_update(self, update_stmt): - # scan the table's columns for onupdates that have to be pre-set for an UPDATE - # add these columns to the parameter list via visit_update_XXX methods - default_params = {} + + # search for columns who will be required to have an explicit bound value. + # for updates, this includes Python-side "onupdate" defaults. + required_cols = util.Set() class OnUpdateVisitor(schema.SchemaVisitor): def visit_column_onupdate(s, cd): - self.visit_update_column_default(c, cd, default_params) + required_cols.add(c) vis = OnUpdateVisitor() for c in update_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): vis.traverse(c) self.isupdate = True - colparams = self._get_colparams(update_stmt, default_params) - - self.inline_params = util.Set() - def create_param(col, p): - if isinstance(p, sql._BindParamClause): - self.binds[p.key] = p - self.binds[p.shortname] = p - return self.bindparam_string(self._truncate_bindparam(p)) - else: - self.traverse(p) - self.inline_params.add(col) - if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): - return "(" + self.strings[p] + ")" - else: - return self.strings[p] + colparams = self._get_colparams(update_stmt, required_cols) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ') + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') if update_stmt._whereclause: text += " WHERE " + self.strings[update_stmt._whereclause] self.strings[update_stmt] = text - - def _get_colparams(self, stmt, default_params): - """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples. - - Each tuple will contain the ``Column`` and a ``ClauseElement`` - representing the value to be set (usually a ``_BindParamClause``, - but could also be other SQL expressions.) - - The list of tuples will determine the columns that are - actually rendered into the ``SET``/``VALUES`` clause of the - rendered ``UPDATE``/``INSERT`` statement. It will also - determine how to generate the list/dictionary of bind - parameters at execution time (i.e. ``get_params()``). - - This list takes into account the `values` keyword specified - to the statement, the parameters sent to this Compiled - instance, and the default bind parameter values corresponding - to the dialect's behavior for otherwise unspecified primary - key columns. + def _get_colparams(self, stmt, required_cols): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + This method may generate new bind params within this compiled + based on the given set of "required columns", which are required + to have a value set in the statement. """ + def create_bind_param(col, value): + bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.parameters is None and stmt.parameters is None: - return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns] + return [(c, create_bind_param(c, None)) for c in stmt.table.columns] + + def create_clause_param(col, value): + self.traverse(value) + self.inline_params.add(col) + return self.strings[value] + + self.inline_params = util.Set() def to_col(key): if not isinstance(key, sql._ColumnClause): @@ -773,18 +742,20 @@ class ANSICompiler(engine.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(to_col(k), v) - for k, v in default_params.iteritems(): - parameters.setdefault(to_col(k), v) + for col in required_cols: + parameters.setdefault(col, None) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: - if parameters.has_key(c): + if c in parameters: value = parameters[c] if sql._is_literal(value): - value = sql.bindparam(c.key, value, type=c.type, unique=True) + value = create_bind_param(c, value) + else: + value = create_clause_param(c, value) values.append((c, value)) - + return values def visit_delete(self, delete_stmt): @@ -846,8 +817,6 @@ class ANSISchemaGenerator(ANSISchemaBase): for column in table.columns: if column.default is not None: self.traverse_single(column.default) - #if column.onupdate is not None: - # column.onupdate.accept_visitor(visitor) self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 3ff46e0942..d3f49544dc 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, datetime, inspect, warnings, weakref +import re, datetime, inspect, warnings, weakref, operator from sqlalchemy import sql, schema, ansisql from sqlalchemy.engine import default @@ -1284,6 +1284,14 @@ class _MySQLPythonRowProxy(object): class MySQLCompiler(ansisql.ANSICompiler): + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y), + operator.mod : '%%' + } + ) + def visit_cast(self, cast): if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): return super(MySQLCompiler, self).visit_cast(cast) @@ -1309,11 +1317,6 @@ class MySQLCompiler(ansisql.ANSICompiler): text += " OFFSET " + str(select._offset) return text - def binary_operator_string(self, binary): - if binary.operator == '%': - return '%%' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index a2b469a304..82388ef871 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, re, warnings +import sys, StringIO, string, re, warnings, operator from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging from sqlalchemy.engine import default, base @@ -460,6 +460,13 @@ class OracleCompiler(ansisql.ANSICompiler): the use_ansi flag is False. """ + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : lambda x, y:"mod(%s, %s)" % (x, y) + } + ) + def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. @@ -496,13 +503,8 @@ class OracleCompiler(ansisql.ANSICompiler): self.traverse_single(self.wheres[join]) - def visit_insert_sequence(self, column, sequence, parameters): - """This is the `sequence` equivalent to ``ANSICompiler``'s - `visit_insert_column_default` which ensures that the column is - present in the generated column list. - """ - - parameters.setdefault(column.key, None) + def uses_sequences_for_inserts(self): + return True def visit_alias(self, alias): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" @@ -571,12 +573,6 @@ class OracleCompiler(ansisql.ANSICompiler): else: return super(OracleCompiler, self).for_update_clause(select) - def visit_binary(self, binary): - if binary.operator == '%': - self.strings[binary] = ("MOD(%s,%s)"%(self.strings[binary.left], self.strings[binary.right])) - else: - return ansisql.ANSICompiler.visit_binary(self, binary) - class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 469614fbb3..80a56a3ca6 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, string, types, re, random, warnings +import datetime, string, types, re, random, warnings, operator from sqlalchemy import util, sql, schema, ansisql, exceptions from sqlalchemy.engine import base, default @@ -83,7 +83,7 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" -class PGArray(sqltypes.TypeEngine): +class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): def __init__(self, item_type): if isinstance(item_type, type): item_type = item_type() @@ -355,7 +355,7 @@ class PGDialect(ansisql.ANSIDialect): ORDER BY a.attnum """ % schema_where_clause - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) + s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) c = connection.execute(s, table_name=table.name, schema=table.schema) rows = c.fetchall() @@ -525,15 +525,15 @@ class PGDialect(ansisql.ANSIDialect): class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): - # all column primary key inserts must be explicitly present - if column.primary_key: - parameters[column.key] = None + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : '%%' + } + ) - def visit_insert_sequence(self, column, sequence, parameters): - """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures - that the column is present in the generated column list""" - parameters.setdefault(column.key, None) + def uses_sequences_for_inserts(self): + return True def limit_clause(self, select): text = "" @@ -565,14 +565,6 @@ class PGCompiler(ansisql.ANSICompiler): else: return super(PGCompiler, self).for_update_clause(select) - def binary_operator_string(self, binary): - if isinstance(binary.type, (sqltypes.String, PGArray)) and binary.operator == '+': - return '||' - elif binary.operator == '%': - return '%%' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) - class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 70cbd0c0e1..e7abc1f32b 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -347,12 +347,6 @@ class SQLiteCompiler(ansisql.ANSICompiler): # sqlite has no "FOR UPDATE" AFAICT return '' - def binary_operator_string(self, binary): - if isinstance(binary.type, sqltypes.String) and binary.operator == '+': - return '||' - else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) - class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 832b56f74a..075d51a538 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -296,6 +296,9 @@ class DefaultExecutionContext(base.ExecutionContext): statement. """ + # TODO: this calculation of defaults is one of the places SA slows down inserts. + # look into optimizing this for a list of params where theres no defaults defined + # (i.e. analyze the first batch of params). if self.compiled.isinsert: if isinstance(self.compiled_parameters, list): plist = self.compiled_parameters @@ -323,6 +326,7 @@ class DefaultExecutionContext(base.ExecutionContext): self._lastrow_has_defaults = True newid = drunner.get_column_default(c) if newid is not None: + print "GOT GENERATED DEFAULT", c, repr(newid) param.set_value(c.key, newid) if c.primary_key: last_inserted_ids.append(param.get_processed(c.key)) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index e8f3d4e245..c06db69631 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -366,7 +366,7 @@ class ManyToManyDP(DependencyProcessor): if len(secondary_delete): secondary_delete.sort() # TODO: precompile the delete/insert queries? - statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type=c.type) for c in self.secondary.c if c.key in associationrow])) + statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow])) result = connection.execute(statement, secondary_delete) if result.supports_sane_rowcount() and result.rowcount != len(secondary_delete): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (result.rowcount, len(secondary_delete))) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index f353575d90..e1209fabf9 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -337,7 +337,7 @@ class MapperProperty(object): return operator(self.comparator, value) -class PropComparator(sql.Comparator): +class PropComparator(sql.ColumnOperators): """defines comparison operations for MapperProperty objects""" def contains_op(a, b): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index f82f713bb8..eb69fb32c8 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.util import ExtensionCarrier from sqlalchemy.orm import sync from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, EXT_PASS, MapperExtension, SynonymProperty -import weakref, warnings +import weakref, warnings, operator __all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry'] @@ -587,7 +587,7 @@ class Mapper(object): _get_clause = sql.and_() for primary_key in self.primary_key: - _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True)) + _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True)) self._get_clause = _get_clause def _get_equivalent_columns(self): @@ -620,7 +620,7 @@ class Mapper(object): result = {} def visit_binary(binary): - if binary.operator == '=': + if binary.operator == operator.eq: if binary.left in result: result[binary.left].add(binary.right) else: @@ -1221,9 +1221,9 @@ class Mapper(object): mapper = table_to_mapper[table] clause = sql.and_() for col in mapper.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True)) + clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True)) if mapper.version_id_col is not None: - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True)) + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True)) statement = table.update(clause) rows = 0 supports_sane_rowcount = True @@ -1358,9 +1358,9 @@ class Mapper(object): delete.sort(comparator) clause = sql.and_() for col in mapper.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True)) + clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True)) if mapper.version_id_col is not None: - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type=mapper.version_id_col.type, unique=True)) + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True)) statement = table.delete(clause) c = connection.execute(statement, delete) if c.supports_sane_rowcount() and c.rowcount != len(delete): @@ -1567,10 +1567,10 @@ class Mapper(object): if leftcol is None or rightcol is None: return if leftcol.table not in needs_tables: - binary.left = sql.bindparam(leftcol.name, None, type=binary.right.type, unique=True) + binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True) param_names.append(leftcol) elif rightcol not in needs_tables: - binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True) + binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True) param_names.append(rightcol) cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True) return cond, param_names diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 7a3da1fdd1..99148cf614 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -384,7 +384,7 @@ class PropertyLoader(StrategizedProperty): if len(self.foreign_keys): self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return if binary.left in self.foreign_keys: self._opposite_side.add(binary.right) @@ -397,7 +397,7 @@ class PropertyLoader(StrategizedProperty): self.foreign_keys = util.Set() self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return # this check is for when the user put the "view_only" flag on and has tables that have nothing diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index fa86e450b2..c581b27c03 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -407,7 +407,7 @@ class LazyLoader(AbstractRelationLoader): if should_bind(leftcol, rightcol): col = leftcol binary.left = binds.setdefault(leftcol, - sql.bindparam(None, None, shortname=leftcol.name, type=binary.right.type, unique=True)) + sql.bindparam(None, None, shortname=leftcol.name, type_=binary.right.type, unique=True)) reverse[rightcol] = binds[col] # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1", @@ -415,7 +415,7 @@ class LazyLoader(AbstractRelationLoader): if leftcol is not rightcol and should_bind(rightcol, leftcol): col = rightcol binary.right = binds.setdefault(rightcol, - sql.bindparam(None, None, shortname=rightcol.name, type=binary.left.type, unique=True)) + sql.bindparam(None, None, shortname=rightcol.name, type_=binary.left.type, unique=True)) reverse[leftcol] = binds[col] lazywhere = primaryjoin diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 88fd980ad2..cf48202b0f 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -12,6 +12,7 @@ clause that compares column values. from sqlalchemy import sql, schema, exceptions from sqlalchemy import logging from sqlalchemy.orm import util as mapperutil +import operator ONETOMANY = 0 MANYTOONE = 1 @@ -42,7 +43,7 @@ class ClauseSynchronizer(object): def compile_binary(binary): """Assemble a SyncRule given a single binary condition.""" - if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return source_column = None diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index afb433e1e4..20160b0bf7 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -396,7 +396,7 @@ class Column(SchemaItem, sql._ColumnClause): ``TableClause``/``Table``. """ - def __init__(self, name, type, *args, **kwargs): + def __init__(self, name, type_, *args, **kwargs): """Construct a new ``Column`` object. Arguments are: @@ -405,7 +405,7 @@ class Column(SchemaItem, sql._ColumnClause): The name of this column. This should be the identical name as it appears, or will appear, in the database. - type + type_ The ``TypeEngine`` for this column. This can be any subclass of ``types.AbstractType``, including the database-agnostic types defined in the types module, @@ -495,7 +495,7 @@ class Column(SchemaItem, sql._ColumnClause): identifier contains mixed case. """ - super(Column, self).__init__(name, None, type) + super(Column, self).__init__(name, None, type_) self.args = args self.key = kwargs.pop('key', name) self._primary_key = kwargs.pop('primary_key', False) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index f38347fc40..672d085487 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -32,45 +32,12 @@ __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', - 'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete', + 'between', 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', 'insert', 'intersect', 'intersect_all', 'join', 'literal', 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select', 'subquery', 'table', 'text', 'union', 'union_all', 'update',] -# precedence ordering for common operators. if an operator is not present in this list, -# it will be parenthesized when grouped against other operators -PRECEDENCE = { - 'FROM':15, - '*':7, - '/':7, - '%':7, - '+':6, - '-':6, - 'ILIKE':5, - 'NOT ILIKE':5, - 'LIKE':5, - 'NOT LIKE':5, - 'IN':5, - 'NOT IN':5, - 'IS':5, - 'IS NOT':5, - '=':5, - '!=':5, - '>':5, - '<':5, - '>=':5, - '<=':5, - 'BETWEEN':5, - 'NOT':4, - 'AND':3, - 'OR':2, - ',':-1, - 'AS':-1, - 'EXISTS':0, - '_smallest': -1000, - '_largest': 1000 -} BIND_PARAMS = re.compile(r'(?'), - operator.le : (__compare, '<=', '>'), - operator.ne : (__compare, '!=', '='), - operator.gt : (__compare, '>', '<='), - operator.ge : (__compare, '>=', '<'), - operator.eq : (__compare, '=', '!='), - Comparator.like_op : (__compare, 'LIKE', 'NOT LIKE'), + operator.add : (__operate,), + operator.mul : (__operate,), + operator.sub : (__operate,), + operator.div : (__operate,), + operator.mod : (__operate,), + operator.truediv : (__operate,), + operator.lt : (__compare, operator.ge), + operator.le : (__compare, operator.gt), + operator.ne : (__compare, operator.eq), + operator.gt : (__compare, operator.le), + operator.ge : (__compare, operator.lt), + operator.eq : (__compare, operator.ne), + ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op), } def operate(self, op, other): o = _CompareMixin.operators[op] - return o[0](self, o[1], other, *o[2:]) + return o[0](self, op, other, *o[1:]) def reverse_operate(self, op, other): return self._bind_param(other).operate(op, self) def in_(self, *other): - """produce an ``IN`` clause.""" + return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other) + + def _in_impl(self, op, negate_op, *other): if len(other) == 0: return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1'))) elif len(other) == 1: @@ -1285,7 +1351,7 @@ class _CompareMixin(Comparator): return self.__eq__( o) #single item -> == else: assert hasattr( o, '_selectable') #better check? - return self.__compare( 'IN', o, negate='NOT IN') #single selectable + return self.__compare( op, o, negate=negate_op) #single selectable args = [] for o in other: @@ -1295,19 +1361,21 @@ class _CompareMixin(Comparator): else: o = self._bind_param(o) args.append(o) - return self.__compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN') + return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op) def startswith(self, other): """produce the clause ``LIKE '%'``""" - perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String) + + perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String) return self.__compare('LIKE', other + perc) def endswith(self, other): """produce the clause ``LIKE '%'``""" + if isinstance(other,(str,unicode)): po = '%' + other else: - po = literal('%', type= sqltypes.String) + other - po.type = sqltypes.to_instance( sqltypes.String) #force! + po = literal('%', type_=sqltypes.String) + other + po.type = sqltypes.to_instance(sqltypes.String) #force! return self.__compare('LIKE', po) def label(self, name): @@ -1320,7 +1388,7 @@ class _CompareMixin(Comparator): def between(self, cleft, cright): """produce a BETWEEN clause, i.e. `` BETWEEN AND ``""" - return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN') + return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), 'BETWEEN') def op(self, operator): """produce a generic operator function. @@ -1342,10 +1410,10 @@ class _CompareMixin(Comparator): return lambda other: self.__operate(operator, other) def _bind_param(self, obj): - return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True) + return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True) def _check_literal(self, other): - if isinstance(other, Comparator): + if isinstance(other, Operators): return other.clause_element() elif _is_literal(other): return self._bind_param(other) @@ -1764,7 +1832,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): __visit_name__ = 'bindparam' - def __init__(self, key, value, shortname=None, type=None, unique=False): + def __init__(self, key, value, shortname=None, type_=None, unique=False): """Construct a _BindParamClause. key @@ -1787,7 +1855,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): execution may match either the key or the shortname of the corresponding ``_BindParamClause`` objects. - type + type_ A ``TypeEngine`` object that will be used to pre-process the value corresponding to this ``_BindParamClause`` at execution time. @@ -1803,8 +1871,20 @@ class _BindParamClause(ClauseElement, _CompareMixin): self.value = value self.shortname = shortname or key self.unique = unique - self.type = sqltypes.to_instance(type) - + type_ = sqltypes.to_instance(type_) + if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map: + self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)]) + else: + self.type = type_ + + # TODO: move to types module, obviously + type_map = { + str : sqltypes.String, + unicode : sqltypes.Unicode, + int : sqltypes.Integer, + float : sqltypes.Numeric + } + def _get_from_objects(self, **modifiers): return [] @@ -1822,7 +1902,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ def __repr__(self): - return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type)) + return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type)) class _TypeClause(ClauseElement): """Handle a type keyword in a SQL statement. @@ -1907,10 +1987,9 @@ class ClauseList(ClauseElement): def __init__(self, *clauses, **kwargs): self.clauses = [] - self.operator = kwargs.pop('operator', ',') + self.operator = kwargs.pop('operator', ColumnOperators.comma_op) self.group = kwargs.pop('group', True) self.group_contents = kwargs.pop('group_contents', True) - self.negate_operator = kwargs.pop('negate', None) for c in clauses: if c is None: continue @@ -1932,14 +2011,6 @@ class ClauseList(ClauseElement): def _copy_internals(self): self.clauses = [clause._clone() for clause in self.clauses] - def _negate(self): - if hasattr(self, 'negation_clause'): - return self.negation_clause - elif self.negate_operator is None: - return super(ClauseList, self)._negate() - else: - return ClauseList(operator=self.negate_operator, negate=self.operator, *(not_(c) for c in self.clauses)) - def get_children(self, **kwargs): return self.clauses @@ -1950,7 +2021,7 @@ class ClauseList(ClauseElement): return f def self_group(self, against=None): - if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): + if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): return _Grouping(self) else: return self @@ -1981,7 +2052,7 @@ class _CalculatedClause(ColumnElement): def __init__(self, name, *clauses, **kwargs): self.name = name - self.type = sqltypes.to_instance(kwargs.get('type', None)) + self.type = sqltypes.to_instance(kwargs.get('type_', None)) self._bind = kwargs.get('bind', None) self.group = kwargs.pop('group', True) self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) @@ -2002,7 +2073,7 @@ class _CalculatedClause(ColumnElement): return self.clauses._get_from_objects(**modifiers) def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type=self.type, unique=True) + return _BindParamClause(self.name, obj, type_=self.type, unique=True) def select(self): return select([self]) @@ -2024,10 +2095,8 @@ class _Function(_CalculatedClause, FromClause): """ def __init__(self, name, *clauses, **kwargs): - self.type = sqltypes.to_instance(kwargs.get('type', None)) self.packagenames = kwargs.get('packagenames', None) or [] - kwargs['operator'] = ',' - self._bind = kwargs.get('bind', None) + kwargs['operator'] = ColumnOperators.comma_op _CalculatedClause.__init__(self, name, **kwargs) for c in clauses: self.append(c) @@ -2065,7 +2134,7 @@ class _Cast(ColumnElement): def _make_proxy(self, selectable, name=None): if name is not None: - co = _ColumnClause(name, selectable, type=self.type) + co = _ColumnClause(name, selectable, type_=self.type) co._distance = self._distance + 1 co.orig_set = self.orig_set selectable.columns[name]= co @@ -2075,12 +2144,12 @@ class _Cast(ColumnElement): class _UnaryExpression(ColumnElement): - def __init__(self, element, operator=None, modifier=None, type=None, negate=None): + def __init__(self, element, operator=None, modifier=None, type_=None, negate=None): self.operator = operator self.modifier = modifier self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier) - self.type = sqltypes.to_instance(type) + self.type = sqltypes.to_instance(type_) self.negate = negate def _get_from_objects(self, **modifiers): @@ -2103,12 +2172,12 @@ class _UnaryExpression(ColumnElement): def _negate(self): if self.negate is not None: - return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type) + return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type) else: return super(_UnaryExpression, self)._negate() def self_group(self, against): - if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): + if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): return _Grouping(self) else: return self @@ -2117,11 +2186,11 @@ class _UnaryExpression(ColumnElement): class _BinaryExpression(ColumnElement): """Represent an expression that is ``LEFT RIGHT``.""" - def __init__(self, left, right, operator, type=None, negate=None): + def __init__(self, left, right, operator, type_=None, negate=None): self.left = _literal_as_text(left).self_group(against=operator) self.right = _literal_as_text(right).self_group(against=operator) self.operator = operator - self.type = sqltypes.to_instance(type) + self.type = sqltypes.to_instance(type_) self.negate = negate def _get_from_objects(self, **modifiers): @@ -2142,7 +2211,7 @@ class _BinaryExpression(ColumnElement): ( self.left.compare(other.left) and self.right.compare(other.right) or ( - self.operator in ['=', '!=', '+', '*'] and + self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and self.left.compare(other.right) and self.right.compare(other.left) ) ) @@ -2150,14 +2219,14 @@ class _BinaryExpression(ColumnElement): def self_group(self, against=None): # use small/large defaults for comparison so that unknown operators are always parenthesized - if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])): + if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])): return _Grouping(self) else: return self def _negate(self): if self.negate is not None: - return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) + return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type) else: return super(_BinaryExpression, self)._negate() @@ -2167,7 +2236,7 @@ class _Exists(_UnaryExpression): def __init__(self, *args, **kwargs): kwargs['correlate'] = True s = select(*args, **kwargs).self_group() - _UnaryExpression.__init__(self, s, operator="EXISTS") + _UnaryExpression.__init__(self, s, operator=Operators.exists) def _hide_froms(self, **modifiers): return self._get_from_objects(**modifiers) @@ -2208,7 +2277,7 @@ class Join(FromClause): class BinaryVisitor(ClauseVisitor): def visit_binary(self, binary): - if binary.operator == '=': + if binary.operator == operator.eq: add_equiv(binary.left, binary.right) BinaryVisitor().traverse(self.onclause) @@ -2290,7 +2359,7 @@ class Join(FromClause): equivs = util.Set() class LocateEquivs(NoColumnVisitor): def visit_binary(self, binary): - if binary.operator == '=' and binary.left.name == binary.right.name: + if binary.operator == operator.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) LocateEquivs().traverse(self.onclause) @@ -2463,14 +2532,14 @@ class _Label(ColumnElement): """ - def __init__(self, name, obj, type=None): + def __init__(self, name, obj, type_=None): while isinstance(obj, _Label): obj = obj.obj self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - self.obj = obj.self_group(against='AS') + self.obj = obj.self_group(against=Operators.as_) self.case_sensitive = getattr(obj, "case_sensitive", True) - self.type = sqltypes.to_instance(type or getattr(obj, 'type', None)) + self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) key = property(lambda s: s.name) _label = property(lambda s: s.name) @@ -2528,11 +2597,11 @@ class _ColumnClause(ColumnElement): """ - def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False): + def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False): self.key = self.name = text self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name self.table = selectable - self.type = sqltypes.to_instance(type) + self.type = sqltypes.to_instance(type_) self._is_oid = _is_oid self._distance = 0 self.__label = None @@ -2586,13 +2655,13 @@ class _ColumnClause(ColumnElement): return [] def _bind_param(self, obj): - return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True) + return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True) def _make_proxy(self, selectable, name = None): # propigate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) - c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal) + c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) c.orig_set = self.orig_set c._distance = self._distance + 1 if not self._is_oid: @@ -3050,7 +3119,7 @@ class Select(_SelectBaseMixin, FromClause): column = literal_column(str(column)) if isinstance(column, Select) and column.is_scalar: - column = column.self_group(against=',') + column = column.self_group(against=ColumnOperators.comma_op) self._raw_columns.append(column) @@ -3191,7 +3260,7 @@ class _UpdateBase(ClauseElement): for key in parameters.keys(): value = parameters[key] if isinstance(value, ClauseElement): - pass + parameters[key] = value.self_group() elif _is_literal(value): if _is_literal(key): col = self.table.c[key] diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index ddaf990e7f..6e59ac16e5 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -192,7 +192,11 @@ class NullType(TypeEngine): return value NullTypeEngine = NullType -class String(TypeEngine): +class Concatenable(object): + """marks a type as supporting 'concatenation'""" + pass + +class String(TypeEngine, Concatenable): def __init__(self, length=None, convert_unicode=False): self.length = length self.convert_unicode = convert_unicode diff --git a/test/engine/bind.py b/test/engine/bind.py index 321493329a..17b243d256 100644 --- a/test/engine/bind.py +++ b/test/engine/bind.py @@ -149,9 +149,10 @@ class BindTest(testbase.PersistTest): assert False except exceptions.InvalidRequestError, e: assert str(e) == "This Compiled object is not bound to any Engine or Connection." - + finally: - bind.close() + if isinstance(bind, engine.Connection): + bind.close() metadata.drop_all(bind=testbase.db) def test_session(self): @@ -165,7 +166,9 @@ class BindTest(testbase.PersistTest): mapper(Foo, table) metadata.create_all(bind=testbase.db) try: - for bind in (testbase.db, testbase.db.connect()): + for bind in (testbase.db, + testbase.db.connect() + ): for args in ({'bind':bind},): sess = create_session(**args) assert sess.bind is bind @@ -173,6 +176,9 @@ class BindTest(testbase.PersistTest): sess.save(f) sess.flush() assert sess.get(Foo, f.foo) is f + + if isinstance(bind, engine.Connection): + bind.close() sess = create_session() f = Foo() @@ -182,9 +188,11 @@ class BindTest(testbase.PersistTest): assert False except exceptions.InvalidRequestError, e: assert str(e).startswith("Could not locate any Engine or Connection bound to mapper") + finally: - bind.close() + if isinstance(bind, engine.Connection): + bind.close() metadata.drop_all(bind=testbase.db) diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py index f393b9f7d0..eb4f95619d 100644 --- a/test/engine/parseconnect.py +++ b/test/engine/parseconnect.py @@ -116,7 +116,6 @@ class CreateEngineTest(PersistTest): except TypeError: assert True - e = create_engine('sqlite://', echo=True) e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True) e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py index 141b3abc27..fe7d77985a 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/polymorph2.py @@ -737,18 +737,18 @@ class MultiLevelTest(testbase.ORMTest): def define_tables(self, metadata): global table_Employee, table_Engineer, table_Manager table_Employee = Table( 'Employee', metadata, - Column( 'name', type= String(100), ), - Column( 'id', primary_key= True, type= Integer, ), - Column( 'atype', type= String(100), ), + Column( 'name', type_= String(100), ), + Column( 'id', primary_key= True, type_= Integer, ), + Column( 'atype', type_= String(100), ), ) table_Engineer = Table( 'Engineer', metadata, - Column( 'machine', type= String(100), ), + Column( 'machine', type_= String(100), ), Column( 'id', Integer, ForeignKey( 'Employee.id', ), primary_key= True, ), ) table_Manager = Table( 'Manager', metadata, - Column( 'duties', type= String(100), ), + Column( 'duties', type_= String(100), ), Column( 'id', Integer, ForeignKey( 'Engineer.id', ), primary_key= True, ), ) def test_threelevels(self): diff --git a/test/orm/query.py b/test/orm/query.py index 026f808081..df4187eb4d 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -112,14 +112,14 @@ class OperatorTest(QueryTest): (operator.sub, '-'), (operator.div, '/'), ): for (lhs, rhs, res) in ( - ('a', User.id, ':users_id %s users.id'), - ('a', literal('b'), ':literal %s :literal_1'), - (User.id, 'b', 'users.id %s :users_id'), + (5, User.id, ':users_id %s users.id'), + (5, literal(6), ':literal %s :literal_1'), + (User.id, 5, 'users.id %s :users_id'), (User.id, literal('b'), 'users.id %s :literal'), (User.id, User.id, 'users.id %s users.id'), - (literal('a'), 'b', ':literal %s :literal_1'), - (literal('a'), User.id, ':literal %s users.id'), - (literal('a'), literal('b'), ':literal %s :literal_1'), + (literal(5), 'b', ':literal %s :literal_1'), + (literal(5), User.id, ':literal %s users.id'), + (literal(5), literal(6), ':literal %s :literal_1'), ): self._test(py_op(lhs, rhs), res % sql_op) @@ -503,7 +503,6 @@ class InstancesTest(QueryTest): l = q.add_column("count").from_statement(s).all() assert l == expected - @testbase.unsupported('mysql') # only because of "+" operator requiring "concat" in mysql (fix #475) def test_two_columns(self): sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index 5a42317d7f..bcf8849644 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -28,9 +28,9 @@ class CaseTest(testbase.PersistTest): def testcase(self): inner = select([case([ [info_table.c.pk < 3, - literal('lessthan3', type=String)], + literal('lessthan3', type_=String)], [and_(info_table.c.pk >= 3, info_table.c.pk < 7), - literal('gt3', type=String)]]).label('x'), + literal('gt3', type_=String)]]).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -67,9 +67,9 @@ class CaseTest(testbase.PersistTest): w_else = select([case([ [info_table.c.pk < 3, - literal(3, type=Integer)], + literal(3, type_=Integer)], [and_(info_table.c.pk >= 3, info_table.c.pk < 6), - literal(6, type=Integer)]], + literal(6, type_=Integer)]], else_ = 0).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 07363a402e..a9dd2f5ad2 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -25,7 +25,7 @@ class DefaultTest(PersistTest): # select "count(1)" returns different results on different DBs # also correct for "current_date" compatible as column default, value differences - currenttime = func.current_date(type=Date, bind=db); + currenttime = func.current_date(type_=Date, bind=db); if is_oracle: ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar() f = select([func.count(1) + 5], bind=db).scalar() @@ -230,7 +230,7 @@ class SequenceTest(PersistTest): ) sometable = Table( 'Manager', metadata, Column( 'obj_id', Integer, Sequence('obj_id_seq'), ), - Column( 'name', type= String, ), + Column( 'name', String, ), Column( 'id', Integer, primary_key= True, ), ) diff --git a/test/sql/query.py b/test/sql/query.py index 772ffa793a..05b6d0419d 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -281,6 +281,10 @@ class QueryTest(PersistTest): y = testbase.db.func.current_date().select().execute().scalar() z = testbase.db.func.current_date().scalar() assert x == y == z + + x = testbase.db.func.current_date(type_=Date) + assert isinstance(x.type, Date) + assert isinstance(x.execute().scalar(), datetime.date) def test_conn_functions(self): conn = testbase.db.connect() @@ -351,7 +355,7 @@ class QueryTest(PersistTest): w = select(['*'], from_obj=[testbase.db.func.current_date()]).scalar() # construct a column-based FROM object out of a function, like in [ticket:172] - s = select([column('date', type=DateTime)], from_obj=[testbase.db.func.current_date()]) + s = select([column('date', type_=DateTime)], from_obj=[testbase.db.func.current_date()]) q = s.execute().fetchone()[s.c.date] r = s.alias('datequery').select().scalar() diff --git a/test/sql/select.py b/test/sql/select.py index d5b00e1dab..3d5996df9d 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -11,21 +11,21 @@ import unittest, re, operator # so SQLAlchemy's SQL construction engine can be used with no database dependencies at all. table1 = table('mytable', - column('myid'), - column('name'), - column('description'), + column('myid', Integer), + column('name', String), + column('description', String), ) table2 = table( 'myothertable', - column('otherid'), - column('othername'), + column('otherid', Integer), + column('othername', String), ) table3 = table( 'thirdtable', - column('userid'), - column('otherstuff'), + column('userid', Integer), + column('otherstuff', String), ) metadata = MetaData() @@ -273,14 +273,14 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A (operator.sub, '-'), (operator.div, '/'), ): for (lhs, rhs, res) in ( - ('a', table1.c.myid, ':mytable_myid %s mytable.myid'), - ('a', literal('b'), ':literal %s :literal_1'), + (5, table1.c.myid, ':mytable_myid %s mytable.myid'), + (5, literal(5), ':literal %s :literal_1'), (table1.c.myid, 'b', 'mytable.myid %s :mytable_myid'), - (table1.c.myid, literal('b'), 'mytable.myid %s :literal'), + (table1.c.myid, literal(2.7), 'mytable.myid %s :literal'), (table1.c.myid, table1.c.myid, 'mytable.myid %s mytable.myid'), - (literal('a'), 'b', ':literal %s :literal_1'), - (literal('a'), table1.c.myid, ':literal %s mytable.myid'), - (literal('a'), literal('b'), ':literal %s :literal_1'), + (literal(5), 8, ':literal %s :literal_1'), + (literal(6), table1.c.myid, ':literal %s mytable.myid'), + (literal(7), literal(5.5), ':literal %s :literal_1'), ): self.runtest(py_op(lhs, rhs), res % sql_op) @@ -328,7 +328,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) self.runtest( - literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2" + literal("a") + literal("b") * literal("c"), ":literal || :literal_1 * :literal_2" ) # test the op() function, also that its results are further usable in expressions @@ -540,7 +540,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today def testliteral(self): self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]), - "SELECT :literal + :literal_1 FROM mytable") + "SELECT :literal || :literal_1 FROM mytable") def testcalculatedcolumns(self): value_tbl = table('values', @@ -866,16 +866,16 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo self.runtest(select([table1], table1.c.myid.in_('a', literal('b'))), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)") - self.runtest(select([table1], table1.c.myid.in_(literal('a') + 'a')), + self.runtest(select([table1], table1.c.myid.in_(literal(1) + 'a')), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1") self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :mytable_myid)") + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :mytable_myid)") self.runtest(select([table1], table1.c.myid.in_(literal('a') + literal('a'), literal('b'))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :literal_2)") + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :literal_2)") - self.runtest(select([table1], table1.c.myid.in_('a', literal('b') +'b')), + self.runtest(select([table1], table1.c.myid.in_(1, literal(3) + 4)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal + :literal_1)") self.runtest(select([table1], table1.c.myid.in_(literal('a') < 'b')), @@ -893,7 +893,7 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo self.runtest(select([table1], table1.c.myid.in_(literal('a'), table1.c.myid +'a')), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, mytable.myid + :mytable_myid)") - self.runtest(select([table1], table1.c.myid.in_(literal('a'), 'a' + table1.c.myid)), + self.runtest(select([table1], table1.c.myid.in_(literal(1), 'a' + table1.c.myid)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :mytable_myid + mytable.myid)") self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)), @@ -1040,12 +1040,16 @@ class CRUDTest(SQLTest): values = { table1.c.name : table1.c.name + "lala", table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho')) - }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1") + }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=(mytable.name || :mytable_name) WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal || mytable.name || :literal_1") def testcorrelatedupdate(self): # test against a straight text subquery - u = update(table1, values = {table1.c.name : text("select name from mytable where id=mytable.id")}) + u = update(table1, values = {table1.c.name : text("(select name from mytable where id=mytable.id)")}) self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") + + mt = table1.alias() + u = update(table1, values = {table1.c.name : select([mt.c.name], mt.c.myid==table1.c.myid)}) + self.runtest(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)") # test against a regular constructed subquery s = select([table2], table2.c.otherid == table1.c.myid) -- 2.47.3