]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged trunk r2653
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 May 2007 20:10:06 +0000 (20:10 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 May 2007 20:10:06 +0000 (20:10 +0000)
- small orm fixes

16 files changed:
CHANGES
doc/build/content/plugins.txt
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/sql_util.py
test/dialect/mysql.py
test/engine/reflection.py
test/orm/generative.py
test/orm/mapper.py
test/sql/select.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index 0730a5c27ce0977722c8b35e10764305d28a7815..d6799da84ee435d81c59560bf6ddc3e5a6452846 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -30,6 +30,8 @@
     - removed "no group by's in a select thats part of a UNION"
       restriction [ticket:578]
 - orm
+    - fixed bug in query.instances() that wouldnt handle more than
+      on additional mapper or one additional column.
     - "delete-orphan" no longer implies "delete". ongoing effort to 
       separate the behavior of these two operations.
     - many-to-many relationships properly set the type of bind params
     - the "primary_key" argument to mapper() is propigated to the "polymorphic"
       mapper.  primary key columns in this list get normalized to that of the mapper's 
       local table.
+    - restored logging of "lazy loading clause" under sa.orm.strategies logger,
+      got removed in 0.3.7
+    - improved support for eagerloading of properties off of mappers that are mapped
+      to select() statements; i.e. eagerloader is better at locating the correct
+      selectable with which to attach its LEFT OUTER JOIN.
 - mysql
+    - Nearly all MySQL column types are now supported for declaration and
+      reflection. Added NCHAR, NVARCHAR, VARBINARY, TINYBLOB, LONGBLOB, YEAR
+    - The sqltypes.Binary passthrough now builds a VARBINARY rather than a
+      BINARY if given a length
     - support for column-level CHARACTER SET and COLLATE declarations,
       as well as ASCII, UNICODE, NATIONAL and BINARY shorthand.
 - firebird
     - set max identifier length to 31
+    - supports_sane_rowcount() set to False due to ticket #370.
+      versioned_id_col feature wont work in FB.
 -extensions
     - new association proxy implementation, implementing complete
       proxies to list, dict and set-based relation collections
index 040c703fd968d21b960bd81ce6b7f800e47390e0..b4a0bebae9f68e3517eb6bce46041e66ecc6850e 100644 (file)
@@ -367,8 +367,6 @@ directly:
        keywords = AssociationProxy('keyword_associations', 'keyword')
 
 
-The `association_proxy` function is
-
 ### orderinglist
 
 **Author:** Jason Kirtland
index 0f55c856c19e8fb93a2e5b2322a695e7e92d9daf..58d6d246f86ac9f27cd037ec42cc097dd86474e5 100644 (file)
@@ -137,7 +137,7 @@ class FBDialect(ansisql.ANSIDialect):
         return sqltypes.adapt_type(typeobj, colspecs)
 
     def supports_sane_rowcount(self):
-        return True
+        return False
 
     def compiler(self, statement, bindparams, **kwargs):
         return FBCompiler(self, statement, bindparams, **kwargs)
index 9154d03d80fc66fa96d13f59f65258b3e42a6352..2fd8ca4a57d7399e004bcf7f6bc70311e5536e79 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import sys, StringIO, string, types, re, datetime
+import sys, StringIO, string, types, re, datetime, inspect
 
 from sqlalchemy import sql,engine,schema,ansisql
 from sqlalchemy.engine import default
@@ -84,7 +84,9 @@ class _StringType(object):
         self.national = national
 
     def _extend(self, spec):
-        "Extend a string-type declaration with MySQL specific extensions."
+        """Extend a string-type declaration with standard SQL CHARACTER SET /
+        COLLATE annotations and MySQL specific extensions.
+        """
         
         if self.charset:
             charset = 'CHARACTER SET %s' % self.charset
@@ -109,8 +111,41 @@ class _StringType(object):
         return ' '.join([c for c in (spec, charset, collation)
                          if c is not None])
 
+    def __repr__(self):
+        attributes = inspect.getargspec(self.__init__)[0][1:]
+        attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
+        
+        params = {}
+        for attr in attributes:
+            val = getattr(self, attr)
+            if val is not None and val is not False:
+                params[attr] = val
+
+        return "%s(%s)" % (self.__class__.__name__,
+                           ','.join(['%s=%s' % (k, params[k]) for k in params]))
+
 class MSNumeric(sqltypes.Numeric, _NumericType):
+    """MySQL NUMERIC type"""
+    
     def __init__(self, precision = 10, length = 2, **kw):
+        """Construct a NUMERIC.
+
+        precision
+          Total digits in this number.  If length and precision are both
+          None, values are stored to limits allowed by the server.
+
+        length
+          The number of digits after the decimal point.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
         _NumericType.__init__(self, **kw)
         sqltypes.Numeric.__init__(self, precision, length)
 
@@ -121,6 +156,29 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
             return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
 
 class MSDecimal(MSNumeric):
+    """MySQL DECIMAL type"""
+
+    def __init__(self, precision=10, length=2, **kw):
+        """Construct a DECIMAL.
+
+        precision
+          Total digits in this number.  If length and precision are both None,
+          values are stored to limits allowed by the server.
+
+        length
+          The number of digits after the decimal point.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
+        super(MSDecimal, self).__init__(precision, length, **kw)
+    
     def get_col_spec(self):
         if self.precision is None:
             return self._extend("DECIMAL")
@@ -130,19 +188,62 @@ class MSDecimal(MSNumeric):
             return self._extend("DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
 
 class MSDouble(MSNumeric):
+    """MySQL DOUBLE type"""
+
     def __init__(self, precision=10, length=2, **kw):
-        if (precision is None and length is not None) or (precision is not None and length is None):
+        """Construct a DOUBLE.
+
+        precision
+          Total digits in this number.  If length and precision are both None,
+          values are stored to limits allowed by the server.
+
+        length
+          The number of digits after the decimal point.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
+        if ((precision is None and length is not None) or
+            (precision is not None and length is None)):
             raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.")
         super(MSDouble, self).__init__(precision, length, **kw)
 
     def get_col_spec(self):
         if self.precision is not None and self.length is not None:
-            return self._extend("DOUBLE(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+            return self._extend("DOUBLE(%(precision)s, %(length)s)" %
+                                {'precision': self.precision,
+                                 'length' : self.length})
         else:
             return self._extend('DOUBLE')
 
 class MSFloat(sqltypes.Float, _NumericType):
+    """MySQL FLOAT type"""
+
     def __init__(self, precision=10, length=None, **kw):
+        """Construct a FLOAT.
+          
+        precision
+          Total digits in this number.  If length and precision are both None,
+          values are stored to limits allowed by the server.
+
+        length
+          The number of digits after the decimal point.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
         if length is not None:
             self.length=length
         _NumericType.__init__(self, **kw)
@@ -157,7 +258,23 @@ class MSFloat(sqltypes.Float, _NumericType):
             return self._extend("FLOAT")
 
 class MSInteger(sqltypes.Integer, _NumericType):
+    """MySQL INTEGER type"""
+
     def __init__(self, length=None, **kw):
+        """Construct an INTEGER.
+
+        length
+          Optional, maximum display width for this number.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
         self.length = length
         _NumericType.__init__(self, **kw)
         sqltypes.Integer.__init__(self)
@@ -169,6 +286,25 @@ class MSInteger(sqltypes.Integer, _NumericType):
             return self._extend("INTEGER")
 
 class MSBigInteger(MSInteger):
+    """MySQL BIGINTEGER type"""
+
+    def __init__(self, length=None, **kw):
+        """Construct a BIGINTEGER.
+
+        length
+          Optional, maximum display width for this number.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
+        super(MSBigInteger, self).__init__(length, **kw)
+
     def get_col_spec(self):
         if self.length is not None:
             return self._extend("BIGINT(%(length)s)" % {'length': self.length})
@@ -176,10 +312,26 @@ class MSBigInteger(MSInteger):
             return self._extend("BIGINT")
 
 class MSSmallInteger(sqltypes.Smallinteger, _NumericType):
+    """MySQL SMALLINTEGER type"""
+
     def __init__(self, length=None, **kw):
+        """Construct a SMALLINTEGER.
+
+        length
+          Optional, maximum display width for this number.
+
+        unsigned
+          Optional.
+
+        zerofill
+          Optional. If true, values will be stored as strings left-padded with
+          zeros. Note that this does not effect the values returned by the
+          underlying database API, which continue to be numeric.
+        """
+
         self.length = length
         _NumericType.__init__(self, **kw)
-        sqltypes.Smallinteger.__init__(self)
+        sqltypes.Smallinteger.__init__(self, length)
 
     def get_col_spec(self):
         if self.length is not None:
@@ -188,14 +340,20 @@ class MSSmallInteger(sqltypes.Smallinteger, _NumericType):
             return self._extend("SMALLINT")
 
 class MSDateTime(sqltypes.DateTime):
+    """MySQL DATETIME type"""
+
     def get_col_spec(self):
         return "DATETIME"
 
 class MSDate(sqltypes.Date):
+    """MySQL DATE type"""
+
     def get_col_spec(self):
         return "DATE"
 
 class MSTime(sqltypes.Time):
+    """MySQL TIME type"""
+
     def get_col_spec(self):
         return "TIME"
 
@@ -207,49 +365,377 @@ class MSTime(sqltypes.Time):
             return None
 
 class MSTimeStamp(sqltypes.TIMESTAMP):
+    """MySQL TIMESTAMP type
+
+    To signal the orm to automatically re-select modified rows to retrieve
+    the timestamp, add a PassiveDefault to your column specification:
+
+        from sqlalchemy.databases import mysql
+        Column('updated', mysql.MSTimeStamp, PassiveDefault(text('CURRENT_TIMESTAMP()')))
+    """
+
     def get_col_spec(self):
         return "TIMESTAMP"
 
-class MSText(sqltypes.TEXT, _StringType):
-    def __init__(self, **kwargs):
+class MSYear(sqltypes.String):
+    """MySQL YEAR type, for single byte storage of years 1901-2155"""
+
+    def get_col_spec(self):
+        if self.length is None:
+            return "YEAR"
+        else:
+            return "YEAR(%d)" % self.length
+
+class MSText(_StringType, sqltypes.TEXT):
+    """MySQL TEXT type, for text up to 2^16 characters""" 
+    
+    def __init__(self, length=None, **kwargs):
+        """Construct a TEXT.
+        
+        length
+          Optional, if provided the server may optimize storage by
+          subsitituting the smallest TEXT type sufficient to store
+          ``length`` characters.
+
+        charset
+          Optional, a column-level character set for this string
+          value.  Takes precendence to 'ascii' or 'unicode' short-hand.
+
+        collation
+          Optional, a column-level collation for this string value.
+          Takes precedence to 'binary' short-hand.
+
+        ascii
+          Defaults to False: short-hand for the ``latin1`` character set,
+          generates ASCII in schema.
+
+        unicode
+          Defaults to False: short-hand for the ``ucs2`` character set,
+          generates UNICODE in schema.
+
+        national
+          Optional. If true, use the server's configured national
+          character set.
+
+        binary
+          Defaults to False: short-hand, pick the binary collation type
+          that matches the column's character set.  Generates BINARY in
+          schema.  This does not affect the type of data stored, only the
+          collation of character data.
+        """
+
         _StringType.__init__(self, **kwargs)
-        sqltypes.TEXT.__init__(self)
+        sqltypes.TEXT.__init__(self, length)
 
     def get_col_spec(self):
-        return self._extend("TEXT")
+        if self.length:
+            return self._extend("TEXT(%d)" % self.length)
+        else:
+            return self._extend("TEXT")
+            
 
 class MSTinyText(MSText):
+    """MySQL TINYTEXT type, for text up to 2^8 characters""" 
+
+    def __init__(self, **kwargs):
+        """Construct a TINYTEXT.
+        
+        charset
+          Optional, a column-level character set for this string
+          value.  Takes precendence to 'ascii' or 'unicode' short-hand.
+
+        collation
+          Optional, a column-level collation for this string value.
+          Takes precedence to 'binary' short-hand.
+
+        ascii
+          Defaults to False: short-hand for the ``latin1`` character set,
+          generates ASCII in schema.
+
+        unicode
+          Defaults to False: short-hand for the ``ucs2`` character set,
+          generates UNICODE in schema.
+
+        national
+          Optional. If true, use the server's configured national
+          character set.
+
+        binary
+          Defaults to False: short-hand, pick the binary collation type
+          that matches the column's character set.  Generates BINARY in
+          schema.  This does not affect the type of data stored, only the
+          collation of character data.
+        """
+
+        super(MSTinyText, self).__init__(**kwargs)
+
     def get_col_spec(self):
         return self._extend("TINYTEXT")
 
 class MSMediumText(MSText):
+    """MySQL MEDIUMTEXT type, for text up to 2^24 characters""" 
+
+    def __init__(self, **kwargs):
+        """Construct a MEDIUMTEXT.
+        
+        charset
+          Optional, a column-level character set for this string
+          value.  Takes precendence to 'ascii' or 'unicode' short-hand.
+
+        collation
+          Optional, a column-level collation for this string value.
+          Takes precedence to 'binary' short-hand.
+
+        ascii
+          Defaults to False: short-hand for the ``latin1`` character set,
+          generates ASCII in schema.
+
+        unicode
+          Defaults to False: short-hand for the ``ucs2`` character set,
+          generates UNICODE in schema.
+
+        national
+          Optional. If true, use the server's configured national
+          character set.
+
+        binary
+          Defaults to False: short-hand, pick the binary collation type
+          that matches the column's character set.  Generates BINARY in
+          schema.  This does not affect the type of data stored, only the
+          collation of character data.
+        """
+
+        super(MSMediumText, self).__init__(**kwargs)
+
     def get_col_spec(self):
         return self._extend("MEDIUMTEXT")
 
 class MSLongText(MSText):
+    """MySQL LONGTEXT type, for text up to 2^32 characters""" 
+
+    def __init__(self, **kwargs):
+        """Construct a LONGTEXT.
+        
+        charset
+          Optional, a column-level character set for this string
+          value.  Takes precendence to 'ascii' or 'unicode' short-hand.
+
+        collation
+          Optional, a column-level collation for this string value.
+          Takes precedence to 'binary' short-hand.
+
+        ascii
+          Defaults to False: short-hand for the ``latin1`` character set,
+          generates ASCII in schema.
+
+        unicode
+          Defaults to False: short-hand for the ``ucs2`` character set,
+          generates UNICODE in schema.
+
+        national
+          Optional. If true, use the server's configured national
+          character set.
+
+        binary
+          Defaults to False: short-hand, pick the binary collation type
+          that matches the column's character set.  Generates BINARY in
+          schema.  This does not affect the type of data stored, only the
+          collation of character data.
+        """
+
+        super(MSLongText, self).__init__(**kwargs)
+
     def get_col_spec(self):
         return self._extend("LONGTEXT")
 
-class MSString(sqltypes.String, _StringType):
-    def __init__(self, length, national=False, **kwargs):
-        _StringType.__init__(self, national=national, **kwargs)
-        sqltypes.String.__init__(self, length, kwargs.get('convert_unicode', False))
+class MSString(_StringType, sqltypes.String):
+    """MySQL VARCHAR type, for variable-length character data."""
+
+    def __init__(self, length=None, **kwargs):
+        """Construct a VARCHAR.
+        
+        length
+          Maximum data length, in characters.
+
+        charset
+          Optional, a column-level character set for this string
+          value.  Takes precendence to 'ascii' or 'unicode' short-hand.
+
+        collation
+          Optional, a column-level collation for this string value.
+          Takes precedence to 'binary' short-hand.
+
+        ascii
+          Defaults to False: short-hand for the ``latin1`` character set,
+          generates ASCII in schema.
+
+        unicode
+          Defaults to False: short-hand for the ``ucs2`` character set,
+          generates UNICODE in schema.
+
+        national
+          Optional. If true, use the server's configured national
+          character set.
+
+        binary
+          Defaults to False: short-hand, pick the binary collation type
+          that matches the column's character set.  Generates BINARY in
+          schema.  This does not affect the type of data stored, only the
+          collation of character data.
+        """
+
+        _StringType.__init__(self, **kwargs)
+        sqltypes.String.__init__(self, length,
+                                 kwargs.get('convert_unicode', False))
 
     def get_col_spec(self):
-        return self._extend("VARCHAR(%(length)s)" % {'length' : self.length})
+        if self.length:
+            return self._extend("VARCHAR(%d)" % self.length)
+        else:
+            return self._extend("TEXT")
+
+class MSChar(_StringType, sqltypes.CHAR):
+    """MySQL CHAR type, for fixed-length character data."""
+    
+    def __init__(self, length, **kwargs):
+        """Construct an NCHAR.
+        
+        length
+          Maximum data length, in characters.
+
+        binary
+          Optional, use the default binary collation for the national character
+          set.  This does not affect the type of data stored, use a BINARY
+          type for binary data.
 
-class MSChar(sqltypes.CHAR, _StringType):
-    def __init__(self, length, national=False, **kwargs):
-        _StringType.__init__(self, national=national, **kwargs)
-        sqltypes.CHAR.__init__(self, length, kwargs.get('convert_unicode', False))
+        collation
+          Optional, request a particular collation.  Must be compatibile
+          with the national character set.
+        """
+        _StringType.__init__(self, **kwargs)
+        sqltypes.CHAR.__init__(self, length,
+                               kwargs.get('convert_unicode', False))
 
     def get_col_spec(self):
         return self._extend("CHAR(%(length)s)" % {'length' : self.length})
 
-class MSBinary(sqltypes.Binary):
+class MSNVarChar(_StringType, sqltypes.String):
+    """MySQL NVARCHAR type, for variable-length character data in the
+    server's configured national character set.
+    """
+
+    def __init__(self, length=None, **kwargs):
+        """Construct an NVARCHAR.
+        
+        length
+          Maximum data length, in characters.
+
+        binary
+          Optional, use the default binary collation for the national character
+          set.  This does not affect the type of data stored, use a VARBINARY
+          type for binary data.
+
+        collation
+          Optional, request a particular collation.  Must be compatibile
+          with the national character set.
+        """
+
+        kwargs['national'] = True
+        _StringType.__init__(self, **kwargs)
+        sqltypes.String.__init__(self, length,
+                                 kwargs.get('convert_unicode', False))
+
+    def get_col_spec(self):
+        # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
+        # of "NVARCHAR".
+        return self._extend("VARCHAR(%(length)s)" % {'length': self.length})
+    
+class MSNChar(_StringType, sqltypes.CHAR):
+    """MySQL NCHAR type, for fixed-length character data in the
+    server's configured national character set.
+    """
+
+    def __init__(self, length=None, **kwargs):
+        """Construct an NCHAR.  Arguments are:
+
+        length
+          Maximum data length, in characters.
+
+        binary
+          Optional, request the default binary collation for the
+          national character set.
+
+        collation
+          Optional, request a particular collation.  Must be compatibile
+          with the national character set.
+        """
+
+        kwargs['national'] = True
+        _StringType.__init__(self, **kwargs)
+        sqltypes.CHAR.__init__(self, length,
+                               kwargs.get('convert_unicode', False))
+    def get_col_spec(self):
+        # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
+        return self._extend("CHAR(%(length)s)" % {'length': self.length})
+
+class MSBaseBinary(sqltypes.Binary):
+    """Flexible binary type"""
+
+    def __init__(self, length=None, **kw):
+        """Flexibly construct a binary column type.  Will construct a
+        VARBINARY or BLOB depending on the length requested, if any.
+
+        length
+          Maximum data length, in bytes.
+        """
+        super(MSBaseBinary, self).__init__(length, **kw)
+
     def get_col_spec(self):
-        if self.length is not None and self.length <=255:
-            # the binary2G type seems to return a value that is null-padded
+        if self.length and self.length <= 255:
+            return "VARBINARY(%d)" % self.length
+        else:
+            return "BLOB"
+
+    def convert_result_value(self, value, dialect):
+        if value is None:
+            return None
+        else:
+            return buffer(value)
+
+class MSVarBinary(MSBaseBinary):
+    """MySQL VARBINARY type, for variable length binary data"""
+
+    def __init__(self, length=None, **kw):
+        """Construct a VARBINARY.  Arguments are:
+
+        length
+          Maximum data length, in bytes.
+        """
+        super(MSVarBinary, self).__init__(length, **kw)
+
+    def get_col_spec(self):
+        if self.length:
+            return "VARBINARY(%d)" % self.length
+        else:
+            return "BLOB"
+
+class MSBinary(MSBaseBinary):
+    """MySQL BINARY type, for fixed length binary data"""
+
+    def __init__(self, length=None, **kw):
+        """Construct a BINARY.  This is a fixed length type, and short
+        values will be right-padded with a server-version-specific
+        pad value.
+
+        length
+          Maximum data length, in bytes.  If not length is specified, this
+          will generate a BLOB.  This usage is deprecated.
+        """
+
+        super(MSBinary, self).__init__(length, **kw)
+
+    def get_col_spec(self):
+        if self.length:
             return "BINARY(%d)" % self.length
         else:
             return "BLOB"
@@ -260,20 +746,64 @@ class MSBinary(sqltypes.Binary):
         else:
             return buffer(value)
 
-class MSMediumBlob(MSBinary):
+class MSBlob(MSBaseBinary):
+    """MySQL BLOB type, for binary data up to 2^16 bytes""" 
+
+
+    def __init__(self, length=None, **kw):
+        """Construct a BLOB.  Arguments are:
+
+        length
+          Optional, if provided the server may optimize storage by
+          subsitituting the smallest TEXT type sufficient to store
+          ``length`` characters.
+        """
+
+        super(MSBlob, self).__init__(length, **kw)
+
+    def get_col_spec(self):
+        if self.length:
+            return "BLOB(%d)" % self.length
+        else:
+            return "BLOB"
+
+    def convert_result_value(self, value, dialect):
+        if value is None:
+            return None
+        else:
+            return buffer(value)
+
+    def __repr__(self):
+        return "%s()" % self.__class__.__name__
+
+class MSTinyBlob(MSBlob):
+    """MySQL TINYBLOB type, for binary data up to 2^8 bytes""" 
+
+    def get_col_spec(self):
+        return "TINYBLOB"
+
+class MSMediumBlob(MSBlob): 
+    """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes"""
+
     def get_col_spec(self):
         return "MEDIUMBLOB"
 
+class MSLongBlob(MSBlob):
+    """MySQL LONGBLOB type, for binary data up to 2^32 bytes"""
+
+    def get_col_spec(self):
+        return "LONGBLOB"
+
 class MSEnum(MSString):
-    """MySQL ENUM datatype."""
+    """MySQL ENUM type."""
     
     def __init__(self, *enums, **kw):
         """
         Construct an ENUM.
 
         Example:
+
           Column('myenum', MSEnum("'foo'", "'bar'", "'baz'"))
-          Column('another', MSEnum("'foo'", "'bar'", "'baz'", strict=True))
 
         Arguments are:
         
@@ -289,24 +819,26 @@ class MSEnum(MSString):
           instead.  (See MySQL ENUM documentation.)
 
         charset
-          Defaults to None: a column-level character set for this string
+          Optional, a column-level character set for this string
           value.  Takes precendence to 'ascii' or 'unicode' short-hand.
 
         collation
-          Defaults to None: a column-level collation for this string value.
+          Optional, a column-level collation for this string value.
           Takes precedence to 'binary' short-hand.
 
         ascii
-          Defaults to False: short-hand for the ascii character set,
+          Defaults to False: short-hand for the ``latin1`` character set,
           generates ASCII in schema.
 
         unicode
-          Defaults to False: short-hand for the utf8 character set,
+          Defaults to False: short-hand for the ``ucs2`` character set,
           generates UNICODE in schema.
 
         binary
           Defaults to False: short-hand, pick the binary collation type
-          that matches the column's character set.  Generates BINARY in schema.
+          that matches the column's character set.  Generates BINARY in
+          schema.  This does not affect the type of data stored, only the
+          collation of character data.
         """
         
         self.__ddl_values = enums
@@ -350,7 +882,7 @@ class MSBoolean(sqltypes.Boolean):
         else:
             return value and True or False
 
-# TODO: NCHAR, NVARCHAR, SET
+# TODO: SET, BIT
 
 colspecs = {
     sqltypes.Integer : MSInteger,
@@ -361,38 +893,49 @@ colspecs = {
     sqltypes.Date : MSDate,
     sqltypes.Time : MSTime,
     sqltypes.String : MSString,
-    sqltypes.Binary : MSBinary,
+    sqltypes.Binary : MSVarBinary,
     sqltypes.Boolean : MSBoolean,
     sqltypes.TEXT : MSText,
     sqltypes.CHAR: MSChar,
-    sqltypes.TIMESTAMP: MSTimeStamp
+    sqltypes.NCHAR: MSNChar,
+    sqltypes.TIMESTAMP: MSTimeStamp,
+    sqltypes.BLOB: MSBlob,
+    MSBaseBinary: MSBaseBinary,
 }
 
 
 ischema_names = {
-    'boolean':MSBoolean,
     'bigint' : MSBigInteger,
+    'binary' : MSBinary,
+    'blob' : MSBlob,
+    'boolean':MSBoolean,
+    'char' : MSChar,
+    'date' : MSDate,
+    'datetime' : MSDateTime,
+    'decimal' : MSDecimal,
+    'double' : MSDouble,
+    'enum': MSEnum,
+    'fixed': MSDecimal,
+    'float' : MSFloat,
     'int' : MSInteger,
+    'integer' : MSInteger,
+    'longblob': MSLongBlob,
+    'longtext': MSLongText,
+    'mediumblob': MSMediumBlob,
     'mediumint' : MSInteger,
-    'smallint' : MSSmallInteger,
-    'tinyint' : MSSmallInteger,
-    'varchar' : MSString,
-    'char' : MSChar,
-    'text' : MSText,
-    'tinytext' : MSTinyText,
     'mediumtext': MSMediumText,
-    'longtext': MSLongText,
-    'decimal' : MSDecimal,
+    'nchar': MSNChar,
+    'nvarchar': MSNVarChar,
     'numeric' : MSNumeric,
-    'float' : MSFloat,
-    'double' : MSDouble,
-    'timestamp' : MSTimeStamp,
-    'datetime' : MSDateTime,
-    'date' : MSDate,
+    'smallint' : MSSmallInteger,
+    'text' : MSText,
     'time' : MSTime,
-    'binary' : MSBinary,
-    'blob' : MSBinary,
-    'enum': MSEnum,
+    'timestamp' : MSTimeStamp,
+    'tinyblob': MSTinyBlob,
+    'tinyint' : MSSmallInteger,
+    'tinytext' : MSTinyText,
+    'varbinary' : MSVarBinary,
+    'varchar' : MSString,
 }
 
 def descriptor():
@@ -426,9 +969,11 @@ class MySQLDialect(ansisql.ANSIDialect):
         util.coerce_kw_type(opts, 'compress', bool)
         util.coerce_kw_type(opts, 'connect_timeout', int)
         util.coerce_kw_type(opts, 'client_flag', int)
+        util.coerce_kw_type(opts, 'local_infile', int)
         # note: these two could break SA Unicode type
         util.coerce_kw_type(opts, 'use_unicode', bool)   
         util.coerce_kw_type(opts, 'charset', str)
+        # TODO: cursorclass and conv:  support via query string or punt?
         
         # ssl
         ssl = {}
@@ -439,8 +984,9 @@ class MySQLDialect(ansisql.ANSIDialect):
                 del opts[key]
         if len(ssl):
             opts['ssl'] = ssl
-
-        # TODO: what about options like "cursorclass" and "conv" ?
+        
+        # FOUND_ROWS must be set in CLIENT_FLAGS for to enable
+        # supports_sane_rowcount.
         client_flag = opts.get('client_flag', 0)
         if self.dbapi is not None:
             try:
@@ -561,6 +1107,7 @@ class MySQLDialect(ansisql.ANSIDialect):
 
             #print "coltype: " + repr(col_type) + " args: " + repr(args) + "extras:" + repr(extra_1) + ' ' + repr(extra_2)
             coltype = ischema_names.get(col_type, MSString)
+
             kw = {}
             if extra_1 is not None:
                 kw[extra_1] = True
index 7e79a889f8eaf9d1e19808b233b9184c713cef05..98226a8064c1fc0326b7f4e1e3e40fa7ee2459ee 100644 (file)
@@ -194,6 +194,7 @@ class OracleDialect(ansisql.ANSIDialect):
             threaded = self.threaded
             )
         opts.update(url.query)
+        util.coerce_kw_type(opts, 'use_ansi', bool)
         return ([], opts)
 
     def type_descriptor(self, typeobj):
index bfddb9c994a3c7defb035b68ddf8c8ae99973597..4f93737bdc726a81950415e2bba20f65b3eac373 100644 (file)
@@ -20,7 +20,7 @@ class MapperProperty(object):
         pass
 
     def create_row_processor(self, selectcontext, mapper, row):
-        """return a tuple of a row processing and a row post-processing function.
+        """return a tuple of a row processing and an instance post-processing function.
         
         Input arguments are the query.SelectionContext and the *first*
         applicable row of a result set obtained within query.Query.instances(), called
@@ -28,8 +28,7 @@ class MapperProperty(object):
         result, and only once per result.
         
         By looking at the columns present within the row, MapperProperty
-        returns two callables which will be used to process the instance 
-        that results from the row.
+        returns two callables which will be used to process all rows and instances.
         
         callables are of the following form:
         
index 25e78b7fe55923d8d1894256e5b33d812b3c7e9c..3fe0d150e6870500613decf67acfd11c2da1e51d 100644 (file)
@@ -1586,7 +1586,7 @@ class Mapper(object):
     def _post_instance(self, selectcontext, instance):
         post_processors = selectcontext.attributes[('post_processors', self, None)]
         for p in post_processors:
-            p(instance, {})
+            p(instance)
 
     def _get_poly_select_loader(self, selectcontext, row):
         # 'select' or 'union'+col not present
index 95cd1b88781383c15fdb422c18eeebbd40db336e..1d592f4e59d58c8ff97d2764e5d5fd458394a64d 100644 (file)
@@ -372,7 +372,6 @@ class LazyLoader(AbstractRelationLoader):
             if reverse_direction:
                 li.traverse(secondaryjoin)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
         return (lazywhere, binds, reverse)
     _create_lazy_clause = classmethod(_create_lazy_clause)
     
@@ -538,22 +537,23 @@ class EagerLoader(AbstractRelationLoader):
         
         if hasattr(statement, '_outerjoin'):
             towrap = statement._outerjoin
-        elif isinstance(localparent.mapped_table, schema.Table):
-            # if the mapper is against a plain Table, look in the from_obj of the select statement
-            # to join against whats already there.
-            for (fromclause, finder) in [(x, sql_util.TableFinder(x)) for x in statement.froms]:
-                # dont join against an Alias'ed Select.  we are really looking either for the 
-                # table itself or a Join that contains the table.  this logic still might need
-                # adjustments for scenarios not thought of yet.
-                if not isinstance(fromclause, sql.Alias) and localparent.mapped_table in finder:
+        elif isinstance(localparent.mapped_table, sql.Join):
+            towrap = localparent.mapped_table
+        else:
+            # look for the mapper's selectable expressed within the current "from" criterion.
+            # this will locate the selectable inside of any containers it may be a part of (such
+            # as a join).  if its inside of a join, we want to outer join on that join, not the 
+            # selectable.
+            for fromclause in statement.froms:
+                if fromclause is localparent.mapped_table:
                     towrap = fromclause
                     break
+                elif isinstance(fromclause, sql.Join):
+                    if localparent.mapped_table in sql_util.TableFinder(fromclause, include_aliases=True):
+                        towrap = fromclause
+                        break
             else:
-                raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), self.localparent.mapped_table))
-        else:
-            # if the mapper is against a select statement or something, we cant handle that at the
-            # same time as a custom FROM clause right now.
-            towrap = localparent.mapped_table
+                raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table))
         
         try:
             clauses = self.clauses[parentclauses]
index ec22179efacb2910010c7ba31b916825f3bc5e10..265a9c1fdfabb13c8fac9495b892b357d97084cc 100644 (file)
@@ -999,7 +999,7 @@ class Compiled(ClauseVisitor):
     defaults.
     """
 
-    def __init__(self, dialect, statement, parameters, engine=None, traversal=None):
+    def __init__(self, dialect, statement, parameters, engine=None):
         """Construct a new ``Compiled`` object.
 
         statement
@@ -1022,7 +1022,6 @@ class Compiled(ClauseVisitor):
         engine
           Optional Engine to compile this statement against.
         """
-        ClauseVisitor.__init__(self, traversal=traversal)
         self.dialect = dialect
         self.statement = statement
         self.parameters = parameters
@@ -2699,6 +2698,9 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
 
     name = property(lambda s:s.keyword + " statement")
 
+    def self_group(self, against=None):
+        return _Grouping(self)
+
     def _locate_oid_column(self):
         return self.selects[0].oid_column
 
index 1f5ac168118d59c04f803968ae4011d2958b4a8f..debf1da4f9c5b5ca37216b2d9e360e3be74ec39b 100644 (file)
@@ -65,14 +65,19 @@ class TableCollection(object):
 
 
 class TableFinder(TableCollection, sql.NoColumnVisitor):
-    """Given a ``Clause``, locate all the ``Tables`` within it into a list."""
+    """locate all Tables within a clause."""
 
-    def __init__(self, table, check_columns=False):
+    def __init__(self, table, check_columns=False, include_aliases=False):
         TableCollection.__init__(self)
         self.check_columns = check_columns
+        self.include_aliases = include_aliases
         if table is not None:
             self.traverse(table)
 
+    def visit_alias(self, alias):
+        if self.include_aliases:
+            self.tables.append(alias)
+            
     def visit_table(self, table):
         self.tables.append(table)
 
index c74de74ee3cb97d1ef537a5163da13ca5f493b44..75c8c016661c8c01ebe1849a20a5e9dbcd1100b7 100644 (file)
@@ -223,7 +223,7 @@ class TypesTest(AssertMixin):
         self.assertEqual(spec(enum_table.c.e2), """e2 ENUM("a",'b') NOT NULL""")
         self.assertEqual(spec(enum_table.c.e3), """e3 ENUM("a",'b')""")
         self.assertEqual(spec(enum_table.c.e4), """e4 ENUM("a",'b') NOT NULL""")
-        enum_table.drop()
+        enum_table.drop(checkfirst=True)
         enum_table.create()
 
         try:
@@ -244,26 +244,100 @@ class TypesTest(AssertMixin):
 
         # Insert out of range enums, push stderr aside to avoid expected
         # warnings cluttering test output
-        try:
-            aside = sys.stderr
-            sys.stderr = StringIO.StringIO()
-
-            con = db.connect()
-            self.assert_(not con.connection.show_warnings())
+        con = db.connect()
+        if not hasattr(con.connection, 'show_warnings'):
             con.execute(insert(enum_table, {'e1':'c', 'e2':'c',
                                             'e3':'a', 'e4':'a'}))
-            self.assert_(con.connection.show_warnings())
-        finally:
-            sys.stderr = aside
+        else:
+            try:
+                aside = sys.stderr
+                sys.stderr = StringIO.StringIO()
+
+                self.assert_(not con.connection.show_warnings())
+
+                con.execute(insert(enum_table, {'e1':'c', 'e2':'c',
+                                                'e3':'a', 'e4':'a'}))
+
+                self.assert_(con.connection.show_warnings())
+            finally:
+                sys.stderr = aside
 
         res = enum_table.select().execute().fetchall()
 
-        # This is known to fail with MySQLDB versions < 1.2.2
-        self.assertEqual(res, [(None, 'a', None, 'a'),
-                               ('a', 'a', 'a', 'a'),
-                               ('b', 'b', 'b', 'b'),
-                               ('', '', 'a', 'a')])
+        expected = [(None, 'a', None, 'a'),
+                    ('a', 'a', 'a', 'a'),
+                    ('b', 'b', 'b', 'b'),
+                    ('', '', 'a', 'a')]
+
+        # This is known to fail with MySQLDB 1.2.2 beta versions
+        # which return these as sets.Set(['a']), sets.Set(['b'])
+        # (even on Pythons with __builtin__.set)
+        if db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
+           db.dialect.dbapi.version_info >= (1, 2, 2):
+            # these mysqldb seem to always uses 'sets', even on later pythons
+            import sets 
+            def convert(value):
+                if value is None:
+                    return value
+                if value == '':
+                    return sets.Set([])
+                else:
+                    return sets.Set([value])
+                
+            e = []
+            for row in expected:
+                e.append(tuple([convert(c) for c in row]))
+            expected = e
+
+        self.assertEqual(res, expected)
         enum_table.drop()
 
+    @testbase.supported('mysql')
+    def test_type_reflection(self):
+        # (ask_for, roundtripped_as_if_different)
+        specs = [( String(), mysql.MSText(), ),
+                 ( String(1), mysql.MSString(1), ),
+                 ( String(3), mysql.MSString(3), ),
+                 ( mysql.MSChar(1), ),
+                 ( mysql.MSChar(3), ),
+                 ( NCHAR(2), mysql.MSChar(2), ),
+                 ( mysql.MSNChar(2), mysql.MSChar(2), ), # N is CREATE only
+                 ( mysql.MSNVarChar(22), mysql.MSString(22), ),
+                 ( Smallinteger(), mysql.MSSmallInteger(), ),
+                 ( Smallinteger(4), mysql.MSSmallInteger(4), ),
+                 ( mysql.MSSmallInteger(), ),
+                 ( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ),
+                 ( Binary(3), mysql.MSVarBinary(3), ),
+                 ( Binary(), mysql.MSBlob() ),
+                 ( mysql.MSBinary(3), mysql.MSBinary(3), ),
+                 ( mysql.MSBaseBinary(), mysql.MSBlob(), ),
+                 ( mysql.MSBaseBinary(3), mysql.MSVarBinary(3), ),
+                 ( mysql.MSVarBinary(3),),
+                 ( mysql.MSVarBinary(), mysql.MSBlob()),
+                 ( mysql.MSTinyBlob(),),
+                 ( mysql.MSBlob(),),
+                 ( mysql.MSBlob(1234), mysql.MSBlob()),
+                 ( mysql.MSMediumBlob(),),
+                 ( mysql.MSLongBlob(),),
+                 ]
+
+        columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
+
+        m = BoundMetaData(db)
+        t_table = Table('mysql_types', m, *columns)
+        m.drop_all()
+        m.create_all()
+        
+        m2 = BoundMetaData(db)
+        rt = Table('mysql_types', m2, autoload=True)
+
+        expected = [len(c) > 1 and c[1] or c[0] for c in specs]
+        for i, reflected in enumerate(rt.c):
+            #print (reflected, specs[i][0], '->',
+            #       reflected.type, '==', expected[i])
+            assert type(reflected.type) == type(expected[i])
+
+        #m.drop_all()
+
 if __name__ == "__main__":
     testbase.main()
index fa359ba50f5b6c1473782da631581c96ef59e886..ea26bb64db90eee51d8944da46460d67a9e813f4 100644 (file)
@@ -206,14 +206,15 @@ class ReflectionTest(PersistTest):
             Column('num1', mysql.MSInteger(unsigned=True)),
             Column('text1', mysql.MSLongText),
             Column('text2', mysql.MSLongText()),
-             Column('num2', mysql.MSBigInteger),
-             Column('num3', mysql.MSBigInteger()),
-             Column('num4', mysql.MSDouble),
-             Column('num5', mysql.MSDouble()),
-             Column('enum1', mysql.MSEnum('"black"', '"white"')),
+            Column('num2', mysql.MSBigInteger),
+            Column('num3', mysql.MSBigInteger()),
+            Column('num4', mysql.MSDouble),
+            Column('num5', mysql.MSDouble()),
+            Column('enum1', mysql.MSEnum('"black"', '"white"')),
             )
         try:
-            table.create(checkfirst=True)
+            table.drop(checkfirst=True)
+            table.create()
             meta2 = BoundMetaData(testbase.db)
             t2 = Table('mysql_types', meta2, autoload=True)
             assert isinstance(t2.c.num1.type, mysql.MSInteger)
@@ -518,26 +519,26 @@ class SchemaTest(PersistTest):
     
     @testbase.supported('postgres')
     def testpg(self):
-        """note: this test requires that the 'test_schema' schema be separate and accessible by the test user"""
+        """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user"""
         
         meta1 = BoundMetaData(testbase.db)
         users = Table('users', meta1,
             Column('user_id', Integer, primary_key = True),
             Column('user_name', String(30), nullable = False),
-            schema="test_schema"
+            schema="alt_schema"
             )
 
         addresses = Table('email_addresses', meta1,
             Column('address_id', Integer, primary_key = True),
             Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(20)),
-            schema="test_schema"
+            schema="alt_schema"
         )
         meta1.create_all()
         try:
             meta2 = BoundMetaData(testbase.db)
-            addresses = Table('email_addresses', meta2, autoload=True, schema="test_schema")
-            users = Table('users', meta2, mustexist=True, schema="test_schema")
+            addresses = Table('email_addresses', meta2, autoload=True, schema="alt_schema")
+            users = Table('users', meta2, mustexist=True, schema="alt_schema")
 
             print users
             print addresses
index 11daf6348d5106ddd8447e3949f479b506344759..bda34ba1c8253f8bf9a084d7f804048c7dac160c 100644 (file)
@@ -86,7 +86,7 @@ class GenerativeQueryTest(PersistTest):
     
     def test_options(self):
         class ext1(MapperExtension):
-            def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
+            def populate_instance(self, mapper, selectcontext, row, instance, **flags):
                 instance.TEST = "hello world"
                 return EXT_PASS
         objectstore.clear()
index 2d8879d13ff90fa04612386b275f4ae0b229616e..de4dfa0ce03ba004381e866a2717797809962d1d 100644 (file)
@@ -378,8 +378,13 @@ class MapperTest(MapperSuperTest):
 
         l = q.select_by(items=item)
         self.assert_result(l, User, user_result[0])
-    
-    
+        
+        # TODO: this works differently from:
+        #q = sess.query(User).join(['orders', 'items']).select_by(order_id=3)
+        # because select_by() doesnt respect query._joinpoint, whereas filter_by does
+        q = sess.query(User).join(['orders', 'items']).filter_by(order_id=3).list()
+        self.assert_result(l, User, user_result[0])
+        
         try:
             # this should raise AttributeError
             l = q.select_by(items=5)
@@ -1357,6 +1362,27 @@ class EagerTest(MapperSuperTest):
         
         l = m.instances(s.execute(emailad = 'jack@bean.com'), session)
         self.echo(repr(l))
+    
+    def testonselect(self):
+        """test eager loading of a mapper which is against a select"""
+        
+        s = select([orders], orders.c.isopen==1).alias('openorders')
+        mapper(Order, s, properties={
+            'user':relation(User, lazy=False)
+        })
+        mapper(User, users)
+        
+        q = create_session().query(Order)
+        self.assert_result(q.list(), Order,
+            {'order_id':3, 'user' : (User, {'user_id':7})},
+            {'order_id':4, 'user' : (User, {'user_id':9})},
+        )
+
+        q = q.select_from(s.outerjoin(orderitems)).filter(orderitems.c.item_name != 'item 2')
+        self.assert_result(q.list(), Order,
+            {'order_id':3, 'user' : (User, {'user_id':7})},
+        )
+        
         
     def testmulti(self):
         """tests eager loading with two relations simultaneously"""
@@ -1670,7 +1696,14 @@ class InstancesTest(MapperSuperTest):
         
     def testmappersplustwocolumns(self):
         mapper(User, users)
-        s = select([users, func.count(addresses.c.address_id).label('count'), ("Name:" + users.c.user_name).label('concat')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.user_id])
+
+        # Fixme ticket #475!
+        if db.engine.name == 'mysql':
+            col2 = func.concat("Name:", users.c.user_name).label('concat')
+        else:
+            col2 = ("Name:" + users.c.user_name).label('concat')
+        
+        s = select([users, func.count(addresses.c.address_id).label('count'), col2], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.user_id])
         sess = create_session()
         (user7, user8, user9) = sess.query(User).select()
         q = sess.query(User)
index 281a0f6a381f39a991c4a5ffe86d2683201b5c3c..2f627ee8fe4fcd5567ada32b9ea9c6bcdc1540c6 100644 (file)
@@ -821,6 +821,16 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo
 
         self.runtest(select([table1], ~table1.c.myid.in_(select([table2.c.otherid]))),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid NOT IN (SELECT myothertable.otherid FROM myothertable)")
+
+        self.runtest(select([table1], table1.c.myid.in_(
+            union(
+                  select([table1], table1.c.myid == 5),
+                  select([table1], table1.c.myid == 12),
+            )
+        )), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable \
+WHERE mytable.myid IN (\
+SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid \
+UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid_1)")
         
         # test that putting a select in an IN clause does not blow away its ORDER BY clause
         self.runtest(
index 1bcc5c142dee2a612ae631da0d48e4d09c12be4a..9ee79da20dc0e730f308e56f6fa3647ab456fe31 100644 (file)
@@ -1,6 +1,6 @@
 import sys
 sys.path.insert(0, './lib/')
-import os, unittest, StringIO, re
+import os, unittest, StringIO, re, ConfigParser
 import sqlalchemy
 from sqlalchemy import sql, engine, pool
 import sqlalchemy.engine.base as base
@@ -33,10 +33,24 @@ def parse_argv():
     DBTYPE = 'sqlite'
     PROXY = False
 
+    base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
+firebird=firebird://sysdba:s@localhost/tmp/test.fdb
+"""
+    config = ConfigParser.ConfigParser()
+    config.readfp(StringIO.StringIO(base_config))
+    config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
 
     parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
     parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)")
-    parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (sqlite, sqlite_file, postgres, mysql, oracle, oracle8, mssql, firebird)")
+    parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (%s)" % ', '.join(config.options('db')))
     parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool (asserts only one connection used)")
     parser.add_option("--verbose", action="store_true", dest="verbose", help="enable stdout echoing/printing")
     parser.add_option("--quiet", action="store_true", dest="quiet", help="suppress unittest output")
@@ -47,6 +61,7 @@ def parse_argv():
     parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running")
     parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)")
     parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG")
+    parser.add_option("--require", action="append", dest="require", help="Require a particular driver or module version", default=[])
     
     (options, args) = parser.parse_args()
     sys.argv[1:] = args
@@ -57,25 +72,34 @@ def parse_argv():
     elif options.db:
         DBTYPE = param = options.db
 
+    if options.require or (config.has_section('require') and
+                           config.items('require')):
+        try:
+            import pkg_resources
+        except ImportError:
+            raise "setuptools is required for version requirements"
+
+        cmdline = []
+        for requirement in options.require:
+            pkg_resources.require(requirement)
+            cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+        if config.has_section('require'):
+            for label, requirement in config.items('require'):
+                if not label == DBTYPE or label.startswith('%s.' % DBTYPE):
+                    continue
+                seen = [c for c in cmdline if requirement.startswith(c)]
+                if seen:
+                    continue
+                pkg_resources.require(requirement)
+        
     opts = {}
     if (None == db_uri):
-        if DBTYPE == 'sqlite':
-            db_uri = 'sqlite:///:memory:'
-        elif DBTYPE == 'sqlite_file':
-            db_uri = 'sqlite:///querytest.db'
-        elif DBTYPE == 'postgres':
-            db_uri = 'postgres://scott:tiger@127.0.0.1:5432/test'
-        elif DBTYPE == 'mysql':
-            db_uri = 'mysql://scott:tiger@127.0.0.1:3306/test'
-        elif DBTYPE == 'oracle':
-            db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
-        elif DBTYPE == 'oracle8':
-            db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
-            opts['use_ansi'] = False
-        elif DBTYPE == 'mssql':
-            db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test'
-        elif DBTYPE == 'firebird':
-            db_uri = 'firebird://sysdba:s@localhost/tmp/test.fdb'
+        if DBTYPE not in config.options('db'):
+            raise ("Could not create engine.  specify --db <%s> to " 
+                   "test runner." % '|'.join(config.options('db')))
+
+        db_uri = config.get('db', DBTYPE)
 
     if not db_uri:
         raise "Could not create engine.  specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql|firebird> to test runner."