]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix to case() construct to propigate the type of the first
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Apr 2007 21:33:07 +0000 (21:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Apr 2007 21:33:07 +0000 (21:33 +0000)
WHEN condition as the return type of the case statement
- various unit test tweaks to get oracle working

CHANGES
lib/sqlalchemy/sql.py
test/orm/unitofwork.py
test/sql/query.py
test/sql/quote.py

diff --git a/CHANGES b/CHANGES
index 138277c130fa7b6dc4bf08c39760cae13217460b..d094071c1bb5dd7845020a231bf8073877fb0e2b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -66,6 +66,8 @@
       for **kwargs compat
     - slight tweak to raw execute() change to also support tuples
       for positional parameters, not just lists [ticket:523]
+    - fix to case() construct to propigate the type of the first
+      WHEN condition as the return type of the case statement
 - orm:
     - fixed critical issue when, after options(eagerload()) is used,
       the mapper would then always apply query "wrapping" behavior
index 1aa9ace9b06ba1ec82ca3b1f2aff9603eae69221..f4db66fdf3de2f9c97c91a5e1ed5be2f3d92cb83 100644 (file)
@@ -331,7 +331,11 @@ def case(whens, value=None, else_=None):
     whenlist = [_CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
     if not else_ is None:
         whenlist.append(_CompoundClause(None, 'ELSE', else_))
-    cc = _CalculatedClause(None, 'CASE', value, *whenlist + ['END'])
+    if len(whenlist):
+        type = list(whenlist[-1])[-1].type
+    else:
+        type = None
+    cc = _CalculatedClause(None, 'CASE', value, type=type, *whenlist + ['END'])
     for c in cc.clauses:
         c.parens = False
     return cc
@@ -1576,6 +1580,13 @@ class _TextClause(ClauseElement):
             for b in bindparams:
                 self.bindparams[b.key] = b
 
+    def _get_type(self):
+        if self.typemap is not None and len(self.typemap) == 1:
+            return list(self.typemap)[0]
+        else:
+            return None
+    type = property(_get_type)
+
     columns = property(lambda s:[])
 
     def get_children(self, **kwargs):
index a54daa41a8796206d70b269807a65f9d78940855..7657426776c5557567b6e58d60583df4fae725cb 100644 (file)
@@ -217,7 +217,7 @@ class MutableTypesTest(UnitOfWorkTest):
         global metadata, table
         metadata = BoundMetaData(testbase.db)
         table = Table('mutabletest', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, Sequence('mutableidseq', optional=True), primary_key=True),
             Column('data', PickleType),
             Column('value', Unicode(30)))
         table.create()
@@ -1223,7 +1223,7 @@ class ManyToManyTest(UnitOfWorkTest):
     def tearDown(self):
         tables.delete()
         UnitOfWorkTest.tearDown(self)
-        
+
     def testmanytomany(self):
         items = orderitems
 
@@ -1348,7 +1348,7 @@ class ManyToManyTest(UnitOfWorkTest):
                 
         mapper(Keyword, keywords)
         mapper(Item, orderitems, properties = dict(
-                keywords = relation(Keyword, secondary=itemkeywords, lazy=False),
+                keywords = relation(Keyword, secondary=itemkeywords, lazy=False, order_by=keywords.c.name),
             ))
 
         (k1, k2, k3) = (Keyword('keyword 1'), Keyword('keyword 2'), Keyword('keyword 3'))
@@ -1457,7 +1457,7 @@ class ManyToManyTest(UnitOfWorkTest):
         k.keyword_name = 'a keyword'
         ctx.current.flush()
         print m.instance_key(k)
-        id = (k.user_id, k.keyword_id)
+        id = (k.keyword_id, k.user_id)
         ctx.current.clear()
         k = ctx.current.query(KeywordUser).get(id)
         assert k.user_name == 'keyworduser'
index 928112c24d3e4e1c2474775c387b5ce95202f8d9..3ea9ec7ea594367d69a6f5b60928bca8abcdeff7 100644 (file)
@@ -136,16 +136,16 @@ class QueryTest(PersistTest):
         self.users.insert().execute(user_id = 7, user_name = 'jack')
         self.users.insert().execute(user_id = 8, user_name = 'fred')
 
-        u = bindparam('uid')
+        u = bindparam('userid')
         s = self.users.select(or_(self.users.c.user_name==u, self.users.c.user_name==u))
-        r = s.execute(uid='fred').fetchall()
+        r = s.execute(userid='fred').fetchall()
         assert len(r) == 1
     
     def test_bindparam_shortname(self):
         """test the 'shortname' field on BindParamClause."""
         self.users.insert().execute(user_id = 7, user_name = 'jack')
         self.users.insert().execute(user_id = 8, user_name = 'fred')
-        u = bindparam('uid', shortname='someshortname')
+        u = bindparam('userid', shortname='someshortname')
         s = self.users.select(self.users.c.user_name==u)
         r = s.execute(someshortname='fred').fetchall()
         assert len(r) == 1
@@ -223,12 +223,12 @@ class QueryTest(PersistTest):
     def test_keys(self):
         self.users.insert().execute(user_id=1, user_name='foo')
         r = self.users.select().execute().fetchone()
-        self.assertEqual(r.keys(), ['user_id', 'user_name'])
+        self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
 
     def test_items(self):
         self.users.insert().execute(user_id=1, user_name='foo')
         r = self.users.select().execute().fetchone()
-        self.assertEqual(r.items(), [('user_id', 1), ('user_name', 'foo')])
+        self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
 
     def test_len(self):
         self.users.insert().execute(user_id=1, user_name='foo')
@@ -269,11 +269,11 @@ class QueryTest(PersistTest):
         and that column-level defaults get overridden"""
         meta = BoundMetaData(testbase.db)
         t = Table('t1', meta,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True),
             Column('value', Integer)
         )
         t2 = Table('t2', meta,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True),
             Column('value', Integer, default="7"),
             Column('stuff', String(20), onupdate="thisisstuff")
         )
@@ -334,7 +334,7 @@ class QueryTest(PersistTest):
         r = self.users.select(self.users.c.user_id==1).execute().fetchone()
         self.assertEqual(r[0], 1)
         self.assertEqual(r[1], 'foo')
-        self.assertEqual(r.keys(), ['user_id', 'user_name'])
+        self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
         self.assertEqual(r.values(), [1, 'foo'])
         
     def test_column_order_with_text_query(self):
@@ -343,7 +343,7 @@ class QueryTest(PersistTest):
         r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone()
         self.assertEqual(r[0], 'foo')
         self.assertEqual(r[1], 1)
-        self.assertEqual(r.keys(), ['user_name', 'user_id'])
+        self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
         self.assertEqual(r.values(), ['foo', 1])
     
     @testbase.unsupported('oracle', 'firebird') 
@@ -384,18 +384,18 @@ class CompoundTest(PersistTest):
         global metadata, t1, t2, t3
         metadata = BoundMetaData(testbase.db)
         t1 = Table('t1', metadata, 
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, Sequence('t1pkseq'), primary_key=True),
             Column('col2', String(30)),
             Column('col3', String(40)),
             Column('col4', String(30))
             )
         t2 = Table('t2', metadata, 
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, Sequence('t2pkseq'), primary_key=True),
             Column('col2', String(30)),
             Column('col3', String(40)),
             Column('col4', String(30)))
         t3 = Table('t3', metadata, 
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, Sequence('t3pkseq'), primary_key=True),
             Column('col2', String(30)),
             Column('col3', String(40)),
             Column('col4', String(30)))
@@ -425,7 +425,7 @@ class CompoundTest(PersistTest):
                     select([t1.c.col3, t1.c.col4], t1.c.col2.in_("t1col2r1", "t1col2r2")),
             select([t2.c.col3, t2.c.col4], t2.c.col2.in_("t2col2r2", "t2col2r3"))
         )        
-        u = union(s1, s2)
+        u = union(s1, s2, order_by=[s1.c.col3])
         assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         
@@ -438,7 +438,7 @@ class CompoundTest(PersistTest):
         assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
 
-    @testbase.unsupported('mysql')
+    @testbase.unsupported('mysql', 'oracle')
     def test_except_style1(self):
         e = except_(union(
             select([t1.c.col3, t1.c.col4]),
@@ -447,7 +447,7 @@ class CompoundTest(PersistTest):
         parens=True), select([t2.c.col3, t2.c.col4]))
         assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
 
-    @testbase.unsupported('mysql')
+    @testbase.unsupported('mysql', 'oracle')
     def test_except_style2(self):
         e = except_(union(
             select([t1.c.col3, t1.c.col4]),
index bd8b206d683d5defb09cb17d6e499f97c3e96a8e..5259437fc748a03994268f59ce0fbd2b4a12a7db 100644 (file)
@@ -49,7 +49,7 @@ class QuoteTest(PersistTest):
         meta2 = BoundMetaData(testbase.db)
         t2 = Table('WorstCase2', meta2, autoload=True, quote=True)
         assert t2.c.has_key('MixedCase')
-    
+
     def testlabels(self):
         table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'a123':4},
                 {'lowercase':2,'UPPERCASE':2,'MixedCase':3,'a123':4},
@@ -77,7 +77,8 @@ class QuoteTest(PersistTest):
         assert lcmetadata.case_sensitive is False
         assert t1.c.UcCol.case_sensitive is False
         assert t2.c.normalcol.case_sensitive is False
-    
+   
+    @testbase.unsupported('oracle') 
     def testlabels(self):
         """test the quoting of labels.