]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The visit_pool() method of Dialect is removed, and replaced with
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 15 Mar 2010 17:08:31 +0000 (13:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 15 Mar 2010 17:08:31 +0000 (13:08 -0400)
on_connect().  This method returns a callable which receives
the raw DBAPI connection after each one is created.   The callable
is assembled into a first_connect/connect pool listener by the
connection strategy if non-None.   Provides a simpler interface
for dialects.

CHANGES
lib/sqlalchemy/connectors/mxodbc.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/test/engines.py
test/aaa_profiling/test_zoomark.py
test/aaa_profiling/test_zoomark_orm.py

diff --git a/CHANGES b/CHANGES
index 5a8cf7aac89369f2046c074c29d9520e1f062dc9..d2ed4672e11d5b83a1401ee2f20df79960c3f187 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -264,6 +264,12 @@ CHANGES
     within the "name" field of logging messages instead of the default
     hex identifier string.  [ticket:1555]
     
+  - The visit_pool() method of Dialect is removed, and replaced with
+    on_connect().  This method returns a callable which receives
+    the raw DBAPI connection after each one is created.   The callable
+    is assembled into a first_connect/connect pool listener by the 
+    connection strategy if non-None.   Provides a simpler interface 
+    for dialects.
         
 - metadata
   - Added the ability to strip schema information when using
index 49c5a732937721769048640996895d11028093db..a646473fb74060d3f691c37fc2ae940a9bf3a71c 100644 (file)
@@ -26,14 +26,13 @@ class MxODBCConnector(Connector):
             raise ImportError, "Unrecognized platform for mxODBC import"
         return module
 
-    def visit_pool(self, pool):
-        def connect(conn, rec):
+    def on_connect(self):
+        def connect(conn):
             conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
             conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
             #conn.bindmethod = self.dbapi.BIND_USING_PYTHONTYPE
             #conn.bindmethod = self.dbapi.BIND_USING_SQLTYPE
-
-        pool.add_listener({'connect':connect})
+        return connect
 
     def create_connect_args(self, url):
         """ Return a tuple of *args,**kwargs for creating a connection.
index 7d4cbbbd808b503ff840c74757595c1ee9874fc8..cbd92ccfecddd2195118dd0fee20e29aad4359fa 100644 (file)
@@ -600,21 +600,19 @@ class PGDialect(default.DefaultDialect):
         if not self.supports_native_enum:
             self.colspecs = self.colspecs.copy()
             del self.colspecs[ENUM]
-            
-    def visit_pool(self, pool):
-        if self.isolation_level is not None:
-            class SetIsolationLevel(object):
-                def __init__(self, isolation_level):
-                    self.isolation_level = isolation_level
-
-                def connect(self, conn, rec):
-                    cursor = conn.cursor()
-                    cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s"
-                                   % self.isolation_level)
-                    cursor.execute("COMMIT")
-                    cursor.close()
-            pool.add_listener(SetIsolationLevel(self.isolation_level))
 
+    def on_connect(self):
+        if self.isolation_level is not None:
+            def connect(conn):
+                cursor = conn.cursor()
+                cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s"
+                               % self.isolation_level)
+                cursor.execute("COMMIT")
+                cursor.close()
+            return connect
+        else:
+            return None
+            
     def do_begin_twophase(self, connection, xid):
         self.do_begin(connection.connection)
 
index 48349834d225d5eac75b9599ae135910dd48b55b..c239a3ee061df9dd11a38232969b4a0637452c62 100644 (file)
@@ -179,20 +179,18 @@ class PGDialect_psycopg2(PGDialect):
         psycopg = __import__('psycopg2')
         return psycopg
     
-    _unwrap_connection = None
-    
-    def visit_pool(self, pool):
+    def on_connect(self):
+        base_on_connect = super(PGDialect_psycopg2, self).on_connect()
         if self.dbapi and self.use_native_unicode:
             extensions = __import__('psycopg2.extensions').extensions
-            def connect(conn, rec):
-                if self._unwrap_connection:
-                    conn = self._unwrap_connection(conn)
-                    if conn is None:
-                        return
+            def connect(conn):
                 extensions.register_type(extensions.UNICODE, conn)
-            pool.add_listener({'first_connect': connect, 'connect':connect})
-        super(PGDialect_psycopg2, self).visit_pool(pool)
-        
+                if base_on_connect:
+                    base_on_connect(conn)
+            return connect
+        else:
+            return base_on_connect
+            
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         if 'port' in opts:
index 5bcf901519a5bc1a927f35b8d2ec8b2d083a4e9e..dfc09f025799d6750110249e790b6de2c8988b9c 100644 (file)
@@ -360,21 +360,21 @@ class SQLiteDialect(default.DefaultDialect):
         # hypothetical driver ?)
         self.native_datetime = native_datetime
         
-    def visit_pool(self, pool):
+    def on_connect(self):
         if self.isolation_level is not None:
-            class SetIsolationLevel(object):
-                def __init__(self, isolation_level):
-                    if isolation_level == 'READ UNCOMMITTED':
-                        self.isolation_level = 1
-                    else:
-                        self.isolation_level = 0
-
-                def connect(self, conn, rec):
-                    cursor = conn.cursor()
-                    cursor.execute("PRAGMA read_uncommitted = %d" % self.isolation_level)
-                    cursor.close()
-            pool.add_listener(SetIsolationLevel(self.isolation_level))
-
+            if self.isolation_level == 'READ UNCOMMITTED':
+                isolation_level = 1
+            else:
+                isolation_level = 0
+                
+            def connect(conn):
+                cursor = conn.cursor()
+                cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
+                cursor.close()
+            return connect
+        else:
+            return None
+    
     def table_names(self, connection, schema):
         if schema is not None:
             qschema = self.identifier_preparer.quote_identifier(schema)
index fa0059130fbbb9c525b59ccf3a90a3cf144f6c56..095f7a960ec9811c85619c9dbc6d1579d4fa82a9 100644 (file)
@@ -169,6 +169,7 @@ class Dialect(object):
         Given a :class:`~sqlalchemy.engine.url.URL` object, returns a tuple
         consisting of a `*args`/`**kwargs` suitable to send directly
         to the dbapi's connect function.
+        
         """
 
         raise NotImplementedError()
@@ -183,6 +184,7 @@ class Dialect(object):
 
         The returned result is cached *per dialect class* so can
         contain no dialect-instance state.
+        
         """
 
         raise NotImplementedError()
@@ -192,6 +194,13 @@ class Dialect(object):
 
         Allows dialects to configure options based on server version info or
         other properties.
+        
+        The connection passed here is a SQLAlchemy Connection object, 
+        with full capabilities.
+        
+        The initalize() method of the base dialect should be called via
+        super().
+        
         """
 
         pass
@@ -204,6 +213,12 @@ class Dialect(object):
         properties from the database.  If include_columns (a list or
         set) is specified, limit the autoload to the given column
         names.
+        
+        The default implementation uses the 
+        :class:`~sqlalchemy.engine.reflection.Inspector` interface to 
+        provide the output, building upon the granular table/column/
+        constraint etc. methods of :class:`Dialect`.
+        
         """
 
         raise NotImplementedError()
@@ -458,8 +473,22 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-    def visit_pool(self, pool):
-        """Executed after a pool is created."""
+    def on_connect(self):
+        """return a callable which sets up a newly created DBAPI connection.
+
+        The callable accepts a single argument "conn" which is the 
+        DBAPI connection itself.  It has no return value.
+        
+        This is used to set dialect-wide per-connection options such as isolation
+        modes, unicode modes, etc.
+
+        If a callable is returned, it will be assembled into a pool listener
+        that receives the direct DBAPI connection, with all wrappers removed.
+
+        If None is returned, no listener will be generated.
+
+        """
+        return None
 
 
 class ExecutionContext(object):
index 7df858a64409c82bedf870c6c4d8bc7c5919d795..ce24a9ae434ccc9b5c375f31dde3f6bd88fd48a2 100644 (file)
@@ -138,6 +138,20 @@ class DefaultDialect(base.Dialect):
    
         self.do_rollback(connection.connection)
  
+    def on_connect(self):
+        """return a callable which sets up a newly created DBAPI connection.
+        
+        This is used to set dialect-wide per-connection options such as isolation
+        modes, unicode modes, etc.
+        
+        If a callable is returned, it will be assembled into a pool listener
+        that receives the direct DBAPI connection, with all wrappers removed.
+        
+        If None is returned, no listener will be generated.
+        
+        """
+        return None
+        
     def _check_unicode_returns(self, connection):
         # Py2K
         if self.supports_unicode_statements:
index 7c434105c30e5928add59db94312071cc86a48ad..7fc39b91a90dac053b0f86714e4776dfdf3a2f72 100644 (file)
@@ -130,8 +130,16 @@ class DefaultEngineStrategy(EngineStrategy):
         engine = engineclass(pool, dialect, u, **engine_args)
 
         if _initialize:
-            dialect.visit_pool(pool)
-
+            do_on_connect = dialect.on_connect()
+            if do_on_connect:
+                def on_connect(conn, rec):
+                    conn = getattr(conn, '_sqla_unwrap', conn)
+                    if conn is None:
+                        return
+                    do_on_connect(conn)
+                    
+                pool.add_listener({'first_connect': on_connect, 'connect':on_connect})
+                    
             def first_connect(conn, rec):
                 c = base.Connection(engine, connection=conn)
                 dialect.initialize(c)
index 58bfe2b3c1b8b3c01bad2866d18d255a6183f18f..0cfd58d20704ba413d2bc1e5f269127c4f0a667d 100644 (file)
@@ -242,7 +242,11 @@ class ReplayableSession(object):
             else:
                 buffer.append(result)
                 return result
-
+        
+        @property
+        def _sqla_unwrap(self):
+            return self._subject
+            
         def __getattribute__(self, key):
             try:
                 return object.__getattribute__(self, key)
@@ -275,7 +279,11 @@ class ReplayableSession(object):
                 return self
             else:
                 return result
-
+        
+        @property
+        def _sqla_unwrap(self):
+            return None
+            
         def __getattribute__(self, key):
             try:
                 return object.__getattribute__(self, key)
@@ -290,10 +298,3 @@ class ReplayableSession(object):
             else:
                 return result
 
-def unwrap_connection(conn):
-    if conn.__class__.__name__ == 'Recorder':
-        return conn._subject
-    elif conn.__class__.__name__ == 'Player':
-        return None
-    else:
-        return conn
index b7b77bd02bfb1fb90c6fdabc61b738d5aaae4ba3..0c090acb7c8ad330db09901eba2473db11eca592 100644 (file)
@@ -37,7 +37,6 @@ class ZooMarkTest(TestBase):
         
         recorder = lambda: dbapi_session.recorder(creator())
         engine = engines.testing_engine(options={'creator':recorder})
-        engine.dialect._unwrap_connection = engines.unwrap_connection
         metadata = MetaData(engine)
         engine.connect()
         
@@ -321,7 +320,6 @@ class ZooMarkTest(TestBase):
 
         player = lambda: dbapi_session.player()
         engine = create_engine('postgresql:///', creator=player)
-        engine.dialect._unwrap_connection = engines.unwrap_connection
         metadata = MetaData(engine)
         engine.connect()
         
index 62a27eca441fb5ab7662eefa02274be4e800ff57..8304c93839fb5dae4e1bbb2cf7ef38ea84e3d7d1 100644 (file)
@@ -36,7 +36,6 @@ class ZooMarkTest(TestBase):
         creator = testing.db.pool._creator
         recorder = lambda: dbapi_session.recorder(creator())
         engine = engines.testing_engine(options={'creator':recorder})
-        engine.dialect._unwrap_connection = engines.unwrap_connection
         metadata = MetaData(engine)
         session = sessionmaker()()
         engine.connect()
@@ -284,7 +283,6 @@ class ZooMarkTest(TestBase):
 
         player = lambda: dbapi_session.player()
         engine = create_engine('postgresql:///', creator=player)
-        engine.dialect._unwrap_connection = engines.unwrap_connection
         metadata = MetaData(engine)
         session = sessionmaker()()
         engine.connect()