From: Mike Bayer Date: Tue, 24 Apr 2007 21:33:07 +0000 (+0000) Subject: - fix to case() construct to propigate the type of the first X-Git-Tag: rel_0_3_7~36 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=49f633b7d11db5a36fa99d53c09620c323374569;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fix to case() construct to propigate the type of the first WHEN condition as the return type of the case statement - various unit test tweaks to get oracle working --- diff --git a/CHANGES b/CHANGES index 138277c130..d094071c1b 100644 --- a/CHANGES +++ b/CHANGES @@ -66,6 +66,8 @@ for **kwargs compat - slight tweak to raw execute() change to also support tuples for positional parameters, not just lists [ticket:523] + - fix to case() construct to propigate the type of the first + WHEN condition as the return type of the case statement - orm: - fixed critical issue when, after options(eagerload()) is used, the mapper would then always apply query "wrapping" behavior diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 1aa9ace9b0..f4db66fdf3 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -331,7 +331,11 @@ def case(whens, value=None, else_=None): whenlist = [_CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens] if not else_ is None: whenlist.append(_CompoundClause(None, 'ELSE', else_)) - cc = _CalculatedClause(None, 'CASE', value, *whenlist + ['END']) + if len(whenlist): + type = list(whenlist[-1])[-1].type + else: + type = None + cc = _CalculatedClause(None, 'CASE', value, type=type, *whenlist + ['END']) for c in cc.clauses: c.parens = False return cc @@ -1576,6 +1580,13 @@ class _TextClause(ClauseElement): for b in bindparams: self.bindparams[b.key] = b + def _get_type(self): + if self.typemap is not None and len(self.typemap) == 1: + return list(self.typemap)[0] + else: + return None + type = property(_get_type) + columns = property(lambda s:[]) def get_children(self, **kwargs): diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index a54daa41a8..7657426776 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -217,7 +217,7 @@ class MutableTypesTest(UnitOfWorkTest): global metadata, table metadata = BoundMetaData(testbase.db) table = Table('mutabletest', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('mutableidseq', optional=True), primary_key=True), Column('data', PickleType), Column('value', Unicode(30))) table.create() @@ -1223,7 +1223,7 @@ class ManyToManyTest(UnitOfWorkTest): def tearDown(self): tables.delete() UnitOfWorkTest.tearDown(self) - + def testmanytomany(self): items = orderitems @@ -1348,7 +1348,7 @@ class ManyToManyTest(UnitOfWorkTest): mapper(Keyword, keywords) mapper(Item, orderitems, properties = dict( - keywords = relation(Keyword, secondary=itemkeywords, lazy=False), + keywords = relation(Keyword, secondary=itemkeywords, lazy=False, order_by=keywords.c.name), )) (k1, k2, k3) = (Keyword('keyword 1'), Keyword('keyword 2'), Keyword('keyword 3')) @@ -1457,7 +1457,7 @@ class ManyToManyTest(UnitOfWorkTest): k.keyword_name = 'a keyword' ctx.current.flush() print m.instance_key(k) - id = (k.user_id, k.keyword_id) + id = (k.keyword_id, k.user_id) ctx.current.clear() k = ctx.current.query(KeywordUser).get(id) assert k.user_name == 'keyworduser' diff --git a/test/sql/query.py b/test/sql/query.py index 928112c24d..3ea9ec7ea5 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -136,16 +136,16 @@ class QueryTest(PersistTest): self.users.insert().execute(user_id = 7, user_name = 'jack') self.users.insert().execute(user_id = 8, user_name = 'fred') - u = bindparam('uid') + u = bindparam('userid') s = self.users.select(or_(self.users.c.user_name==u, self.users.c.user_name==u)) - r = s.execute(uid='fred').fetchall() + r = s.execute(userid='fred').fetchall() assert len(r) == 1 def test_bindparam_shortname(self): """test the 'shortname' field on BindParamClause.""" self.users.insert().execute(user_id = 7, user_name = 'jack') self.users.insert().execute(user_id = 8, user_name = 'fred') - u = bindparam('uid', shortname='someshortname') + u = bindparam('userid', shortname='someshortname') s = self.users.select(self.users.c.user_name==u) r = s.execute(someshortname='fred').fetchall() assert len(r) == 1 @@ -223,12 +223,12 @@ class QueryTest(PersistTest): def test_keys(self): self.users.insert().execute(user_id=1, user_name='foo') r = self.users.select().execute().fetchone() - self.assertEqual(r.keys(), ['user_id', 'user_name']) + self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) def test_items(self): self.users.insert().execute(user_id=1, user_name='foo') r = self.users.select().execute().fetchone() - self.assertEqual(r.items(), [('user_id', 1), ('user_name', 'foo')]) + self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) def test_len(self): self.users.insert().execute(user_id=1, user_name='foo') @@ -269,11 +269,11 @@ class QueryTest(PersistTest): and that column-level defaults get overridden""" meta = BoundMetaData(testbase.db) t = Table('t1', meta, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True), Column('value', Integer) ) t2 = Table('t2', meta, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True), Column('value', Integer, default="7"), Column('stuff', String(20), onupdate="thisisstuff") ) @@ -334,7 +334,7 @@ class QueryTest(PersistTest): r = self.users.select(self.users.c.user_id==1).execute().fetchone() self.assertEqual(r[0], 1) self.assertEqual(r[1], 'foo') - self.assertEqual(r.keys(), ['user_id', 'user_name']) + self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) self.assertEqual(r.values(), [1, 'foo']) def test_column_order_with_text_query(self): @@ -343,7 +343,7 @@ class QueryTest(PersistTest): r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone() self.assertEqual(r[0], 'foo') self.assertEqual(r[1], 1) - self.assertEqual(r.keys(), ['user_name', 'user_id']) + self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id']) self.assertEqual(r.values(), ['foo', 1]) @testbase.unsupported('oracle', 'firebird') @@ -384,18 +384,18 @@ class CompoundTest(PersistTest): global metadata, t1, t2, t3 metadata = BoundMetaData(testbase.db) t1 = Table('t1', metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, Sequence('t1pkseq'), primary_key=True), Column('col2', String(30)), Column('col3', String(40)), Column('col4', String(30)) ) t2 = Table('t2', metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, Sequence('t2pkseq'), primary_key=True), Column('col2', String(30)), Column('col3', String(40)), Column('col4', String(30))) t3 = Table('t3', metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, Sequence('t3pkseq'), primary_key=True), Column('col2', String(30)), Column('col3', String(40)), Column('col4', String(30))) @@ -425,7 +425,7 @@ class CompoundTest(PersistTest): select([t1.c.col3, t1.c.col4], t1.c.col2.in_("t1col2r1", "t1col2r2")), select([t2.c.col3, t2.c.col4], t2.c.col2.in_("t2col2r2", "t2col2r3")) ) - u = union(s1, s2) + u = union(s1, s2, order_by=[s1.c.col3]) assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] @@ -438,7 +438,7 @@ class CompoundTest(PersistTest): assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - @testbase.unsupported('mysql') + @testbase.unsupported('mysql', 'oracle') def test_except_style1(self): e = except_(union( select([t1.c.col3, t1.c.col4]), @@ -447,7 +447,7 @@ class CompoundTest(PersistTest): parens=True), select([t2.c.col3, t2.c.col4])) assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - @testbase.unsupported('mysql') + @testbase.unsupported('mysql', 'oracle') def test_except_style2(self): e = except_(union( select([t1.c.col3, t1.c.col4]), diff --git a/test/sql/quote.py b/test/sql/quote.py index bd8b206d68..5259437fc7 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -49,7 +49,7 @@ class QuoteTest(PersistTest): meta2 = BoundMetaData(testbase.db) t2 = Table('WorstCase2', meta2, autoload=True, quote=True) assert t2.c.has_key('MixedCase') - + def testlabels(self): table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'a123':4}, {'lowercase':2,'UPPERCASE':2,'MixedCase':3,'a123':4}, @@ -77,7 +77,8 @@ class QuoteTest(PersistTest): assert lcmetadata.case_sensitive is False assert t1.c.UcCol.case_sensitive is False assert t2.c.normalcol.case_sensitive is False - + + @testbase.unsupported('oracle') def testlabels(self): """test the quoting of labels.