]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- bindparam() names are now repeatable! specify two
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Mar 2007 21:02:26 +0000 (21:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Mar 2007 21:02:26 +0000 (21:02 +0000)
distinct bindparam()s with the same name in a single statement,
and the key will be shared.  proper positional/named args translate
at compile time.  for the old behavior of "aliasing" bind parameters
with conflicting names, specify "unique=True" - this option is
still used internally for all the auto-genererated (value-based)
     bind parameters.

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/exceptions.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 6323fcbeeced005358625840ec20fcd1b5730abc..e554003e526c014369ae407e7ddc0ed2e5d4df60 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,4 +1,11 @@
 - sql:
+    - bindparam() names are now repeatable!  specify two
+     distinct bindparam()s with the same name in a single statement,
+     and the key will be shared.  proper positional/named args translate
+     at compile time.  for the old behavior of "aliasing" bind parameters
+     with conflicting names, specify "unique=True" - this option is
+     still used internally for all the auto-genererated (value-based) 
+     bind parameters.    
     - exists() becomes useable as a standalone selectable, not just in a 
     WHERE clause
     - correlated subqueries work inside of ORDER BY, GROUP BY
@@ -15,7 +22,7 @@
     'duplicate' columns from the resulting column clause that are known to be 
     equivalent based on the join condition.  this is of great usage when 
     constructing subqueries of joins which Postgres complains about if 
-    duplicate column names are present.
+    duplicate column names are present.    
 - orm:
     - a full select() construct can be passed to query.select() (which
       worked anyway), but also query.selectfirst(), query.selectone() which
index f96bf7abefa093d320fcec9c1fae784cc5cbabcf..19cde38628f3739564fb7b0977520531fb4db63d 100644 (file)
@@ -10,7 +10,7 @@ Contains default implementations for the abstract objects in the sql
 module.
 """
 
-from sqlalchemy import schema, sql, engine, util, sql_util
+from sqlalchemy import schema, sql, engine, util, sql_util, exceptions
 from  sqlalchemy.engine import default
 import string, re, sets, weakref
 
@@ -353,20 +353,27 @@ class ANSICompiler(sql.Compiled):
     def visit_bindparam(self, bindparam):
         if bindparam.shortname != bindparam.key:
             self.binds.setdefault(bindparam.shortname, bindparam)
-        count = 1
-        key = bindparam.key
-
-        # redefine the generated name of the bind param in the case
-        # that we have multiple conflicting bind parameters.
-        while self.binds.setdefault(key, bindparam) is not bindparam:
-            # ensure the name doesn't expand the length of the string
-            # in case we're at the edge of max identifier length
-            tag = "_%d" % count
-            key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag
-            count += 1
-        bindparam.key = key
-        self.strings[bindparam] = self.bindparam_string(key)
-
+        if bindparam.unique:
+            count = 1
+            key = bindparam.key
+
+            # redefine the generated name of the bind param in the case
+            # that we have multiple conflicting bind parameters.
+            while self.binds.setdefault(key, bindparam) is not bindparam:
+                # ensure the name doesn't expand the length of the string
+                # in case we're at the edge of max identifier length
+                tag = "_%d" % count
+                key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag
+                count += 1
+            bindparam.key = key
+            self.strings[bindparam] = self.bindparam_string(key)
+        else:
+            existing = self.binds.get(bindparam.key)
+            if existing is not None and existing.unique:
+                raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+            self.strings[bindparam] = self.bindparam_string(bindparam.key)
+            self.binds[bindparam.key] = bindparam
+            
     def bindparam_string(self, name):
         return self.bindtemplate % name
 
@@ -702,7 +709,7 @@ class ANSICompiler(sql.Compiled):
             if parameters.has_key(c):
                 value = parameters[c]
                 if sql._is_literal(value):
-                    value = sql.bindparam(c.key, value, type=c.type)
+                    value = sql.bindparam(c.key, value, type=c.type, unique=True)
                 values.append((c, value))
         return values
 
index e9d7d0c4427014d5307a0d9345de921956b33c53..08908cdb6043a814ab1f0c35d21502d3b4ac965a 100644 (file)
@@ -34,6 +34,11 @@ class ArgumentError(SQLAlchemyError):
 
     pass
 
+class CompileError(SQLAlchemyError):
+    """Raised when an error occurs during SQL compilation"""
+    
+    pass
+    
 class TimeoutError(SQLAlchemyError):
     """Raised when a connection pool times out on getting a connection."""
 
index 55edf0f41a92f81df370fbef1165560b4d6911b8..10fab3ba3cf41bd6689be0d13f8a67569cd39133 100644 (file)
@@ -1151,9 +1151,9 @@ class Mapper(object):
                 mapper = table_to_mapper[table]
                 clause = sql.and_()
                 for col in mapper.pks_by_table[table]:
-                    clause.clauses.append(col == sql.bindparam(col._label, type=col.type))
+                    clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True))
                 if mapper.version_id_col is not None:
-                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type))
+                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True))
                 statement = table.update(clause)
                 rows = 0
                 supports_sane_rowcount = True
@@ -1277,9 +1277,9 @@ class Mapper(object):
                 delete.sort(comparator)
                 clause = sql.and_()
                 for col in self.pks_by_table[table]:
-                    clause.clauses.append(col == sql.bindparam(col.key, type=col.type))
+                    clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True))
                 if self.version_id_col is not None:
-                    clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key, type=self.version_id_col.type))
+                    clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key, type=self.version_id_col.type, unique=True))
                 statement = table.delete(clause)
                 c = connection.execute(statement, delete)
                 if c.supports_sane_rowcount() and c.rowcount != len(delete):
index da1354c242d7631ab0f8c2528227ecff7ef01a4b..8df5628d15b46391205fcdb08a5cdefb76d9b4cd 100644 (file)
@@ -32,7 +32,7 @@ class Query(object):
         if not hasattr(self.mapper, '_get_clause'):
             _get_clause = sql.and_()
             for primary_key in self.primary_key_columns:
-                _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type))
+                _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
             self.mapper._get_clause = _get_clause
         self._get_clause = self.mapper._get_clause
         for opt in util.flatten_iterator(self.with_options):
index 115b53bfd0272bd611f5149659de57970a90a77f..8e19be5367e148366cd9c02dc49ed81f1b37230f 100644 (file)
@@ -281,7 +281,7 @@ class LazyLoader(AbstractRelationLoader):
             if should_bind(leftcol, rightcol):
                 col = leftcol
                 binary.left = binds.setdefault(leftcol,
-                        sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type))
+                        sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type, unique=True))
                 reverse[rightcol] = binds[col]
 
             # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
@@ -289,7 +289,7 @@ class LazyLoader(AbstractRelationLoader):
             if leftcol is not rightcol and should_bind(rightcol, leftcol):
                 col = rightcol
                 binary.right = binds.setdefault(rightcol,
-                        sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type))
+                        sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True))
                 reverse[leftcol] = binds[col]
 
         lazywhere = primaryjoin.copy_container()
index 9c8d5db08d56275516e9aacb2b81ef56e3950071..da1afe7992a19ebb90f372731192b7e6de5e78d1 100644 (file)
@@ -308,7 +308,7 @@ def literal(value, type=None):
     for this literal.
     """
 
-    return _BindParamClause('literal', value, type=type)
+    return _BindParamClause('literal', value, type=type, unique=True)
 
 def label(name, obj):
     """Return a ``_Label`` object for the given selectable, used in
@@ -343,19 +343,30 @@ def table(name, *columns):
 
     return TableClause(name, *columns)
 
-def bindparam(key, value=None, type=None, shortname=None):
+def bindparam(key, value=None, type=None, shortname=None, unique=False):
     """Create a bind parameter clause with the given key.
 
-    An optional default value can be specified by the value parameter,
-    and the optional type parameter is a
-    ``sqlalchemy.types.TypeEngine`` object which indicates
-    bind-parameter and result-set translation for this bind parameter.
+     value
+       a default value for this bind parameter.  a bindparam with a value
+       is called a ``value-based bindparam``.
+     
+     shortname
+        an ``alias`` for this bind parameter.  usually used to alias the ``key`` and 
+       ``label`` of a column, i.e. ``somecolname`` and ``sometable_somecolname``
+       
+     type
+       a sqlalchemy.types.TypeEngine object indicating the type of this bind param, will
+       invoke type-specific bind parameter processing
+     
+     unique
+       if True, bind params sharing the same name will have their underlying ``key`` modified
+       to a uniquely generated name.  mostly useful with value-based bind params.
     """
 
     if isinstance(key, _ColumnClause):
-        return _BindParamClause(key.name, value, type=key.type, shortname=shortname)
+        return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique)
     else:
-        return _BindParamClause(key, value, type=type, shortname=shortname)
+        return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique)
 
 def text(text, engine=None, *args, **kwargs):
     """Create literal text to be inserted into a query.
@@ -817,7 +828,7 @@ class _CompareMixin(object):
         return self._operate('/', other)
 
     def _bind_param(self, obj):
-        return _BindParamClause('literal', obj, shortname=None, type=self.type)
+        return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
 
     def _check_literal(self, other):
         if _is_literal(other):
@@ -1120,7 +1131,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
     Public constructor is the ``bindparam()`` function.
     """
 
-    def __init__(self, key, value, shortname=None, type=None):
+    def __init__(self, key, value, shortname=None, type=None, unique=False):
         """Construct a _BindParamClause.
 
         key
@@ -1144,15 +1155,21 @@ class _BindParamClause(ClauseElement, _CompareMixin):
           corresponding ``_BindParamClause`` objects.
 
         type
-
           A ``TypeEngine`` object that will be used to pre-process the
           value corresponding to this ``_BindParamClause`` at
           execution time.
+
+        unique
+          if True, the key name of this BindParamClause will be 
+          modified if another ``_BindParamClause`` of the same
+          name already has been located within the containing 
+          ``ClauseElement``.
         """
 
         self.key = key
         self.value = value
         self.shortname = shortname or key
+        self.unique = unique
         self.type = sqltypes.to_instance(type)
 
     def accept_visitor(self, visitor):
@@ -1162,7 +1179,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
         return []
 
     def copy_container(self):
-        return _BindParamClause(self.key, self.value, self.shortname, self.type)
+        return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique)
 
     def typeprocess(self, value, dialect):
         return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
@@ -1353,7 +1370,7 @@ class _CalculatedClause(ClauseList, ColumnElement):
         visitor.visit_calculatedclause(self)
 
     def _bind_param(self, obj):
-        return _BindParamClause(self.name, obj, type=self.type)
+        return _BindParamClause(self.name, obj, type=self.type, unique=True)
 
     def select(self):
         return select([self])
@@ -1388,7 +1405,7 @@ class _Function(_CalculatedClause, FromClause):
             if clause is None:
                 clause = null()
             else:
-                clause = _BindParamClause(self.name, clause, shortname=self.name, type=None)
+                clause = _BindParamClause(self.name, clause, shortname=self.name, type=None, unique=True)
         self.clauses.append(clause)
 
     def copy_container(self):
@@ -1753,7 +1770,7 @@ class _ColumnClause(ColumnElement):
             return []
 
     def _bind_param(self, obj):
-        return _BindParamClause(self._label, obj, shortname = self.name, type=self.type)
+        return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True)
 
     def _make_proxy(self, selectable, name = None):
         # propigate the "is_literal" flag only if we are keeping our name,
@@ -2208,7 +2225,7 @@ class _UpdateBase(ClauseElement):
                 else:
                     col = key
                 try:
-                    parameters[key] = bindparam(col, value)
+                    parameters[key] = bindparam(col, value, unique=True)
                 except KeyError:
                     del parameters[key]
         return parameters
index a021bd5b99bfbbba947089ff82eabeab4a2c3dbd..b6f7699597be909225437fed9c5238f9311f2024 100644 (file)
@@ -597,25 +597,87 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
         self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))    
 
     def testbindparam(self):
-        for stmt, assertion in [
-            (
-                select(
-                    [table1, table2],
-                    and_(table1.c.myid == table2.c.otherid,
-                    table1.c.name == bindparam('mytablename'))),
-                "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = :mytablename"
-            )
-       ]:
-
-            self.runtest(stmt, assertion)
-
+        for (
+             stmt,
+             expected_named_stmt,
+             expected_positional_stmt,
+             expected_default_params_dict, 
+             expected_default_params_list,
+             test_param_dict, 
+             expected_test_params_dict,
+             expected_test_params_list
+             ) in [
+              (
+                  select(
+                      [table1, table2],
+                     and_(
+                         table1.c.myid == table2.c.otherid,
+                         table1.c.name == bindparam('mytablename')
+                     )),
+                     """SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = :mytablename""",
+                     """SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = ?""",
+                 {'mytablename':None}, [None],
+                 {'mytablename':5}, {'mytablename':5}, [5]
+             ),
+             (
+                 select([table1], or_(table1.c.myid==bindparam('myid'), table2.c.otherid==bindparam('myid'))),
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid",
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?",
+                 {'myid':None}, [None, None],
+                 {'myid':5}, {'myid':5}, [5,5]
+             ),
+             (
+                 text("SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid"),
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid",
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?",
+                 {'myid':None}, [None, None],
+                 {'myid':5}, {'myid':5}, [5,5]
+             ),
+             (
+                 select([table1], or_(table1.c.myid==bindparam('myid', unique=True), table2.c.otherid==bindparam('myid', unique=True))),
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :my_1",
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?",
+                 {'myid':None, 'my_1':None}, [None, None],
+                 {'myid':5, 'my_1': 6}, {'myid':5, 'my_1':6}, [5,6]
+             ),
+             (
+                 select([table1], or_(table1.c.myid==bindparam('myid', value=7, unique=True), table2.c.otherid==bindparam('myid', value=8, unique=True))),
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :my_1",
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?",
+                 {'myid':7, 'my_1':8}, [7,8],
+                 {'myid':5, 'my_1':6}, {'myid':5, 'my_1':6}, [5,6]
+             ),
+             ][2:3]:
+             
+                self.runtest(stmt, expected_named_stmt, params=expected_default_params_dict)
+                self.runtest(stmt, expected_positional_stmt, dialect=sqlite.dialect())
+                nonpositional = stmt.compile()
+                positional = stmt.compile(dialect=sqlite.dialect())
+                assert positional.get_params().get_raw_list() == expected_default_params_list
+                assert nonpositional.get_params(**test_param_dict).get_raw_dict() == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict).get_raw_dict()))
+                assert positional.get_params(**test_param_dict).get_raw_list() == expected_test_params_list
+        
+        # check that conflicts with "unique" params are caught
+        s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('mytable_myid')))
+        try:
+            str(s)
+            assert False
+        except exceptions.CompileError, err:
+            assert str(err) == "Bind parameter 'mytable_myid' conflicts with unique bind parameter of the same name"
+
+        s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('mytable_my_1')))
+        try:
+            str(s)
+            assert False
+        except exceptions.CompileError, err:
+            assert str(err) == "Bind parameter 'mytable_my_1' conflicts with unique bind parameter of the same name"
+            
         # check that the bind params sent along with a compile() call
         # get preserved when the params are retreived later
         s = select([table1], table1.c.myid == bindparam('test'))
         c = s.compile(parameters = {'test' : 7})
         self.assert_(c.get_params() == {'test' : 7})
 
-
     def testin(self):
         self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_my_1, :mytable_my_2)")