]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- assurances that context.connection is safe to use by column default functions,...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 31 Jul 2007 17:15:36 +0000 (17:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 31 Jul 2007 17:15:36 +0000 (17:15 +0000)
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
test/sql/defaults.py

index 548494ff2cee6ac0681e2e43a454cc5022a5e9ab..a40ed9bdf42ac58438242a1c61d871e52c5dcb50 100644 (file)
@@ -185,9 +185,9 @@ class PGExecutionContext(default.DefaultExecutionContext):
             # use server-side cursors:
             # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
             ident = "c" + hex(random.randint(0, 65535))[2:]
-            return self.connection.connection.cursor(ident)
+            return self._connection.connection.cursor(ident)
         else:
-            return self.connection.connection.cursor()
+            return self._connection.connection.cursor()
 
     def get_result_proxy(self):
         if self._is_server_side():
index 642eeac627fca16b8892b663bb15bdc19c4b28ba..ff2da1165c39d26acabf4eb1555fd06ea48a85da 100644 (file)
@@ -275,8 +275,14 @@ class ExecutionContext(object):
     ExecutionContext should have these datamembers:
     
         connection
-            Connection object which initiated the call to the
-            dialect to create this ExecutionContext.
+            Connection object which can be freely used by default value generators
+            to execute SQL.  This Connection should reference the same underlying
+            connection/transactional resources of root_connection.
+            
+        root_connection
+            Connection object which is the source of this ExecutionContext.  This
+            Connection may have close_with_result=True set, in which case it can
+            only be used once.
 
         dialect
             dialect which created this ExecutionContext.
@@ -515,12 +521,13 @@ class Connection(Connectable):
     The Connection object is **not** threadsafe.
     """
 
-    def __init__(self, engine, connection=None, close_with_result=False):
+    def __init__(self, engine, connection=None, close_with_result=False, _branch=False):
         self.__engine = engine
         self.__connection = connection or engine.raw_connection()
         self.__transaction = None
         self.__close_with_result = close_with_result
         self.__savepoint_seq = 0
+        self.__branch = _branch
 
     def _get_connection(self):
         try:
@@ -530,9 +537,14 @@ class Connection(Connectable):
 
     def _branch(self):
         """return a new Connection which references this Connection's 
-        engine and connection; but does not have close_with_result enabled."""
+        engine and connection; but does not have close_with_result enabled,
+        and also whose close() method does nothing.
+        
+        This is used to execute "sub" statements within a single execution,
+        usually an INSERT statement.
+        """
         
-        return Connection(self.__engine, self.__connection)
+        return Connection(self.__engine, self.__connection, _branch=True)
         
     engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
     dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
@@ -686,7 +698,8 @@ class Connection(Connectable):
             c = self.__connection
         except AttributeError:
             return
-        self.__connection.close()
+        if not self.__branch:
+            self.__connection.close()
         self.__connection = None
         del self.__connection
 
@@ -757,7 +770,7 @@ class Connection(Connectable):
         else:
             self.__execute(context)
         self._autocommit(context.statement)
-
+        
     def __execute(self, context):
         if context.parameters is None:
             if context.dialect.positional:
@@ -1124,7 +1137,8 @@ class ResultProxy(object):
             self._rowcount = context.get_rowcount()
             self.close()
             
-    connection = property(lambda self:self.context.connection)
+    connection = property(lambda self:self.context.root_connection)
+    
     def _get_rowcount(self):
         if self._rowcount is not None:
             return self._rowcount
@@ -1510,9 +1524,7 @@ class DefaultRunner(schema.SchemaVisitor):
 
     def __init__(self, context):
         self.context = context
-        # branch the connection so it doesnt close after result
-        self.connection = context.connection._branch()
-        
+        self.connection = self.context._connection._branch()
     dialect = property(lambda self:self.context.dialect)
     
     def get_column_default(self, column):
index a2e159639dcd9c4cd96c54960f9639427d0c4a5e..185387177d3441c762dc0b49960f5f677b51ea1b 100644 (file)
@@ -145,7 +145,7 @@ class DefaultDialect(base.Dialect):
 class DefaultExecutionContext(base.ExecutionContext):
     def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
         self.dialect = dialect
-        self.connection = connection
+        self._connection = connection
         self.compiled = compiled
         self._postfetch_cols = util.Set()
         
@@ -172,11 +172,15 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.statement = self.statement.encode(self.dialect.encoding)
             
         self.cursor = self.create_cursor()
-        
+    
     engine = property(lambda s:s.connection.engine)
     isinsert = property(lambda s:s.compiled and s.compiled.isinsert)
     isupdate = property(lambda s:s.compiled and s.compiled.isupdate)
     
+    connection = property(lambda s:s._connection._branch())
+    
+    root_connection = property(lambda s:s._connection)
+    
     def __encode_param_keys(self, params):
         """apply string encoding to the keys of dictionary-based bind parameters"""
         if self.dialect.positional or self.dialect.supports_unicode_statements():
@@ -218,7 +222,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None
 
     def create_cursor(self):
-        return self.connection.connection.cursor()
+        return self._connection.connection.cursor()
 
     def pre_execution(self):
         self.pre_exec()
index 6c200232f229708cbf1cfb178e8e6007b52126c8..a3fe8c07a5bc405e887d44499511763d114e9392 100644 (file)
@@ -18,11 +18,20 @@ class DefaultTest(PersistTest):
             x['x'] += 1
             return x['x']
 
-        def mydefault_with_ctx(ctx):
-            return ctx.compiled_parameters['col1'] + 10
-
         def myupdate_with_ctx(ctx):
             return len(ctx.compiled_parameters['col2'])
+        
+        def mydefault_using_connection(ctx):
+            conn = ctx.connection
+            try:
+                if db.engine.name == 'oracle':
+                    return conn.execute("select 12 from dual").scalar()
+                else:
+                    return conn.execute("select 12").scalar()
+            finally:
+                # ensure a "close()" on this connection does nothing,
+                # since its a "branched" connection
+                conn.close()
             
         use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
         is_oracle = db.engine.name == 'oracle'
@@ -76,7 +85,7 @@ class DefaultTest(PersistTest):
             Column('boolcol2', Boolean, default=False),
             
             # python function which uses ExecutionContext
-            Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx),
+            Column('col7', Integer, default=mydefault_using_connection, onupdate=myupdate_with_ctx),
             
             # python builtin
             Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today)
@@ -119,7 +128,7 @@ class DefaultTest(PersistTest):
         print "Currenttime "+ repr(ctexec)
         l = t.select().execute()
         today = datetime.date.today()
-        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)])
+        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today)])
 
     def testinsertvalues(self):
         t.insert(values={'col3':50}).execute()
@@ -181,7 +190,7 @@ class AutoIncrementTest(PersistTest):
         nonai_table = Table("aitest", meta, 
             Column('id', Integer, autoincrement=False, primary_key=True),
             Column('data', String(20)))
-        nonai_table.create()
+        nonai_table.create(checkfirst=True)
         try:
             try:
                 # postgres will fail on first row, mysql fails on second row
@@ -201,7 +210,7 @@ class AutoIncrementTest(PersistTest):
         table = Table("aitest", meta, 
             Column('id', Integer, primary_key=True),
             Column('data', String(20)))
-        table.create()
+        table.create(checkfirst=True)
         try:
             table.insert().execute(data='row 1')
             table.insert().execute(data='row 2')
@@ -216,7 +225,7 @@ class AutoIncrementTest(PersistTest):
         table = Table("aitest", meta, 
             Column('id', Integer, primary_key=True),
             Column('data', String(20)))
-        table.create()
+        table.create(checkfirst=True)
 
         try:
             # simulate working on a table that doesn't already exist