]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
The :meth:`.Connection.connect` and :meth:`.Connection.contextual_connect`
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Nov 2012 06:18:58 +0000 (01:18 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Nov 2012 06:18:58 +0000 (01:18 -0500)
methods now return a "branched" version so that the :meth:`.Connection.close`
method can be called on the returned connection without affecting the
original.   Allows symmetry when using :class:`.Engine` and
:class:`.Connection` objects as context managers.

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/schema.py
test/engine/test_bind.py

index c34aadb8a594e6f870f25170ce3a1df0117b33f7..a5bf6d4aac3a58fc0fd94371fd0d551be80fee56 100644 (file)
@@ -6,6 +6,21 @@
 .. changelog::
     :version: 0.8.0b2
 
+    .. change::
+        :tags: engine, feature
+
+      The :meth:`.Connection.connect` and :meth:`.Connection.contextual_connect`
+      methods now return a "branched" version so that the :meth:`.Connection.close`
+      method can be called on the returned connection without affecting the
+      original.   Allows symmetry when using :class:`.Engine` and
+      :class:`.Connection` objects as context managers::
+
+        with conn.connect() as c: # leaves the Connection open
+          c.execute("...")
+
+        with engine.connect() as c:  # closes the Connection
+          c.execute("...")
+
     .. change::
         :tags: engine
 
index 797158b3b71e14747aa619a3f16a056146257eae..062c20d84c973d9d1eaebf29a35a2351ead22d0e 100644 (file)
@@ -9,13 +9,13 @@
 
 """
 
-
+from __future__ import with_statement
 import sys
-from itertools import chain
 from .. import exc, schema, util, log, interfaces
 from ..sql import expression, util as sql_util
 from .interfaces import Connectable, Compiled
 from .util import _distill_params
+import contextlib
 
 class Connection(Connectable):
     """Provides high-level functionality for a wrapped DB-API connection.
@@ -270,24 +270,34 @@ class Connection(Connectable):
         return self.connection.info
 
     def connect(self):
-        """Returns self.
+        """Returns a branched version of this :class:`.Connection`.
+
+        The :meth:`.Connection.close` method on the returned
+        :class:`.Connection` can be called and this
+        :class:`.Connection` will remain open.
+
+        This method provides usage symmetry with
+        :meth:`.Engine.connect`, including for usage
+        with context managers.
 
-        This ``Connectable`` interface method returns self, allowing
-        Connections to be used interchangeably with Engines in most
-        situations that require a bind.
         """
 
-        return self
+        return self._branch()
 
     def contextual_connect(self, **kwargs):
-        """Returns self.
+        """Returns a branched version of this :class:`.Connection`.
+
+        The :meth:`.Connection.close` method on the returned
+        :class:`.Connection` can be called and this
+        :class:`.Connection` will remain open.
+
+        This method provides usage symmetry with
+        :meth:`.Engine.contextual_connect`, including for usage
+        with context managers.
 
-        This ``Connectable`` interface method returns self, allowing
-        Connections to be used interchangeably with Engines in most
-        situations that require a bind.
         """
 
-        return self
+        return self._branch()
 
     def invalidate(self, exception=None):
         """Invalidate the underlying DBAPI connection associated with
@@ -1459,24 +1469,21 @@ class Engine(Connectable, log.Identified):
 
 
     def _execute_default(self, default):
-        connection = self.contextual_connect()
-        try:
-            return connection._execute_default(default, (), {})
-        finally:
-            connection.close()
+        with self.contextual_connect() as conn:
+            return conn._execute_default(default, (), {})
 
+    @contextlib.contextmanager
+    def _optional_conn_ctx_manager(self, connection=None):
+        if connection is None:
+            with self.contextual_connect() as conn:
+                yield conn
+        else:
+            yield connection
 
     def _run_visitor(self, visitorcallable, element,
                                     connection=None, **kwargs):
-        if connection is None:
-            conn = self.contextual_connect(close_with_result=False)
-        else:
-            conn = connection
-        try:
+        with self._optional_conn_ctx_manager(connection) as conn:
             conn._run_visitor(visitorcallable, element, **kwargs)
-        finally:
-            if connection is None:
-                conn.close()
 
     class _trans_ctx(object):
         def __init__(self, conn, transaction, close_with_result):
@@ -1495,6 +1502,7 @@ class Engine(Connectable, log.Identified):
             if not self.close_with_result:
                 self.conn.close()
 
+
     def begin(self, close_with_result=False):
         """Return a context manager delivering a :class:`.Connection`
         with a :class:`.Transaction` established.
@@ -1575,11 +1583,8 @@ class Engine(Connectable, log.Identified):
 
         """
 
-        conn = self.contextual_connect()
-        try:
+        with self.contextual_connect() as conn:
             return conn.transaction(callable_, *args, **kwargs)
-        finally:
-            conn.close()
 
     def run_callable(self, callable_, *args, **kwargs):
         """Given a callable object or function, execute it, passing
@@ -1594,11 +1599,8 @@ class Engine(Connectable, log.Identified):
         which one is being dealt with.
 
         """
-        conn = self.contextual_connect()
-        try:
+        with self.contextual_connect() as conn:
             return conn.run_callable(callable_, *args, **kwargs)
-        finally:
-            conn.close()
 
     def execute(self, statement, *multiparams, **params):
         """Executes the given construct and returns a :class:`.ResultProxy`.
@@ -1673,17 +1675,10 @@ class Engine(Connectable, log.Identified):
           the ``contextual_connect`` for this ``Engine``.
         """
 
-        if connection is None:
-            conn = self.contextual_connect()
-        else:
-            conn = connection
-        if not schema:
-            schema = self.dialect.default_schema_name
-        try:
+        with self._optional_conn_ctx_manager(connection) as conn:
+            if not schema:
+                schema = self.dialect.default_schema_name
             return self.dialect.get_table_names(conn, schema)
-        finally:
-            if connection is None:
-                conn.close()
 
     def has_table(self, table_name, schema=None):
         return self.run_callable(self.dialect.has_table, table_name, schema)
index ac8be377c2ccaf6dc116fcb15276cd72912b13eb..9aa74217768950ad8c072b55f170f0966d30aa6b 100644 (file)
@@ -27,6 +27,7 @@ Since these objects are part of the SQL expression language, they are usable
 as components in SQL expressions.
 
 """
+from __future__ import with_statement
 import re
 import inspect
 from . import exc, util, dialects, event, events, inspection
@@ -2598,25 +2599,19 @@ class MetaData(SchemaItem):
         if bind is None:
             bind = _bind_or_error(self)
 
-        if bind.engine is not bind:
-            conn = bind
-            close = False
-        else:
-            conn = bind.contextual_connect()
-            close = True
+        with bind.connect() as conn:
 
-        reflect_opts = {
-            'autoload': True,
-            'autoload_with': bind
-        }
+            reflect_opts = {
+                'autoload': True,
+                'autoload_with': conn
+            }
 
-        if schema is None:
-            schema = self.schema
+            if schema is None:
+                schema = self.schema
 
-        if schema is not None:
-            reflect_opts['schema'] = schema
+            if schema is not None:
+                reflect_opts['schema'] = schema
 
-        try:
             available = util.OrderedSet(bind.engine.table_names(schema,
                                                             connection=conn))
             if views:
@@ -2643,9 +2638,6 @@ class MetaData(SchemaItem):
 
             for name in load:
                 Table(name, self, **reflect_opts)
-        finally:
-            if close:
-                conn.close()
 
     def append_ddl_listener(self, event_name, listener):
         """Append a DDL event listener to this ``MetaData``.
index 30ee43b3bb81b745f1ea467d3b351f68b2202d3f..f76350fcca019968303e68e7a71068b7efbd66b8 100644 (file)
@@ -1,6 +1,6 @@
 """tests the "bind" attribute/argument across schema and SQL,
 including the deprecated versions of these arguments"""
-
+from __future__ import with_statement
 from sqlalchemy.testing import eq_, assert_raises
 from sqlalchemy import engine, exc
 from sqlalchemy import MetaData, ThreadLocalMetaData
@@ -12,6 +12,29 @@ from sqlalchemy import testing
 from sqlalchemy.testing import fixtures
 
 class BindTest(fixtures.TestBase):
+    def test_bind_close_engine(self):
+        e = testing.db
+        with e.connect() as conn:
+            assert not conn.closed
+        assert conn.closed
+
+        with e.contextual_connect() as conn:
+            assert not conn.closed
+        assert conn.closed
+
+    def test_bind_close_conn(self):
+        e = testing.db
+        conn = e.connect()
+        with conn.connect() as c2:
+            assert not c2.closed
+        assert not conn.closed
+        assert c2.closed
+
+        with conn.contextual_connect() as c2:
+            assert not c2.closed
+        assert not conn.closed
+        assert c2.closed
+
     def test_create_drop_explicit(self):
         metadata = MetaData()
         table = Table('test_table', metadata,