From: Mike Bayer Date: Sat, 3 Mar 2007 21:02:26 +0000 (+0000) Subject: - bindparam() names are now repeatable! specify two X-Git-Tag: rel_0_3_6~44 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e736817a92797a3a3ce7b1c2cc9622643186f65b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - bindparam() names are now repeatable! specify two distinct bindparam()s with the same name in a single statement, and the key will be shared. proper positional/named args translate at compile time. for the old behavior of "aliasing" bind parameters with conflicting names, specify "unique=True" - this option is still used internally for all the auto-genererated (value-based) bind parameters. --- diff --git a/CHANGES b/CHANGES index 6323fcbeec..e554003e52 100644 --- a/CHANGES +++ b/CHANGES @@ -1,4 +1,11 @@ - sql: + - bindparam() names are now repeatable! specify two + distinct bindparam()s with the same name in a single statement, + and the key will be shared. proper positional/named args translate + at compile time. for the old behavior of "aliasing" bind parameters + with conflicting names, specify "unique=True" - this option is + still used internally for all the auto-genererated (value-based) + bind parameters. - exists() becomes useable as a standalone selectable, not just in a WHERE clause - correlated subqueries work inside of ORDER BY, GROUP BY @@ -15,7 +22,7 @@ 'duplicate' columns from the resulting column clause that are known to be equivalent based on the join condition. this is of great usage when constructing subqueries of joins which Postgres complains about if - duplicate column names are present. + duplicate column names are present. - orm: - a full select() construct can be passed to query.select() (which worked anyway), but also query.selectfirst(), query.selectone() which diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index f96bf7abef..19cde38628 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -10,7 +10,7 @@ Contains default implementations for the abstract objects in the sql module. """ -from sqlalchemy import schema, sql, engine, util, sql_util +from sqlalchemy import schema, sql, engine, util, sql_util, exceptions from sqlalchemy.engine import default import string, re, sets, weakref @@ -353,20 +353,27 @@ class ANSICompiler(sql.Compiled): def visit_bindparam(self, bindparam): if bindparam.shortname != bindparam.key: self.binds.setdefault(bindparam.shortname, bindparam) - count = 1 - key = bindparam.key - - # redefine the generated name of the bind param in the case - # that we have multiple conflicting bind parameters. - while self.binds.setdefault(key, bindparam) is not bindparam: - # ensure the name doesn't expand the length of the string - # in case we're at the edge of max identifier length - tag = "_%d" % count - key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag - count += 1 - bindparam.key = key - self.strings[bindparam] = self.bindparam_string(key) - + if bindparam.unique: + count = 1 + key = bindparam.key + + # redefine the generated name of the bind param in the case + # that we have multiple conflicting bind parameters. + while self.binds.setdefault(key, bindparam) is not bindparam: + # ensure the name doesn't expand the length of the string + # in case we're at the edge of max identifier length + tag = "_%d" % count + key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag + count += 1 + bindparam.key = key + self.strings[bindparam] = self.bindparam_string(key) + else: + existing = self.binds.get(bindparam.key) + if existing is not None and existing.unique: + raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + self.strings[bindparam] = self.bindparam_string(bindparam.key) + self.binds[bindparam.key] = bindparam + def bindparam_string(self, name): return self.bindtemplate % name @@ -702,7 +709,7 @@ class ANSICompiler(sql.Compiled): if parameters.has_key(c): value = parameters[c] if sql._is_literal(value): - value = sql.bindparam(c.key, value, type=c.type) + value = sql.bindparam(c.key, value, type=c.type, unique=True) values.append((c, value)) return values diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exceptions.py index e9d7d0c442..08908cdb60 100644 --- a/lib/sqlalchemy/exceptions.py +++ b/lib/sqlalchemy/exceptions.py @@ -34,6 +34,11 @@ class ArgumentError(SQLAlchemyError): pass +class CompileError(SQLAlchemyError): + """Raised when an error occurs during SQL compilation""" + + pass + class TimeoutError(SQLAlchemyError): """Raised when a connection pool times out on getting a connection.""" diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 55edf0f41a..10fab3ba3c 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1151,9 +1151,9 @@ class Mapper(object): mapper = table_to_mapper[table] clause = sql.and_() for col in mapper.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type=col.type)) + clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True)) if mapper.version_id_col is not None: - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type)) + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True)) statement = table.update(clause) rows = 0 supports_sane_rowcount = True @@ -1277,9 +1277,9 @@ class Mapper(object): delete.sort(comparator) clause = sql.and_() for col in self.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col.key, type=col.type)) + clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True)) if self.version_id_col is not None: - clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key, type=self.version_id_col.type)) + clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key, type=self.version_id_col.type, unique=True)) statement = table.delete(clause) c = connection.execute(statement, delete) if c.supports_sane_rowcount() and c.rowcount != len(delete): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index da1354c242..8df5628d15 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -32,7 +32,7 @@ class Query(object): if not hasattr(self.mapper, '_get_clause'): _get_clause = sql.and_() for primary_key in self.primary_key_columns: - _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type)) + _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True)) self.mapper._get_clause = _get_clause self._get_clause = self.mapper._get_clause for opt in util.flatten_iterator(self.with_options): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 115b53bfd0..8e19be5367 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -281,7 +281,7 @@ class LazyLoader(AbstractRelationLoader): if should_bind(leftcol, rightcol): col = leftcol binary.left = binds.setdefault(leftcol, - sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type)) + sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type, unique=True)) reverse[rightcol] = binds[col] # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1", @@ -289,7 +289,7 @@ class LazyLoader(AbstractRelationLoader): if leftcol is not rightcol and should_bind(rightcol, leftcol): col = rightcol binary.right = binds.setdefault(rightcol, - sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type)) + sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True)) reverse[leftcol] = binds[col] lazywhere = primaryjoin.copy_container() diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 9c8d5db08d..da1afe7992 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -308,7 +308,7 @@ def literal(value, type=None): for this literal. """ - return _BindParamClause('literal', value, type=type) + return _BindParamClause('literal', value, type=type, unique=True) def label(name, obj): """Return a ``_Label`` object for the given selectable, used in @@ -343,19 +343,30 @@ def table(name, *columns): return TableClause(name, *columns) -def bindparam(key, value=None, type=None, shortname=None): +def bindparam(key, value=None, type=None, shortname=None, unique=False): """Create a bind parameter clause with the given key. - An optional default value can be specified by the value parameter, - and the optional type parameter is a - ``sqlalchemy.types.TypeEngine`` object which indicates - bind-parameter and result-set translation for this bind parameter. + value + a default value for this bind parameter. a bindparam with a value + is called a ``value-based bindparam``. + + shortname + an ``alias`` for this bind parameter. usually used to alias the ``key`` and + ``label`` of a column, i.e. ``somecolname`` and ``sometable_somecolname`` + + type + a sqlalchemy.types.TypeEngine object indicating the type of this bind param, will + invoke type-specific bind parameter processing + + unique + if True, bind params sharing the same name will have their underlying ``key`` modified + to a uniquely generated name. mostly useful with value-based bind params. """ if isinstance(key, _ColumnClause): - return _BindParamClause(key.name, value, type=key.type, shortname=shortname) + return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique) else: - return _BindParamClause(key, value, type=type, shortname=shortname) + return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique) def text(text, engine=None, *args, **kwargs): """Create literal text to be inserted into a query. @@ -817,7 +828,7 @@ class _CompareMixin(object): return self._operate('/', other) def _bind_param(self, obj): - return _BindParamClause('literal', obj, shortname=None, type=self.type) + return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True) def _check_literal(self, other): if _is_literal(other): @@ -1120,7 +1131,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): Public constructor is the ``bindparam()`` function. """ - def __init__(self, key, value, shortname=None, type=None): + def __init__(self, key, value, shortname=None, type=None, unique=False): """Construct a _BindParamClause. key @@ -1144,15 +1155,21 @@ class _BindParamClause(ClauseElement, _CompareMixin): corresponding ``_BindParamClause`` objects. type - A ``TypeEngine`` object that will be used to pre-process the value corresponding to this ``_BindParamClause`` at execution time. + + unique + if True, the key name of this BindParamClause will be + modified if another ``_BindParamClause`` of the same + name already has been located within the containing + ``ClauseElement``. """ self.key = key self.value = value self.shortname = shortname or key + self.unique = unique self.type = sqltypes.to_instance(type) def accept_visitor(self, visitor): @@ -1162,7 +1179,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): return [] def copy_container(self): - return _BindParamClause(self.key, self.value, self.shortname, self.type) + return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique) def typeprocess(self, value, dialect): return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) @@ -1353,7 +1370,7 @@ class _CalculatedClause(ClauseList, ColumnElement): visitor.visit_calculatedclause(self) def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type=self.type) + return _BindParamClause(self.name, obj, type=self.type, unique=True) def select(self): return select([self]) @@ -1388,7 +1405,7 @@ class _Function(_CalculatedClause, FromClause): if clause is None: clause = null() else: - clause = _BindParamClause(self.name, clause, shortname=self.name, type=None) + clause = _BindParamClause(self.name, clause, shortname=self.name, type=None, unique=True) self.clauses.append(clause) def copy_container(self): @@ -1753,7 +1770,7 @@ class _ColumnClause(ColumnElement): return [] def _bind_param(self, obj): - return _BindParamClause(self._label, obj, shortname = self.name, type=self.type) + return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True) def _make_proxy(self, selectable, name = None): # propigate the "is_literal" flag only if we are keeping our name, @@ -2208,7 +2225,7 @@ class _UpdateBase(ClauseElement): else: col = key try: - parameters[key] = bindparam(col, value) + parameters[key] = bindparam(col, value, unique=True) except KeyError: del parameters[key] return parameters diff --git a/test/sql/select.py b/test/sql/select.py index a021bd5b99..b6f7699597 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -597,25 +597,87 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False)) def testbindparam(self): - for stmt, assertion in [ - ( - select( - [table1, table2], - and_(table1.c.myid == table2.c.otherid, - table1.c.name == bindparam('mytablename'))), - "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = :mytablename" - ) - ]: - - self.runtest(stmt, assertion) - + for ( + stmt, + expected_named_stmt, + expected_positional_stmt, + expected_default_params_dict, + expected_default_params_list, + test_param_dict, + expected_test_params_dict, + expected_test_params_list + ) in [ + ( + select( + [table1, table2], + and_( + table1.c.myid == table2.c.otherid, + table1.c.name == bindparam('mytablename') + )), + """SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = :mytablename""", + """SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = ?""", + {'mytablename':None}, [None], + {'mytablename':5}, {'mytablename':5}, [5] + ), + ( + select([table1], or_(table1.c.myid==bindparam('myid'), table2.c.otherid==bindparam('myid'))), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid", + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?", + {'myid':None}, [None, None], + {'myid':5}, {'myid':5}, [5,5] + ), + ( + text("SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid"), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid", + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?", + {'myid':None}, [None, None], + {'myid':5}, {'myid':5}, [5,5] + ), + ( + select([table1], or_(table1.c.myid==bindparam('myid', unique=True), table2.c.otherid==bindparam('myid', unique=True))), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :my_1", + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?", + {'myid':None, 'my_1':None}, [None, None], + {'myid':5, 'my_1': 6}, {'myid':5, 'my_1':6}, [5,6] + ), + ( + select([table1], or_(table1.c.myid==bindparam('myid', value=7, unique=True), table2.c.otherid==bindparam('myid', value=8, unique=True))), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :my_1", + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?", + {'myid':7, 'my_1':8}, [7,8], + {'myid':5, 'my_1':6}, {'myid':5, 'my_1':6}, [5,6] + ), + ][2:3]: + + self.runtest(stmt, expected_named_stmt, params=expected_default_params_dict) + self.runtest(stmt, expected_positional_stmt, dialect=sqlite.dialect()) + nonpositional = stmt.compile() + positional = stmt.compile(dialect=sqlite.dialect()) + assert positional.get_params().get_raw_list() == expected_default_params_list + assert nonpositional.get_params(**test_param_dict).get_raw_dict() == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict).get_raw_dict())) + assert positional.get_params(**test_param_dict).get_raw_list() == expected_test_params_list + + # check that conflicts with "unique" params are caught + s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('mytable_myid'))) + try: + str(s) + assert False + except exceptions.CompileError, err: + assert str(err) == "Bind parameter 'mytable_myid' conflicts with unique bind parameter of the same name" + + s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('mytable_my_1'))) + try: + str(s) + assert False + except exceptions.CompileError, err: + assert str(err) == "Bind parameter 'mytable_my_1' conflicts with unique bind parameter of the same name" + # check that the bind params sent along with a compile() call # get preserved when the params are retreived later s = select([table1], table1.c.myid == bindparam('test')) c = s.compile(parameters = {'test' : 7}) self.assert_(c.get_params() == {'test' : 7}) - def testin(self): self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_my_1, :mytable_my_2)")