From: Mike Bayer Date: Sun, 28 Nov 2010 16:10:41 +0000 (-0500) Subject: - Query.get() will raise if the number of params X-Git-Tag: rel_0_6_6~31^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=52167e1c37a6d704f01dc15bcb14ec2489979457;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Query.get() will raise if the number of params in a composite key is too large, as well as too small. [ticket:1977] - the above change smoked out an old mistake in a unit test. --- diff --git a/CHANGES b/CHANGES index 8e3d066662..12937f5ee4 100644 --- a/CHANGES +++ b/CHANGES @@ -43,6 +43,10 @@ CHANGES a placeholder flag for forwards compatibility, as it will be needed in 0.7 for composites. [ticket:1976] + + - Query.get() will raise if the number of params + in a composite key is too large, as well as too + small. [ticket:1977] - sql - Fixed operator precedence rules for multiple diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 2bccb8f73f..ef75efd767 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1919,6 +1919,12 @@ class Query(object): q = self._clone() if ident is not None: + if len(ident) != len(mapper.primary_key): + raise sa_exc.InvalidRequestError( + "Incorrect number of values in identifier to formulate " + "primary key for query.get(); primary key columns are %s" % + ','.join("'%s'" % c for c in mapper.primary_key)) + (_get_clause, _get_params) = mapper._get_clause # None present in ident - turn those comparisons @@ -1939,12 +1945,6 @@ class Query(object): for id_val, primary_key in zip(ident, mapper.primary_key) ]) - if len(params) != len(mapper.primary_key): - raise sa_exc.InvalidRequestError( - "Incorrect number of values in identifier to formulate " - "primary key for query.get(); primary key columns are %s" % - ','.join("'%s'" % c for c in mapper.primary_key)) - q._params = params if lockmode is not None: diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index c6dec16b7f..cc7dcba052 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -772,7 +772,7 @@ class DistinctPKTest(_base.MappedTest): mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id}) - self._do_test(True) + self._do_test(False) def test_explicit_composite_pk(self): person_mapper = mapper(Person, person_table) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index d96fa73846..aa95db5e78 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -138,15 +138,31 @@ class GetTest(QueryTest): u2 = s.query(User).get(7) assert u is not u2 - def test_get_composite_pk(self): - s = create_session() + def test_get_composite_pk_no_result(self): + s = Session() assert s.query(CompositePk).get((100,100)) is None + + def test_get_composite_pk_result(self): + s = Session() one_two = s.query(CompositePk).get((1,2)) assert one_two.i == 1 assert one_two.j == 2 assert one_two.k == 3 + + def test_get_too_few_params(self): + s = Session() + q = s.query(CompositePk) + assert_raises(sa_exc.InvalidRequestError, q.get, 7) + + def test_get_too_few_params_tuple(self): + s = Session() + q = s.query(CompositePk) + assert_raises(sa_exc.InvalidRequestError, q.get, (7,)) + + def test_get_too_many_params(self): + s = Session() q = s.query(CompositePk) - assert_raises(sa_exc.InvalidRequestError, q.get, 7) + assert_raises(sa_exc.InvalidRequestError, q.get, (7, 10, 100)) def test_get_null_pk(self): """test that a mapping which can have None in a @@ -214,8 +230,9 @@ class GetTest(QueryTest): @testing.requires.unicode_connections def test_unicode(self): - """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail - on postgresql, mysql and oracle unless it is converted to an encoded string""" + """test that Query.get properly sets up the type for the bind + parameter. using unicode would normally fail on postgresql, mysql and + oracle unless it is converted to an encoded string""" metadata = MetaData(engines.utf8_engine()) table = Table('unicode_data', metadata,