]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added distinct positional dictionary arg to query.params(), fixes [ticket:690]
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 21:10:12 +0000 (21:10 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 21:10:12 +0000 (21:10 +0000)
lib/sqlalchemy/orm/query.py
test/orm/query.py
test/sql/unicode.py

index 4e53270b2dfa9a0bdb4f7c5f5c10b26f1636ca47..9040655b286f2e51c254fb370b7c5866b39c36dc 100644 (file)
@@ -249,10 +249,21 @@ class Query(object):
         q._lockmode = mode
         return q
 
-    def params(self, **kwargs):
-        """add values for bind parameters which may have been specified in filter()."""
+    def params(self, *args, **kwargs):
+        """add values for bind parameters which may have been specified in filter().
+        
+        parameters may be specified using \**kwargs, or optionally a single dictionary
+        as the first positional argument.  The reason for both is that \**kwargs is 
+        convenient, however some parameter dictionaries contain unicode keys in which case
+        \**kwargs cannot be used.
+        """
         
         q = self._clone()
+        if len(args) == 1:
+            d = args[0]
+            kwargs.update(d)
+        elif len(args) > 0:
+            raise exceptions.ArgumentError("params() takes zero or one positional argument, which is a dictionary.")
         q._params = q._params.copy()
         q._params.update(kwargs)
         return q
@@ -714,7 +725,7 @@ class Query(object):
             if lockmode is not None:
                 q = q.with_lockmode(lockmode)
             q = q.filter(self.select_mapper._get_clause)
-            q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
+            q = q.params(params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
             return q.first()
         except IndexError:
             return None
@@ -747,7 +758,7 @@ class Query(object):
         if whereclause is not None:
             q = q.filter(whereclause)
         if params is not None:
-            q = q.params(**params)
+            q = q.params(params)
         q = q._legacy_select_kwargs(**kwargs)
         return q._count()
 
@@ -950,7 +961,7 @@ class Query(object):
 
         q = self.filter(whereclause)._legacy_select_kwargs(**kwargs)
         if params is not None:
-            q = q.params(**params)
+            q = q.params(params)
         return list(q)
         
     def _legacy_select_kwargs(self, **kwargs): #pragma: no cover
@@ -1037,7 +1048,7 @@ class Query(object):
     def _select_statement(self, statement, params=None, **kwargs): #pragma: no cover
         q = self.from_statement(statement)
         if params is not None:
-            q = q.params(**params)
+            q = q.params(params)
         q._select_context_options(**kwargs)
         return list(q)
 
index 3783e1fa0c2aa47390c3d9b2f6be4833970e12e5..5d17d7d817e03f5ba4dcbdcb2073a94739cbb802 100644 (file)
@@ -39,6 +39,24 @@ class QueryTest(ORMTest):
         })
         mapper(Keyword, keywords)
 
+class UnicodeSchemaTest(QueryTest):
+    keep_mappers = False
+    
+    def setup_mappers(self):
+        pass
+        
+    def define_tables(self, metadata):
+        super(UnicodeSchemaTest, self).define_tables(metadata)
+        global uni_meta, uni_users
+        uni_meta = MetaData()
+        uni_users = Table(u'users', uni_meta,
+            Column(u'id', Integer, primary_key=True),
+            Column(u'name', String(30), nullable=False))
+            
+    def test_get(self):
+        mapper(User, uni_users)
+        assert User(id=7) == create_session(bind=testbase.db).query(User).get(7)
+        
 class GetTest(QueryTest):
     def test_get(self):
         s = create_session()
index f882c2a5f8a9f2c5671ec2cf437cd5705db27d11..34e3c19f849bc06682035a68a79b60e3f6bb7059 100644 (file)
@@ -101,5 +101,6 @@ class UnicodeSchemaTest(PersistTest):
         assert new_a1.a == a1.a
         assert new_a1.t2s[0].a == b1.a
         
+        
 if __name__ == '__main__':
     testbase.main()