]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
moved the metadata step of ResultProxy into a ResultMetaData object. this also repla...
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Jan 2010 17:48:43 +0000 (17:48 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Jan 2010 17:48:43 +0000 (17:48 +0000)
Allows RowProxy objects to reference just the metadata they need and provides the "core" of ResultProxy
detached from the object itself, allowing ResultProxy implementations to vary more easily.  will also
enable [ticket:1635]

CHANGES
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/engine/base.py
test/sql/test_query.py

diff --git a/CHANGES b/CHANGES
index 84127d71f87cb0238494e65c4a32c16d4f57c6cb..2c1db03ccbe1b64112577c0bccda9f8642184065 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -333,6 +333,10 @@ CHANGES
   - RowProxy objects are now pickleable, i.e. the object returned
     by result.fetchone(), result.fetchall() etc.
 
+  - RowProxy no longer has a close() method, as the row no longer
+    maintains a reference to the parent.  Call close() on 
+    the parent ResultProxy instead, or use autoclose.
+    
   - ResultProxy internals have been overhauled to greatly reduce
     method call counts when fetching columns that have no 
     type-level processing applied.   Provides a 100% speed
index c27dae26b87ba27f4a10e1cead3087f8e2dda573..e3235783c61d5261e5fdb35e160643f77adcaead 100644 (file)
@@ -1727,6 +1727,11 @@ class MySQLDialect(default.DefaultDialect):
 
         return _DecodingRowProxy(rp.fetchone(), charset)
 
+    def _compat_first(self, rp, charset=None):
+        """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
+
+        return _DecodingRowProxy(rp.first(), charset)
+
     def _extract_error_code(self, exception):
         raise NotImplementedError()
     
@@ -1975,7 +1980,7 @@ class MySQLDialect(default.DefaultDialect):
         # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
 
         charset = self._connection_charset
-        row = self._compat_fetchone(connection.execute(
+        row = self._compat_first(connection.execute(
             "SHOW VARIABLES LIKE 'lower_case_table_names'"),
                                charset=charset)
         if not row:
@@ -1989,7 +1994,6 @@ class MySQLDialect(default.DefaultDialect):
                 cs = 1
             else:
                 cs = int(row[1])
-            row.close()
         return cs
 
     def _detect_collations(self, connection):
@@ -2011,7 +2015,7 @@ class MySQLDialect(default.DefaultDialect):
     def _detect_ansiquotes(self, connection):
         """Detect and adjust for the ANSI_QUOTES sql mode."""
 
-        row = self._compat_fetchone(
+        row = self._compat_first(
             connection.execute("SHOW VARIABLES LIKE 'sql_mode'"),
                                charset=self._connection_charset)
 
@@ -2036,20 +2040,16 @@ class MySQLDialect(default.DefaultDialect):
 
         rp = None
         try:
-            try:
-                rp = connection.execute(st)
-            except exc.SQLError, e:
-                if self._extract_error_code(e) == 1146:
-                    raise exc.NoSuchTableError(full_name)
-                else:
-                    raise
-            row = self._compat_fetchone(rp, charset=charset)
-            if not row:
+            rp = connection.execute(st)
+        except exc.SQLError, e:
+            if self._extract_error_code(e) == 1146:
                 raise exc.NoSuchTableError(full_name)
-            return row[1].strip()
-        finally:
-            if rp:
-                rp.close()
+            else:
+                raise
+        row = self._compat_first(rp, charset=charset)
+        if not row:
+            raise exc.NoSuchTableError(full_name)
+        return row[1].strip()
 
         return sql
 
index 2210ba40837b6d876ba27f6a8d60411a768059c2..2789ea4d9cdccc82acf80a1dbe769baca7eeb7ea 100644 (file)
@@ -1570,7 +1570,7 @@ def _proxy_connection_cls(cls, proxy):
 
 
 class RowProxy(object):
-    """Proxy a single cursor row for a parent ResultProxy.
+    """Proxy values from a single cursor row.
 
     Mostly follows "ordered dictionary" behavior, mapping result
     values to the string-based column name, the integer position of
@@ -1582,19 +1582,13 @@ class RowProxy(object):
     __slots__ = ['__parent', '__row', '__colfuncs']
 
     def __init__(self, parent, row):
-        """RowProxy objects are constructed by ResultProxy objects."""
 
         self.__parent = parent
         self.__row = row
         self.__colfuncs = parent._colfuncs
         if self.__parent._echo:
-            self.__parent.context.engine.logger.debug("Row %r", row)
+            self.__parent.logger.debug("Row %r", row)
         
-    def close(self):
-        """Close the parent ResultProxy."""
-
-        self.__parent.close()
-
     def __contains__(self, key):
         return self.__parent._has_key(self.__row, key)
 
@@ -1604,7 +1598,7 @@ class RowProxy(object):
     def __getstate__(self):
         return {
             '__row':[self.__colfuncs[i][0](self.__row) for i in xrange(len(self.__row))],
-            '__parent':PickledResultProxy(self.__parent)
+            '__parent':self.__parent
         }
     
     def __setstate__(self, d):
@@ -1680,62 +1674,140 @@ class RowProxy(object):
     def itervalues(self):
         return iter(self)
 
-class PickledResultProxy(object):
-    """a 'mock' ResultProxy used by a RowProxy being pickled."""
-    
-    _echo = False
+class ResultMetaData(object):
+    """Handle cursor.description, applying additional info from an execution context."""
     
-    def __init__(self, resultproxy):
-        self._pickled_colfuncs = \
-                    dict(
-                        (key, (i, type_)) 
-                        for key, (fn, i, type_) in resultproxy._colfuncs.iteritems() 
-                        if isinstance(key, (basestring, int))
-                    )
-        self._keys = resultproxy.keys
-    
-    @util.memoized_property
-    def _colfuncs(self):
-        d = {}
-        for key, (index, type_) in self._pickled_colfuncs.iteritems():
-            if type_ == 'ambiguous':
-                d[key] = (ResultProxy._ambiguous_processor(key), index, type_)
+    def __init__(self, parent, metadata):
+        self._colfuncs = colfuncs = {}
+        self.keys = []
+        self._echo = parent._echo
+        context = parent.context
+        dialect = context.dialect
+        typemap = dialect.dbapi_type_map
+
+        for i, (colname, coltype) in enumerate(m[0:2] for m in metadata):
+            if dialect.description_encoding:
+                colname = colname.decode(dialect.description_encoding)
+
+            if '.' in colname:
+                # sqlite will in some circumstances prepend table name to
+                # colnames, so strip
+                origname = colname
+                colname = colname.split('.')[-1]
             else:
-                d[key] = (operator.itemgetter(index), index, "itemgetter")
-        return d
-        
+                origname = None
+
+            if context.result_map:
+                try:
+                    name, obj, type_ = context.result_map[colname.lower()]
+                except KeyError:
+                    name, obj, type_ = \
+                        colname, None, typemap.get(coltype, types.NULLTYPE)
+            else:
+                name, obj, type_ = (colname, None, typemap.get(coltype, types.NULLTYPE))
+
+            processor = type_.dialect_impl(dialect).\
+                            result_processor(dialect, coltype)
+            
+            if processor:
+                def make_colfunc(processor, index):
+                    def getcol(row):
+                        return processor(row[index])
+                    return getcol
+                rec = (make_colfunc(processor, i), i, "colfunc")
+            else:
+                rec = (operator.itemgetter(i), i, "itemgetter")
+
+            # indexes as keys
+            colfuncs[i] = rec
+            
+            # Column names as keys 
+            if colfuncs.setdefault(name.lower(), rec) is not rec: 
+                #XXX: why not raise directly? because several columns colliding 
+                #by name is not a problem as long as the user don't use them (ie 
+                #use the more precise ColumnElement 
+                colfuncs[name.lower()] = (self._ambiguous_processor(name), i, "ambiguous")
+            
+            # store the "origname" if we truncated (sqlite only)
+            if origname and \
+                    colfuncs.setdefault(origname.lower(), rec) is not rec:
+                colfuncs[name.lower()] = (self._ambiguous_processor(origname), i, "ambiguous")
+            
+            if dialect.requires_name_normalize:
+                colname = dialect.normalize_name(colname)
+                
+            self.keys.append(colname)
+            if obj:
+                for o in obj:
+                    colfuncs[o] = rec
+
+        if self._echo:
+            self.logger = context.engine.logger
+            self.logger.debug(
+                "Col %r", tuple(x[0] for x in metadata))
+
     @util.memoized_property
     def _colfunc_list(self):
         funcs = self._colfuncs
         return [funcs[i][0] for i in xrange(len(self.keys))]
 
     def _key_fallback(self, key):
-        if key in self._colfuncs:
-            return self._colfuncs[key]
-            
+        funcs = self._colfuncs
+
         if isinstance(key, basestring):
             key = key.lower()
-            if key in self._colfuncs:
-                return self._colfuncs[key]
+            if key in funcs:
+                return funcs[key]
 
+        # fallback for targeting a ColumnElement to a textual expression
+        # this is a rare use case which only occurs when matching text()
+        # constructs to ColumnElements
         if isinstance(key, expression.ColumnElement):
-            if key._label and key._label.lower() in self._colfuncs:
-                return self._colfuncs[key._label.lower()]
-            elif hasattr(key, 'name') and key.name.lower() in self._colfuncs:
-                return self._colfuncs[key.name.lower()]
-        
+            if key._label and key._label.lower() in funcs:
+                return funcs[key._label.lower()]
+            elif hasattr(key, 'name') and key.name.lower() in funcs:
+                return funcs[key.name.lower()]
+
         return None
-        
-    def close(self):
-        pass
-        
+
     def _has_key(self, row, key):
-        return self._key_fallback(key) is not None
-        
-    @property
-    def keys(self):
-        return self._keys
-        
+        if key in self._colfuncs:
+            return True
+        else:
+            key = self._key_fallback(key)
+            return key is not None
+
+    @classmethod
+    def _ambiguous_processor(cls, colname):
+        def process(value):
+            raise exc.InvalidRequestError(
+                    "Ambiguous column name '%s' in result set! "
+                    "try 'use_labels' option on select statement." % colname)
+        return process
+    
+    def __len__(self):
+        return len(self.keys)
+
+    def __getstate__(self):
+        return {
+            '_pickled_colfuncs':dict(
+                (key, (i, type_)) 
+                for key, (fn, i, type_) in self._colfuncs.iteritems() 
+                if isinstance(key, (basestring, int))
+            ),
+            'keys':self.keys
+        }
+    
+    def __setstate__(self, state):
+        pickled_colfuncs = state['_pickled_colfuncs']
+        self._colfuncs = d = {}
+        for key, (index, type_) in pickled_colfuncs.iteritems():
+            if type_ == 'ambiguous':
+                d[key] = (self._ambiguous_processor(key), index, type_)
+            else:
+                d[key] = (operator.itemgetter(index), index, "itemgetter")
+        self.keys = state['keys']
+        self._echo = False
         
 class ResultProxy(object):
     """Wraps a DB-API cursor object to provide easier access to row columns.
@@ -1752,11 +1824,10 @@ class ResultProxy(object):
 
       col3 = row[mytable.c.mycol] # access via Column object.
 
-    ResultProxy also contains a map of TypeEngine objects and will
-    invoke the appropriate ``result_processor()`` method before
-    returning columns, as well as the ExecutionContext corresponding
-    to the statement execution.  It provides several methods for which
-    to obtain information from the underlying ExecutionContext.
+    ``ResultProxy`` also handles post-processing of result column
+    data using ``TypeEngine`` objects, which are referenced from 
+    the originating SQL statement that produced this result set.
+
     """
 
     _process_row = RowProxy
@@ -1770,7 +1841,14 @@ class ResultProxy(object):
         self.connection = context.root_connection
         self._echo = context.engine._should_log_info
         self._init_metadata()
-            
+
+    def _init_metadata(self):
+        metadata = self._cursor_description()
+        if metadata is None:
+            self._metadata = None
+        else:
+            self._metadata = ResultMetaData(self, metadata)
+
     @util.memoized_property
     def rowcount(self):
         """Return the 'rowcount' for this result.
@@ -1809,6 +1887,8 @@ class ResultProxy(object):
         return self.cursor.lastrowid
     
     def _cursor_description(self):
+        """May be overridden by subclasses."""
+        
         return self.cursor.description
             
     def _autoclose(self):
@@ -1825,110 +1905,7 @@ class ResultProxy(object):
             self.close() # autoclose
             
         return self
-    
             
-    def _init_metadata(self):
-        self._metadata = metadata = self._cursor_description()
-        if metadata is None:
-            return
-        
-        self._colfuncs = colfuncs = {}
-        self.keys = []
-
-        typemap = self.dialect.dbapi_type_map
-
-        for i, (colname, coltype) in enumerate(m[0:2] for m in metadata):
-            if self.dialect.description_encoding:
-                colname = colname.decode(self.dialect.description_encoding)
-
-            if '.' in colname:
-                # sqlite will in some circumstances prepend table name to
-                # colnames, so strip
-                origname = colname
-                colname = colname.split('.')[-1]
-            else:
-                origname = None
-
-            if self.context.result_map:
-                try:
-                    name, obj, type_ = self.context.result_map[colname.lower()]
-                except KeyError:
-                    name, obj, type_ = \
-                        colname, None, typemap.get(coltype, types.NULLTYPE)
-            else:
-                name, obj, type_ = (colname, None, typemap.get(coltype, types.NULLTYPE))
-
-            processor = type_.dialect_impl(self.dialect).\
-                            result_processor(self.dialect, coltype)
-            
-            if processor:
-                def make_colfunc(processor, index):
-                    def getcol(row):
-                        return processor(row[index])
-                    return getcol
-                rec = (make_colfunc(processor, i), i, "colfunc")
-            else:
-                rec = (operator.itemgetter(i), i, "itemgetter")
-
-            # indexes as keys
-            colfuncs[i] = rec
-            
-            # Column names as keys 
-            if colfuncs.setdefault(name.lower(), rec) is not rec: 
-                #XXX: why not raise directly? because several columns colliding 
-                #by name is not a problem as long as the user don't use them (ie 
-                #use the more precise ColumnElement 
-                colfuncs[name.lower()] = (self._ambiguous_processor(name), i, "ambiguous")
-            
-            # store the "origname" if we truncated (sqlite only)
-            if origname and \
-                    colfuncs.setdefault(origname.lower(), rec) is not rec:
-                colfuncs[name.lower()] = (self._ambiguous_processor(origname), i, "ambiguous")
-            
-            if self.dialect.requires_name_normalize:
-                colname = self.dialect.normalize_name(colname)
-                
-            self.keys.append(colname)
-            if obj:
-                for o in obj:
-                    colfuncs[o] = rec
-
-        if self._echo:
-            self.context.engine.logger.debug(
-                "Col %r", tuple(x[0] for x in metadata))
-
-    @util.memoized_property
-    def _colfunc_list(self):
-        funcs = self._colfuncs
-        return [funcs[i][0] for i in xrange(len(self._metadata))]
-
-    def _key_fallback(self, key):
-        funcs = self._colfuncs
-
-        if isinstance(key, basestring):
-            key = key.lower()
-            if key in funcs:
-                return funcs[key]
-
-        # fallback for targeting a ColumnElement to a textual expression
-        # this is a rare use case which only occurs when matching text()
-        # constructs to ColumnElements
-        if isinstance(key, expression.ColumnElement):
-            if key._label and key._label.lower() in funcs:
-                return funcs[key._label.lower()]
-            elif hasattr(key, 'name') and key.name.lower() in funcs:
-                return funcs[key.name.lower()]
-        
-        return None
-
-    @classmethod
-    def _ambiguous_processor(cls, colname):
-        def process(value):
-            raise exc.InvalidRequestError(
-                    "Ambiguous column name '%s' in result set! "
-                    "try 'use_labels' option on select statement." % colname)
-        return process
-
     def close(self):
         """Close this ResultProxy.
 
@@ -1953,13 +1930,6 @@ class ResultProxy(object):
             if self.connection.should_close_with_result:
                 self.connection.close()
 
-    def _has_key(self, row, key):
-        if key in self._colfuncs:
-            return True
-        else:
-            key = self._key_fallback(key)
-            return key is not None
-
     def __iter__(self):
         while True:
             row = self.fetchone()
@@ -2048,7 +2018,8 @@ class ResultProxy(object):
 
         try:
             process_row = self._process_row
-            l = [process_row(self, row) for row in self._fetchall_impl()]
+            metadata = self._metadata
+            l = [process_row(metadata, row) for row in self._fetchall_impl()]
             self.close()
             return l
         except Exception, e:
@@ -2065,7 +2036,8 @@ class ResultProxy(object):
 
         try:
             process_row = self._process_row
-            l = [process_row(self, row) for row in self._fetchmany_impl(size)]
+            metadata = self._metadata
+            l = [process_row(metadata, row) for row in self._fetchmany_impl(size)]
             if len(l) == 0:
                 self.close()
             return l
@@ -2084,7 +2056,7 @@ class ResultProxy(object):
         try:
             row = self._fetchone_impl()
             if row is not None:
-                return self._process_row(self, row)
+                return self._process_row(self._metadata, row)
             else:
                 self.close()
                 return None
@@ -2106,7 +2078,7 @@ class ResultProxy(object):
 
         try:
             if row is not None:
-                return self._process_row(self, row)
+                return self._process_row(self._metadata, row)
             else:
                 return None
         finally:
@@ -2195,7 +2167,7 @@ class FullyBufferedResultProxy(ResultProxy):
     def _init_metadata(self):
         super(FullyBufferedResultProxy, self)._init_metadata()
         self.__rowbuffer = self._buffer_rows()
-        
+
     def _buffer_rows(self):
         return self.cursor.fetchall()
         
@@ -2240,13 +2212,13 @@ class BufferedColumnResultProxy(ResultProxy):
 
     def _init_metadata(self):
         super(BufferedColumnResultProxy, self)._init_metadata()
-        self._orig_colfuncs = self._colfuncs
-        self._colfuncs = colfuncs = {}
+        self._metadata._orig_colfuncs = self._metadata._colfuncs
+        self._metadata._colfuncs = colfuncs = {}
         # replace the parent's _colfuncs dict, replacing 
         # column processors with straight itemgetters.
         # the original _colfuncs dict is used when each row
         # is constructed.
-        for k, (colfunc, index, type_) in self._orig_colfuncs.iteritems():
+        for k, (colfunc, index, type_) in self._metadata._orig_colfuncs.iteritems():
             if type_ == "colfunc":
                 colfuncs[k] = (operator.itemgetter(index), index, "itemgetter")
             else:
index c711819f96108fecdd045cd88bdfb084978b5b90..953dcab7f6078a2130853e4fd27211e486abd9db 100644 (file)
@@ -717,7 +717,6 @@ class QueryTest(TestBase):
                 self.fail('Should not allow access to private attributes')
             except AttributeError:
                 pass # expected
-            r.close()
         finally:
             shadowed.drop(checkfirst=True)