]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- implemented RowProxy.__ne__ [ticket:945], thanks knutroy
authorJason Kirtland <jek@discorporate.us>
Thu, 31 Jan 2008 04:49:31 +0000 (04:49 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 31 Jan 2008 04:49:31 +0000 (04:49 +0000)
- test coverage for same

lib/sqlalchemy/engine/base.py
test/sql/query.py

index 61755013b049e246f164c1746d8a6887d131cc1e..733d77a69635a505e02c2942f96dc5a0c4aa7614 100644 (file)
@@ -96,8 +96,8 @@ class Dialect(object):
       This is used to apply types to result sets based on the DB-API
       types present in cursor.description; it only takes effect for
       result sets against textual statements where no explicit
-      typemap was present.  
-      
+      typemap was present.
+
     """
 
     def create_connect_args(self, url):
@@ -317,14 +317,14 @@ class ExecutionContext(object):
 
     should_autocommit
       True if the statement is a "committable" statement
-      
+
     returns_rows
       True if the statement should return result rows
 
     postfetch_cols
      a list of Column objects for which a server-side default
      or inline SQL expression value was fired off.  applies to inserts and updates.
-      
+
     The Dialect should provide an ExecutionContext via the
     create_execution_context() method.  The `pre_exec` and `post_exec`
     methods will be called for compiled statements.
@@ -375,9 +375,9 @@ class ExecutionContext(object):
 
     def should_autocommit_compiled(self, compiled):
         """return True if the given Compiled object refers to a "committable" statement."""
-        
+
         raise NotImplementedError()
-        
+
     def should_autocommit_text(self, statement):
         """Parse the given textual statement and return True if it refers to a "committable" statement"""
 
@@ -411,7 +411,7 @@ class ExecutionContext(object):
         raise NotImplementedError()
 
     def lastrow_has_defaults(self):
-        """Return True if the last INSERT or UPDATE row contained 
+        """Return True if the last INSERT or UPDATE row contained
         inlined or database-side defaults.
         """
 
@@ -436,7 +436,7 @@ class Compiled(object):
 
         dialect
           ``Dialect`` to compile against.
-          
+
         statement
           ``ClauseElement`` to be compiled.
 
@@ -452,12 +452,12 @@ class Compiled(object):
         self.column_keys = column_keys
         self.bind = bind
         self.can_execute = statement.supports_execution()
-    
+
     def compile(self):
         """Produce the internal string representation of this element."""
-        
+
         raise NotImplementedError()
-        
+
     def __str__(self):
         """Return the string text of the generated SQL statement."""
 
@@ -473,7 +473,7 @@ class Compiled(object):
     def construct_params(self, params):
         """Return the bind params for this compiled object.
 
-        `params` is a dict of string/object pairs whos 
+        `params` is a dict of string/object pairs whos
         values will override bind values compiled in
         to the statement.
         """
@@ -545,7 +545,7 @@ class Connection(Connectable):
         self.__invalid = False
 
     def _branch(self):
-        """Return a new Connection which references this Connection's 
+        """Return a new Connection which references this Connection's
         engine and connection; but does not have close_with_result enabled,
         and also whose close() method does nothing.
 
@@ -556,22 +556,22 @@ class Connection(Connectable):
 
     def dialect(self):
         "Dialect used by this Connection."
-        
+
         return self.engine.dialect
     dialect = property(dialect)
-    
+
     def closed(self):
         """return True if this connection is closed."""
-        
+
         return not self.__invalid and '_Connection__connection' not in self.__dict__
     closed = property(closed)
-    
+
     def invalidated(self):
         """return True if this connection was invalidated."""
-        
+
         return self.__invalid
     invalidated = property(invalidated)
-    
+
     def connection(self):
         "The underlying DB-API connection managed by this Connection."
 
@@ -586,7 +586,7 @@ class Connection(Connectable):
                 return self.__connection
             raise exceptions.InvalidRequestError("This Connection is closed")
     connection = property(connection)
-    
+
     def should_close_with_result(self):
         """Indicates if this Connection should be closed when a corresponding
         ResultProxy is closed; this is essentially an auto-release mode.
@@ -628,16 +628,16 @@ class Connection(Connectable):
         The underlying DB-API connection is literally closed (if
         possible), and is discarded.  Its source connection pool will
         typically lazily create a new connection to replace it.
-        
+
         Upon the next usage, this Connection will attempt to reconnect
         to the pool with a new connection.
 
         Transactions in progress remain in an "opened" state (even though
-        the actual transaction is gone); these must be explicitly 
+        the actual transaction is gone); these must be explicitly
         rolled back before a reconnect on this Connection can proceed.  This
         is to prevent applications from accidentally continuing their transactional
         operations in a non-transactional state.
-        
+
         """
 
         if self.__connection.is_valid:
@@ -757,7 +757,7 @@ class Connection(Connectable):
         except Exception, e:
             self._handle_dbapi_exception(e, None, None, None)
             raise
-        
+
     def _savepoint_impl(self, name=None):
         if name is None:
             self.__savepoint_seq += 1
@@ -829,7 +829,7 @@ class Connection(Connectable):
 
     def scalar(self, object, *multiparams, **params):
         """Executes and returns the first column of the first row.
-        
+
         The underlying result/cursor is closed after execution.
         """
 
@@ -859,11 +859,11 @@ class Connection(Connectable):
 
     def __distill_params(self, multiparams, params):
         """given arguments from the calling form *multiparams, **params, return a list
-        of bind parameter structures, usually a list of dictionaries.  
-        
-        in the case of 'raw' execution which accepts positional parameters, 
+        of bind parameter structures, usually a list of dictionaries.
+
+        in the case of 'raw' execution which accepts positional parameters,
         it may be a list of tuples or lists."""
-        
+
         if multiparams is None or len(multiparams) == 0:
             if params:
                 return [params]
@@ -937,7 +937,7 @@ class Connection(Connectable):
             raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
         finally:
             del self._reentrant_error
-        
+
     def __create_execution_context(self, **kwargs):
         try:
             return self.engine.dialect.create_execution_context(connection=self, **kwargs)
@@ -1005,7 +1005,7 @@ class Transaction(object):
         self._connection = connection
         self._parent = parent or self
         self._is_active = True
-    
+
     def connection(self):
         "The Connection object referenced by this Transaction"
         return self._connection
@@ -1061,7 +1061,7 @@ class RootTransaction(Transaction):
     def __init__(self, connection):
         super(RootTransaction, self).__init__(connection, None)
         self._connection._begin_impl()
-    
+
     def _do_rollback(self):
         self._connection._rollback_impl()
 
@@ -1072,7 +1072,7 @@ class NestedTransaction(Transaction):
     def __init__(self, connection, parent):
         super(NestedTransaction, self).__init__(connection, parent)
         self._savepoint = self._connection._savepoint_impl()
-    
+
     def _do_rollback(self):
         self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent)
 
@@ -1085,16 +1085,16 @@ class TwoPhaseTransaction(Transaction):
         self._is_prepared = False
         self.xid = xid
         self._connection._begin_twophase_impl(self.xid)
-    
+
     def prepare(self):
         if not self._parent._is_active:
             raise exceptions.InvalidRequestError("This transaction is inactive")
         self._connection._prepare_twophase_impl(self.xid)
         self._is_prepared = True
-    
+
     def _do_rollback(self):
         self._connection._rollback_twophase_impl(self.xid, self._is_prepared)
-    
+
     def commit(self):
         self._connection._commit_twophase_impl(self.xid, self._is_prepared)
 
@@ -1114,15 +1114,15 @@ class Engine(Connectable):
 
     def name(self):
         "String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``."
-        
+
         return sys.modules[self.dialect.__module__].descriptor()['name']
     name = property(name)
-    
+
     echo = logging.echo_property()
-    
+
     def __repr__(self):
         return 'Engine(%s)' % str(self.url)
-    
+
     def dispose(self):
         self.pool.dispose()
         self.pool = self.pool.recreate()
@@ -1231,7 +1231,7 @@ class Engine(Connectable):
         """
 
         return Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs)
-    
+
     def table_names(self, schema=None, connection=None):
         """Return a list of all table names available in the database.
 
@@ -1308,7 +1308,7 @@ class RowProxy(object):
 
     def __len__(self):
         return len(self.__row)
-        
+
     def __iter__(self):
         for i in xrange(len(self.__row)):
             yield self.__parent._get_col(self.__row, i)
@@ -1318,6 +1318,9 @@ class RowProxy(object):
                 (other == tuple([self.__parent._get_col(self.__row, key)
                                  for key in xrange(len(self.__row))])))
 
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
     def __repr__(self):
         return repr(tuple(self))
 
@@ -1402,11 +1405,11 @@ class ResultProxy(object):
         else:
             return self.context.get_rowcount()
     rowcount = property(rowcount)
-    
+
     def lastrowid(self):
         return self.cursor.lastrowid
     lastrowid = property(lastrowid)
-    
+
     def out_parameters(self):
         return self.context.out_parameters
     out_parameters = property(out_parameters)
@@ -1422,14 +1425,14 @@ class ResultProxy(object):
 
             for i, item in enumerate(metadata):
                 colname = item[0].decode(self.dialect.encoding)
-                
+
                 if '.' in colname:
                     # sqlite will in some circumstances prepend table name to colnames, so strip
                     origname = colname
                     colname = colname.split('.')[-1]
                 else:
                     origname = None
-                    
+
                 if self.context.result_map:
                     try:
                         (name, obj, type_) = self.context.result_map[colname.lower()]
@@ -1442,12 +1445,12 @@ class ResultProxy(object):
 
                 if self.__props.setdefault(name.lower(), rec) is not rec:
                     self.__props[name.lower()] = (type_, self.__ambiguous_processor(name), 0)
-                
+
                 # store the "origname" if we truncated (sqlite only)
                 if origname:
                     if self.__props.setdefault(origname.lower(), rec) is not rec:
                         self.__props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0)
-                    
+
                 self.__keys.append(colname)
                 self.__props[i] = rec
                 if obj:
@@ -1465,10 +1468,10 @@ class ResultProxy(object):
             """Given a key, which could be a ColumnElement, string, etc.,
             matches it to the appropriate key we got from the result set's
             metadata; then cache it locally for quick re-access."""
-            
+
             if isinstance(key, basestring):
                 key = key.lower()
-            
+
             try:
                 rec = props[key]
             except KeyError:
@@ -1556,14 +1559,14 @@ class ResultProxy(object):
         """
 
         return self.context.lastrow_has_defaults()
-    
+
     def postfetch_cols(self):
         """Return ``postfetch_cols()`` from the underlying ExecutionContext.
 
         See ExecutionContext for details.
         """
         return self.context.postfetch_cols
-        
+
     def supports_sane_rowcount(self):
         """Return ``supports_sane_rowcount`` from the dialect.
 
@@ -1587,7 +1590,7 @@ class ResultProxy(object):
                 return tuple([self._get_col(row, i) for i in xrange(*indices)])
             else:
                 raise
-                
+
         if processor:
             return processor(row[index])
         else:
@@ -1595,10 +1598,10 @@ class ResultProxy(object):
 
     def _fetchone_impl(self):
         return self.cursor.fetchone()
-        
+
     def _fetchmany_impl(self, size=None):
         return self.cursor.fetchmany(size)
-        
+
     def _fetchall_impl(self):
         return self.cursor.fetchall()
 
@@ -1809,7 +1812,7 @@ class DefaultRunner(schema.SchemaVisitor):
         conn = self.context.connection
         c = expression.select([default.arg]).compile(bind=conn)
         return conn._execute_compiled(c).scalar()
-    
+
     def execute_string(self, stmt, params=None):
         """execute a string statement, using the raw cursor,
         and return a scalar result."""
@@ -1818,7 +1821,7 @@ class DefaultRunner(schema.SchemaVisitor):
             stmt = stmt.encode(self.dialect.encoding)
         conn._cursor_execute(self.context.cursor, stmt, params)
         return self.context.cursor.fetchone()[0]
-        
+
     def visit_column_onupdate(self, onupdate):
         if isinstance(onupdate.arg, expression.ClauseElement):
             return self.exec_default_sql(onupdate)
index 784ab040792606fb64977fafeca1874540527f70..19d11a2f1f69f6873e2690b8931fe2c3066b8ac6 100644 (file)
@@ -138,6 +138,20 @@ class QueryTest(PersistTest):
             l.append(row)
         self.assert_(len(l) == 3)
 
+    def test_row_comparison(self):
+        users.insert().execute(user_id = 7, user_name = 'jack')
+        rp = users.select().execute().fetchone()
+
+        self.assert_(rp == rp)
+        self.assert_(not(rp != rp))
+
+        equal = (7, 'jack')
+
+        self.assert_(rp == equal)
+        self.assert_(equal == rp)
+        self.assert_(not (rp != equal))
+        self.assert_(not (equal != equal))
+
     def test_fetchmany(self):
         users.insert().execute(user_id = 7, user_name = 'jack')
         users.insert().execute(user_id = 8, user_name = 'ed')