]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query.get() will raise if the number of params
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Nov 2010 16:10:41 +0000 (11:10 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Nov 2010 16:10:41 +0000 (11:10 -0500)
in a composite key is too large, as well as too
small. [ticket:1977]
- the above change smoked out an old mistake in a unit test.

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/inheritance/test_basic.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index 8e3d0666629a0286096b7843909c9a8689e1f5d0..12937f5ee4ca8f2c2639026dbd20704c112a1fe3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -43,6 +43,10 @@ CHANGES
     a placeholder flag for forwards compatibility,
     as it will be needed in 0.7 for composites.
     [ticket:1976]
+
+  - Query.get() will raise if the number of params
+    in a composite key is too large, as well as too 
+    small. [ticket:1977]
     
 - sql
   - Fixed operator precedence rules for multiple
index 2bccb8f73fadc45e6b8c4b16fd3007f482a5b5a8..ef75efd767cefacd8eadb34c6ef343568ea1360b 100644 (file)
@@ -1919,6 +1919,12 @@ class Query(object):
             q = self._clone()
 
         if ident is not None:
+            if len(ident) != len(mapper.primary_key):
+                raise sa_exc.InvalidRequestError(
+                "Incorrect number of values in identifier to formulate "
+                "primary key for query.get(); primary key columns are %s" %
+                ','.join("'%s'" % c for c in mapper.primary_key))
+
             (_get_clause, _get_params) = mapper._get_clause
             
             # None present in ident - turn those comparisons
@@ -1939,12 +1945,6 @@ class Query(object):
                 for id_val, primary_key in zip(ident, mapper.primary_key)
             ])
 
-            if len(params) != len(mapper.primary_key):
-                raise sa_exc.InvalidRequestError(
-                "Incorrect number of values in identifier to formulate "
-                "primary key for query.get(); primary key columns are %s" %
-                ','.join("'%s'" % c for c in mapper.primary_key))
-                        
             q._params = params
 
         if lockmode is not None:
index c6dec16b7f0e5e565ee99295ab157358b8b4b660..cc7dcba052f2fe627f16b85eaa713b3bb8563808 100644 (file)
@@ -772,7 +772,7 @@ class DistinctPKTest(_base.MappedTest):
         mapper(Employee, employee_table, inherits=person_mapper,
                         properties={'pid':person_table.c.id, 
                                     'eid':employee_table.c.id})
-        self._do_test(True)
+        self._do_test(False)
 
     def test_explicit_composite_pk(self):
         person_mapper = mapper(Person, person_table)
index d96fa7384615b8630e567d39aa75873065b30698..aa95db5e780199b09609bf58ab04aa03f8defea3 100644 (file)
@@ -138,15 +138,31 @@ class GetTest(QueryTest):
         u2 = s.query(User).get(7)
         assert u is not u2
 
-    def test_get_composite_pk(self):
-        s = create_session()
+    def test_get_composite_pk_no_result(self):
+        s = Session()
         assert s.query(CompositePk).get((100,100)) is None
+        
+    def test_get_composite_pk_result(self):
+        s = Session()
         one_two = s.query(CompositePk).get((1,2))
         assert one_two.i == 1
         assert one_two.j == 2
         assert one_two.k == 3
+    
+    def test_get_too_few_params(self):
+        s = Session()
+        q = s.query(CompositePk)
+        assert_raises(sa_exc.InvalidRequestError, q.get, 7)
+
+    def test_get_too_few_params_tuple(self):
+        s = Session()
+        q = s.query(CompositePk)
+        assert_raises(sa_exc.InvalidRequestError, q.get, (7,))
+
+    def test_get_too_many_params(self):
+        s = Session()
         q = s.query(CompositePk)
-        assert_raises(sa_exc.InvalidRequestError, q.get, 7)        
+        assert_raises(sa_exc.InvalidRequestError, q.get, (7, 10, 100))
     
     def test_get_null_pk(self):
         """test that a mapping which can have None in a 
@@ -214,8 +230,9 @@ class GetTest(QueryTest):
 
     @testing.requires.unicode_connections
     def test_unicode(self):
-        """test that Query.get properly sets up the type for the bind parameter.  using unicode would normally fail
-        on postgresql, mysql and oracle unless it is converted to an encoded string"""
+        """test that Query.get properly sets up the type for the bind
+        parameter. using unicode would normally fail on postgresql, mysql and
+        oracle unless it is converted to an encoded string"""
 
         metadata = MetaData(engines.utf8_engine())
         table = Table('unicode_data', metadata,