]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
sqlite detects version and disables CAST if version < 3.2.3
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Jul 2006 00:36:32 +0000 (00:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Jul 2006 00:36:32 +0000 (00:36 +0000)
fixes to unittests, mapper extension to work better with setting/unsetting extensions
objectstore objects get 'session' attribute

lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/ext/activemapper.py
lib/sqlalchemy/mods/threadlocal.py
lib/sqlalchemy/orm/mapper.py
test/ext/activemapper.py
test/orm/manytomany.py

index d96c0eda1a25d39985c1b9e0624f1adbd1504224..da43ad1ea504b5dcb3428346832fdaf0355b86dc 100644 (file)
@@ -126,6 +126,9 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
             self._last_inserted_ids = [proxy().lastrowid]
     
 class SQLiteDialect(ansisql.ANSIDialect):
+    def __init__(self, **kwargs):
+        self.supports_cast = (sqlite.sqlite_version >= "3.2.3")
+        ansisql.ANSIDialect.__init__(self, **kwargs)
     def compiler(self, statement, bindparams, **kwargs):
         return SQLiteCompiler(self, statement, bindparams, **kwargs)
     def schemagenerator(self, *args, **kwargs):
@@ -222,6 +225,14 @@ class SQLiteDialect(ansisql.ANSIDialect):
                 table.columns[col]._set_primary_key()
                     
 class SQLiteCompiler(ansisql.ANSICompiler):
+    def visit_cast(self, cast):
+        if self.dialect.supports_cast:
+            super(SQLiteCompiler, self).visit_cast(cast)
+        else:
+            if len(self.select_stack):
+                # not sure if we want to set the typemap here...
+                self.typemap.setdefault("CAST", cast.type)
+            self.strings[cast] = self.strings[cast.clause]
     def limit_clause(self, select):
         text = ""
         if select.limit is not None:
index d21332e3ab3bcf82254394d8088d531286bda43f..957496e8e206fe3703a6ac61dc76c243cf0cf103 100644 (file)
@@ -25,6 +25,7 @@ except AttributeError:
             self.context = SessionContext(*args, **kwargs)
         def __getattr__(self, name):
             return getattr(self.context.current, name)
+        session = property(lambda s:s.context.current)
     objectstore = Objectstore(create_session)
 
 
index 760a37e810e6cc4934e5c5d9b00b82eb6e617223..87d7f271ce88e9c48fb877f01b5571b68af243a2 100644 (file)
@@ -27,7 +27,8 @@ class Objectstore(object):
         self.context = SessionContext(*args, **kwargs)
     def __getattr__(self, name):
         return getattr(self.context.current, name)
-
+    session = property(lambda s:s.context.current)
+    
 def assign_mapper(class_, *args, **kwargs):
     assignmapper.assign_mapper(objectstore.context, class_, *args, **kwargs)
 
index 92d10ad1039460784f39e0805bc37941dec46cf1..beccb16afae2db9b97761f780b81ed46a26cf299 100644 (file)
@@ -201,6 +201,7 @@ class Mapper(object):
         self.extension = None
         previous = None
         for ext in extlist:
+            ext.unchain()
             if self.extension is None:
                 self.extension = ext
             if previous is not None:
@@ -1164,6 +1165,8 @@ class MapperExtension(object):
     def chain(self, ext):
         self.next = ext
         return self
+    def unchain(self):
+        self.next = None
     def get_session(self):
         """called to retrieve a contextual Session instance with which to
         register a new object. Note: this is not called if a session is 
index 2a44f8e5bf8a1533097dcf5ffd0eba2500dbbe40..85466e29b6a750744ee67d00999bad4e9a945b1c 100644 (file)
@@ -61,7 +61,7 @@ class testcase(testbase.PersistTest):
 
         activemapper.metadata.connect(testbase.db)
         activemapper.create_tables()
-
+    
     def tearDownAll(self):
         clear_mappers()
         activemapper.drop_tables()
index 577903d472dd770eb1fa69d5ae099aa0302d3c88..7f1aa7ef7877a69829571f41ab4837d9c6c3a62b 100644 (file)
@@ -83,6 +83,8 @@ class M2MTest(testbase.AssertMixin):
         place_thingy.drop()
         place.drop()
         transition.drop()
+        objectstore.clear()
+        clear_mappers()
         #testbase.db.tables.clear()
         self.uninstall_threadlocal()
         
@@ -234,6 +236,8 @@ class M2MTest2(testbase.AssertMixin):
         enrolTbl.drop()
         studentTbl.drop()
         courseTbl.drop()
+        objectstore.clear()
+        clear_mappers()
         #testbase.db.tables.clear()
         self.uninstall_threadlocal()
         
@@ -311,6 +315,8 @@ class M2MTest3(testbase.AssertMixin):
         c2a1.drop()
         a.drop()
         c.drop()
+        objectstore.clear()
+        clear_mappers()
         #testbase.db.tables.clear()
         self.uninstall_threadlocal()