]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- turned on auto-returning for oracle, some errors
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 15:26:43 +0000 (15:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 15:26:43 +0000 (15:26 +0000)
- added make_transient() [ticket:1052]
- ongoing refactor of compiler _get_colparams()  (more to come)

06CHANGES
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/compiler.py
test/orm/test_session.py

index 476867ad08d19169ab742fd45ef4fd4ffa03c025..9ea5ec9f22bbb64a0674e8d419606c1116a6c154 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
       "load=False".
     - many-to-one relations now fire off a lazyload in fewer cases, including
       in most cases will not fetch the "old" value when a new one is replaced.
-
+    - added "make_transient()" helper function which transforms a persistent/
+      detached instance into a transient one (i.e. deletes the instance_key
+      and removes from any session.) [ticket:1052]
+      
 - sql
     - returning() support is native to insert(), update(), delete().  Implementations
       of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
index 419ccedb1644622aedccb0de4b6892679087b44f..9ba01610157ac35731d737932b2c68d90beb8a21 100644 (file)
@@ -492,11 +492,10 @@ class OracleDialect(default.DefaultDialect):
         self.use_ansi = use_ansi
         self.optimize_limits = optimize_limits
 
-# TODO: implement server_version_info for oracle
-#    def initialize(self, connection):
-#        super(OracleDialect, self).initialize(connection)
-#        self.implicit_returning = self.server_version_info > (10, ) and \
-#                                        self.__dict__.get('implicit_returning', True)
+    def initialize(self, connection):
+        super(OracleDialect, self).initialize(connection)
+        self.implicit_returning = self.server_version_info > (10, ) and \
+                                        self.__dict__.get('implicit_returning', True)
 
     def do_release_savepoint(self, connection, name):
         # Oracle does not support RELEASE SAVEPOINT
index 3c39316da39a3abcd38924460187cd75a76dad3c..eeb2e1fd98b2e48037fdc910a08b33695c595ece 100644 (file)
@@ -52,7 +52,7 @@ from sqlalchemy.orm import strategies
 from sqlalchemy.orm.query import AliasOption, Query
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.orm.session import Session as _Session
-from sqlalchemy.orm.session import object_session, sessionmaker
+from sqlalchemy.orm.session import object_session, sessionmaker, make_transient
 from sqlalchemy.orm.scoping import ScopedSession
 from sqlalchemy import util as sa_util
 
@@ -86,6 +86,7 @@ __all__ = (
     'join',
     'lazyload',
     'mapper',
+    'make_transient',
     'noload',
     'object_mapper',
     'object_session',
index d3d653de4f6c051e1c158a0813fa465451c7917e..fdf679a4246592ca7f0522c2a036ea017e7ac369 100644 (file)
@@ -1593,6 +1593,22 @@ def _state_for_unknown_persistence_instance(instance):
 
     return state
 
+def make_transient(instance):
+    """Make the given instance 'transient'.
+    
+    This will remove its association with any 
+    session and additionally will remove its "identity key",
+    such that it's as though the object were newly constructed,
+    except retaining its values.
+    
+    """
+    state = attributes.instance_state(instance)
+    s = _state_session(state)
+    if s:
+        s._expunge_state(state)
+    del state.key
+    
+    
 def object_session(instance):
     """Return the ``Session`` to which instance belongs, or None."""
 
index b6d717356816eef2746f4dc6268b9ace19b31b5f..6935e31e5d7f8d6d58ef1f03c7f0c7a536fec2ff 100644 (file)
@@ -746,17 +746,17 @@ class SQLCompiler(engine.Compiled):
 
         return text
 
+    def _create_crud_bind_param(self, col, value):
+        bindparam = sql.bindparam(col.key, value, type_=col.type)
+        self.binds[col.key] = bindparam
+        return self.bindparam_string(self._truncate_bindparam(bindparam))
+        
     def _get_colparams(self, stmt):
         """create a set of tuples representing column/string pairs for use
         in an INSERT or UPDATE statement.
 
         """
 
-        def create_bind_param(col, value):
-            bindparam = sql.bindparam(col.key, value, type_=col.type)
-            self.binds[col.key] = bindparam
-            return self.bindparam_string(self._truncate_bindparam(bindparam))
-
         self.postfetch = []
         self.prefetch = []
         self.returning = []
@@ -764,7 +764,7 @@ class SQLCompiler(engine.Compiled):
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
-            return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
+            return [(c, self._create_crud_bind_param(c, None)) for c in stmt.table.columns]
 
         # if we have statement parameters - set defaults in the
         # compiled params
@@ -793,7 +793,7 @@ class SQLCompiler(engine.Compiled):
             if c.key in parameters:
                 value = parameters[c.key]
                 if sql._is_literal(value):
-                    value = create_bind_param(c, value)
+                    value = self._create_crud_bind_param(c, value)
                 else:
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
@@ -819,7 +819,7 @@ class SQLCompiler(engine.Compiled):
                                 values.append((c, self.process(c.default.arg.self_group())))
                                 self.returning.append(c)
                             elif c.default is not None:
-                                values.append((c, create_bind_param(c, None)))
+                                values.append((c, self._create_crud_bind_param(c, None)))
                                 self.prefetch.append(c)
                             else:
                                 self.returning.append(c)
@@ -833,7 +833,7 @@ class SQLCompiler(engine.Compiled):
                                 ) or \
                                 self.dialect.preexecute_autoincrement_sequences:
 
-                                values.append((c, create_bind_param(c, None)))
+                                values.append((c, self._create_crud_bind_param(c, None)))
                                 self.prefetch.append(c)
                                 
                     elif isinstance(c.default, schema.ColumnDefault):
@@ -844,7 +844,7 @@ class SQLCompiler(engine.Compiled):
                                 # dont add primary key column to postfetch
                                 self.postfetch.append(c)
                         else:
-                            values.append((c, create_bind_param(c, None)))
+                            values.append((c, self._create_crud_bind_param(c, None)))
                             self.prefetch.append(c)
                     elif c.server_default is not None:
                         if not c.primary_key:
@@ -861,7 +861,7 @@ class SQLCompiler(engine.Compiled):
                             values.append((c, self.process(c.onupdate.arg.self_group())))
                             self.postfetch.append(c)
                         else:
-                            values.append((c, create_bind_param(c, None)))
+                            values.append((c, self._create_crud_bind_param(c, None)))
                             self.prefetch.append(c)
                     elif c.server_onupdate is not None:
                         self.postfetch.append(c)
index 2d99e20630ac88a729a9dfb2ae4663c2850a2a79..8562366c59da8bad790157a80b05a3cf47a03229 100644 (file)
@@ -2,7 +2,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy.test.util import gc_collect
 import inspect
 import pickle
-from sqlalchemy.orm import create_session, sessionmaker, attributes
+from sqlalchemy.orm import create_session, sessionmaker, attributes, make_transient
 import sqlalchemy as sa
 from sqlalchemy.test import engines, testing, config
 from sqlalchemy import Integer, String, Sequence
@@ -205,6 +205,25 @@ class SessionTest(_fixtures.FixtureTest):
         eq_(bind.connect().execute("select count(1) from users").scalar(), 1)
         sess.close()
 
+    @testing.resolve_artifact_names
+    def test_make_transient(self):
+        mapper(User, users)
+        sess = create_session()
+        sess.add(User(name='test'))
+        sess.flush()
+        
+        u1 = sess.query(User).first()
+        make_transient(u1)
+        assert u1 not in sess
+        sess.add(u1)
+        assert u1 in sess.new
+
+        u1 = sess.query(User).first()
+        sess.expunge(u1)
+        make_transient(u1)
+        sess.add(u1)
+        assert u1 in sess.new
+        
     @testing.resolve_artifact_names
     def test_autoflush_expressions(self):
         """test that an expression which is dependent on object state is