]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged trunk 2629-2730
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Jun 2007 18:37:20 +0000 (18:37 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Jun 2007 18:37:20 +0000 (18:37 +0000)
- fixes to is_select() which is now an important method
- mysql unit tests fixes

19 files changed:
doc/build/genhtml.py
doc/build/runhtml.py
examples/polymorph/polymorph.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/orderinglist.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/dialect/mysql.py
test/engine/execute.py
test/engine/reflection.py
test/ext/associationproxy.py
test/sql/query.py
test/sql/rowcount.py
test/sql/selectable.py
test/sql/testtypes.py

index 840d0362ed451457b44de511d8d960cad45a65e2..3b78da7690b243b3a751935ae4353f24d08da28d 100644 (file)
@@ -29,7 +29,7 @@ files = [
 parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
 parser.add_option("--file", action="store", dest="file", help="only generate file <file>")
 parser.add_option("--docstrings", action="store_true", dest="docstrings", help="only generate docstrings")
-parser.add_option("--version", action="store", dest="version", default="0.3.7", help="version string")
+parser.add_option("--version", action="store", dest="version", default="0.3.8", help="version string")
 
 (options, args) = parser.parse_args()
 if options.file:
index c5b34e4bf192e7f58de4da27208a045917f325e6..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100755 (executable)
@@ -1,31 +0,0 @@
-#!/usr/bin/env python\r
-import sys,re,os\r
-\r
-"""starts an HTTP server which will serve generated .myt files from the ./components and \r
-./output directories."""\r
-\r
-\r
-component_root = [\r
-    {'components': './components'},\r
-    {'content' : './output'}\r
-]\r
-doccomp = ['document_base.myt']\r
-output = os.path.dirname(os.getcwd())\r
-\r
-sys.path = ['./lib/'] + sys.path\r
-\r
-import myghty.http.HTTPServerHandler as HTTPServerHandler\r
-\r
-port = 8080\r
-httpd = HTTPServerHandler.HTTPServer(\r
-    port = port,\r
-    handlers = [\r
-        {'.*(?:\.myt|/$)' : HTTPServerHandler.HSHandler(path_translate=[(r'^/$', r'/index.myt')], data_dir = './cache', component_root = component_root, output_encoding='utf-8')},\r
-    ],\r
-\r
-    docroot = [{'.*' : '../'}],\r
-    \r
-)       \r
-\r
-print "Listening on %d" % port        \r
-httpd.serve_forever()\r
index b1f3e75158f8c9f1e3e62c7f3c24511238ea4cb4..d5d747d36c1a0d67e5586e9e3d6252ffa7b77847 100644 (file)
@@ -4,7 +4,7 @@ import sets
 
 # this example illustrates a polymorphic load of two classes
 
-metadata = BoundMetaData('sqlite://', echo='True')
+metadata = BoundMetaData('sqlite://', echo=True)
 
 # a table to store companies
 companies = Table('companies', metadata, 
index 58d6d246f86ac9f27cd037ec42cc097dd86474e5..a02781c846b195fdeb3007cd1af71c49ee42f5d1 100644 (file)
@@ -312,18 +312,10 @@ class FBCompiler(ansisql.ANSICompiler):
         else:
             self.strings[func] = func.name
 
-    def visit_insert(self, insert):
-        """Inserts are required to have the primary keys be explicitly present.
-
-         mapper will by default not put them in the insert statement
-         to comply with autoincrement fields that require they not be
-         present. So, put them all in for all primary key columns.
-         """
-
-        for c in insert.table.primary_key:
-            if not self.parameters.has_key(c.key):
-                self.parameters[c.key] = None
-        return ansisql.ANSICompiler.visit_insert(self, insert)
+    def visit_insert_column(self, column, parameters):
+        # all column primary key inserts must be explicitly present
+        if column.primary_key:
+            parameters[column.key] = None
 
     def visit_select_precolumns(self, select):
         """Called when building a ``SELECT`` statement, position is just
@@ -372,7 +364,7 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper):
 
 class FBDefaultRunner(ansisql.ANSIDefaultRunner):
     def exec_default_sql(self, default):
-        c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.engine)
+        c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.connection)
         return self.connection.execute_compiled(c).scalar()
 
     def visit_sequence(self, seq):
index 46fe990734eae8189cb3a70032a73a0e9377e155..4336296dd90fa520d9615c8b287cbcfdbc7c5415 100644 (file)
@@ -99,14 +99,14 @@ class MSDateTime(sqltypes.DateTime):
     def get_col_spec(self):
         return "DATETIME"
 
-    def convert_bind_param(self, value, dialect):
-        if hasattr(value, "isoformat"):
-            #return value.isoformat(' ')
-            # isoformat() bings on apodbapi -- reported/suggested by Peter Buschman
-            return value.strftime('%Y-%m-%d %H:%M:%S')
-        else:
-            return value
+class MSDate(sqltypes.Date):
+    def __init__(self, *a, **kw):
+        super(MSDate, self).__init__(False)
 
+    def get_col_spec(self):
+        return "SMALLDATETIME"
+
+class MSDateTime_adodbapi(MSDateTime):
     def convert_result_value(self, value, dialect):
         # adodbapi will return datetimes with empty time values as datetime.date() objects.
         # Promote them back to full datetime.datetime()
@@ -114,23 +114,34 @@ class MSDateTime(sqltypes.DateTime):
             return datetime.datetime(value.year, value.month, value.day)
         return value
 
-class MSDate(sqltypes.Date):
-    def __init__(self, *a, **kw):
-        super(MSDate, self).__init__(False)
+class MSDateTime_pyodbc(MSDateTime):
+    def convert_bind_param(self, value, dialect):
+        if value and not hasattr(value, 'second'):
+            return datetime.datetime(value.year, value.month, value.day)
+        else:
+            return value
 
-    def get_col_spec(self):
-        return "SMALLDATETIME"
-    
+class MSDate_pyodbc(MSDate):
     def convert_bind_param(self, value, dialect):
-        if value and hasattr(value, "isoformat"):
-            return value.strftime('%Y-%m-%d %H:%M')
-        return value
+        if value and not hasattr(value, 'second'):
+            return datetime.datetime(value.year, value.month, value.day)
+        else:
+            return value
 
+    def convert_result_value(self, value, dialect):
+        # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date()
+        if value and hasattr(value, 'second'):
+            return value.date()
+        else:
+            return value
+
+class MSDate_pymssql(MSDate):
     def convert_result_value(self, value, dialect):
         # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
         if value and hasattr(value, 'second'):
             return value.date()
-        return value
+        else:
+            return value
 
 class MSText(sqltypes.TEXT):
     def get_col_spec(self):
@@ -143,7 +154,7 @@ class MSString(sqltypes.String):
     def get_col_spec(self):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
 
-class MSNVarchar(MSString):
+class MSNVarchar(sqltypes.Unicode):
     def get_col_spec(self):
         if self.length:
             return "NVARCHAR(%(length)s)" % {'length' : self.length}
@@ -191,6 +202,10 @@ class MSBoolean(sqltypes.Boolean):
         else:
             return value and True or False
         
+class MSTimeStamp(sqltypes.TIMESTAMP):
+    def get_col_spec(self):
+        return "TIMESTAMP"
+        
 def descriptor():
     return {'name':'mssql',
     'description':'MSSQL',
@@ -240,7 +255,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
 
             if self.IINSERT:
                 # TODO: quoting rules for table name here ?
-                self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name)
+                self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.fullname)
 
         super(MSSQLExecutionContext, self).pre_exec()
 
@@ -253,7 +268,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
         if self.compiled.isinsert:
             if self.IINSERT:
                 # TODO: quoting rules for table name here ?
-                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name)
+                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.fullname)
                 self.IINSERT = False
             elif self.HASIDENT:
                 if self.dialect.use_scope_identity:
@@ -294,6 +309,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         sqltypes.TEXT : MSText,
         sqltypes.CHAR: MSChar,
         sqltypes.NCHAR: MSNChar,
+        sqltypes.TIMESTAMP: MSTimeStamp,
     }
 
     ischema_names = {
@@ -314,7 +330,8 @@ class MSSQLDialect(ansisql.ANSIDialect):
         'binary' : MSBinary,
         'bit': MSBoolean,
         'real' : MSFloat,
-        'image' : MSBinary
+        'image' : MSBinary,
+        'timestamp': MSTimeStamp,
     }
 
     def __new__(cls, dbapi=None, *args, **kwargs):
@@ -330,7 +347,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         super(MSSQLDialect, self).__init__(**params)
         self.auto_identity_insert = auto_identity_insert
         self.text_as_varchar = False
-        self.use_scope_identity = True
+        self.use_scope_identity = False
         self.set_default_schema_name("dbo")
 
     def dbapi(cls, module_name=None):
@@ -570,9 +587,22 @@ class MSSQLDialect_pymssql(MSSQLDialect):
         return module
     import_dbapi = classmethod(import_dbapi)
     
+    colspecs = MSSQLDialect.colspecs.copy()
+    colspecs[sqltypes.Date] = MSDate_pymssql
+
+    ischema_names = MSSQLDialect.ischema_names.copy()
+    ischema_names['smalldatetime'] = MSDate_pymssql
+
+    def __init__(self, **params):
+        super(MSSQLDialect_pymssql, self).__init__(**params)
+        self.use_scope_identity = True
+
     def supports_sane_rowcount(self):
         return True
 
+    def max_identifier_length(self):
+        return 30
+
     def do_rollback(self, connection):
         # pymssql throws an error on repeated rollbacks. Ignore it.
         # TODO: this is normal behavior for most DBs.  are we sure we want to ignore it ?
@@ -638,12 +668,21 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
     
     colspecs = MSSQLDialect.colspecs.copy()
     colspecs[sqltypes.Unicode] = AdoMSNVarchar
+    colspecs[sqltypes.Date] = MSDate_pyodbc
+    colspecs[sqltypes.DateTime] = MSDateTime_pyodbc
+
     ischema_names = MSSQLDialect.ischema_names.copy()
     ischema_names['nvarchar'] = AdoMSNVarchar
+    ischema_names['smalldatetime'] = MSDate_pyodbc
+    ischema_names['datetime'] = MSDateTime_pyodbc
 
     def supports_sane_rowcount(self):
         return False
 
+    def supports_unicode_statements(self):
+        """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
+        return True
+
     def make_connect_string(self, keys):
         connectors = ["Driver={SQL Server}"]
         connectors.append("Server=%s" % keys.get("host"))
@@ -671,12 +710,19 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
 
     colspecs = MSSQLDialect.colspecs.copy()
     colspecs[sqltypes.Unicode] = AdoMSNVarchar
+    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
+
     ischema_names = MSSQLDialect.ischema_names.copy()
     ischema_names['nvarchar'] = AdoMSNVarchar
+    ischema_names['datetime'] = MSDateTime_adodbapi
 
     def supports_sane_rowcount(self):
         return True
 
+    def supports_unicode_statements(self):
+        """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
+        return True
+
     def make_connect_string(self, keys):
         connectors = ["Provider=SQLOLEDB"]
         if 'port' in keys:
index 8b4b89d508d6557f4fbe7ca0b5347e7d81c7830a..63ce05eb68e0c017202b72ea12680736a5374a45 100644 (file)
@@ -24,7 +24,7 @@ RESERVED_WORDS = util.Set(
      'declare', 'default', 'delayed', 'delete', 'desc', 'describe',
      'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop',
      'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists',
-     'exit', 'explain', 'false', 'fetch', 'fields', 'float', 'float4', 'float8',
+     'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8',
      'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', 'having',
      'high_priority', 'hour_microsecond', 'hour_minute', 'hour_second', 'if',
      'ignore', 'in', 'index', 'infile', 'inner', 'inout', 'insensitive',
@@ -49,9 +49,11 @@ RESERVED_WORDS = util.Set(
      'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use',
      'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary',
      'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with',
-     'write', 'x509', 'xor', 'year_month', 'zerofill',
+     'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0
+     'fields', # 4.1
      'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
-     'read_only', 'read_write'])
+     'read_only', 'read_write', # 5.1
+     ])
 
 class _NumericType(object):
     "Base for MySQL numeric types."
@@ -425,7 +427,8 @@ class MSText(_StringType, sqltypes.TEXT):
         """
 
         _StringType.__init__(self, **kwargs)
-        sqltypes.TEXT.__init__(self, length)
+        sqltypes.TEXT.__init__(self, length,
+                               kwargs.get('convert_unicode', False))
 
     def get_col_spec(self):
         if self.length:
@@ -678,21 +681,12 @@ class MSNChar(_StringType, sqltypes.CHAR):
         # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
         return self._extend("CHAR(%(length)s)" % {'length': self.length})
 
-class MSBaseBinary(sqltypes.Binary):
-    """Flexible binary type"""
-
-    def __init__(self, length=None, **kw):
-        """Flexibly construct a binary column type.  Will construct a
-        VARBINARY or BLOB depending on the length requested, if any.
-
-        length
-          Maximum data length, in bytes.
-        """
-        super(MSBaseBinary, self).__init__(length, **kw)
+class _BinaryType(sqltypes.Binary):
+    """MySQL binary types"""
 
     def get_col_spec(self):
-        if self.length and self.length <= 255:
-            return "VARBINARY(%d)" % self.length
+        if self.length:
+            return "BLOB(%d)" % self.length
         else:
             return "BLOB"
 
@@ -702,7 +696,7 @@ class MSBaseBinary(sqltypes.Binary):
         else:
             return buffer(value)
 
-class MSVarBinary(MSBaseBinary):
+class MSVarBinary(_BinaryType):
     """MySQL VARBINARY type, for variable length binary data"""
 
     def __init__(self, length=None, **kw):
@@ -719,7 +713,7 @@ class MSVarBinary(MSBaseBinary):
         else:
             return "BLOB"
 
-class MSBinary(MSBaseBinary):
+class MSBinary(_BinaryType):
     """MySQL BINARY type, for fixed length binary data"""
 
     def __init__(self, length=None, **kw):
@@ -746,7 +740,7 @@ class MSBinary(MSBaseBinary):
         else:
             return buffer(value)
 
-class MSBlob(MSBaseBinary):
+class MSBlob(_BinaryType):
     """MySQL BLOB type, for binary data up to 2^16 bytes""" 
 
 
@@ -865,7 +859,7 @@ class MSEnum(MSString):
 
 class MSBoolean(sqltypes.Boolean):
     def get_col_spec(self):
-        return "BOOLEAN"
+        return "BOOL"
 
     def convert_result_value(self, value, dialect):
         if value is None:
@@ -893,14 +887,14 @@ colspecs = {
     sqltypes.Date : MSDate,
     sqltypes.Time : MSTime,
     sqltypes.String : MSString,
-    sqltypes.Binary : MSVarBinary,
+    sqltypes.Binary : MSBlob,
     sqltypes.Boolean : MSBoolean,
     sqltypes.TEXT : MSText,
     sqltypes.CHAR: MSChar,
     sqltypes.NCHAR: MSNChar,
     sqltypes.TIMESTAMP: MSTimeStamp,
     sqltypes.BLOB: MSBlob,
-    MSBaseBinary: MSBaseBinary,
+    _BinaryType: _BinaryType,
 }
 
 
@@ -951,7 +945,10 @@ def descriptor():
 class MySQLExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
         if self.compiled.isinsert:
-            self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
+            self._last_inserted_ids = [self.cursor.lastrowid]
+            
+    def is_select(self):
+        return re.match(r'SELECT|SHOW|DESCRIBE', self.statement.lstrip(), re.I) is not None
 
 class MySQLDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
@@ -1069,6 +1066,19 @@ class MySQLDialect(ansisql.ANSIDialect):
             else:
                 raise
 
+    def get_version_info(self, connectable):
+        if hasattr(connectable, 'connect'):
+            con = connectable.connect().connection
+        else:
+            con = connectable
+        version = []
+        for n in con.get_server_info().split('.'):
+            try:
+                version.append(int(n))
+            except ValueError:
+                version.append(n)
+        return tuple(version)
+
     def reflecttable(self, connection, table):
         # reference:  http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
         cs = connection.execute("show variables like 'lower_case_table_names'").fetchone()[1]
@@ -1125,7 +1135,11 @@ class MySQLDialect(ansisql.ANSIDialect):
 
             colargs= []
             if default:
-                colargs.append(schema.PassiveDefault(sql.text(default)))
+                if col_type == 'timestamp' and default == 'CURRENT_TIMESTAMP':
+                    arg = sql.text(default)
+                else:
+                    arg = default
+                colargs.append(schema.PassiveDefault(arg))
             table.append_column(schema.Column(name, coltype, *colargs,
                                             **dict(primary_key=primary_key,
                                                    nullable=nullable,
index cde8ee0981d4c27cc161aaa1d5ca4feb6002e9ca..eca0faf914118c8e666ac9eee6dba67a06c5e0ea 100644 (file)
@@ -191,6 +191,8 @@ class DefaultExecutionContext(base.ExecutionContext):
                 return proc(params)
                 
     def is_select(self):
+        """return TRUE if the statement is expected to have result rows."""
+        
         return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None
 
     def create_cursor(self):
index d6a995450ac97d4a4ceb95a80412ff13f1a03444..1b363c9acdfdf35fa894d75ab52bf31ca9db9690 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy.orm.attributes import InstrumentedList
 import sqlalchemy.exceptions as exceptions
 import sqlalchemy.orm as orm
 import sqlalchemy.util as util
+import weakref
 
 def association_proxy(targetcollection, attr, **kw):
     """Convenience function for use in mapped classes.  Implements a Python
@@ -116,6 +117,18 @@ class AssociationProxy(object):
 
     def _target_is_scalar(self):
         return not self._get_property().uselist
+
+    def _lazy_collection(self, weakobjref):
+        target = self.target_collection
+        del self
+        def lazy_collection():
+            obj = weakobjref()
+            if obj is None:
+                raise exceptions.InvalidRequestError(
+                    "stale association proxy, parent object has gone out of "
+                    "scope")
+            return getattr(obj, target)
+        return lazy_collection
         
     def __get__(self, obj, class_):
         if obj is None:
@@ -130,7 +143,7 @@ class AssociationProxy(object):
             try:
                 return getattr(obj, self.key)
             except AttributeError:
-                proxy = self._new(getattr(obj, self.target_collection))
+                proxy = self._new(self._lazy_collection(weakref.ref(obj)))
                 setattr(obj, self.key, proxy)
                 return proxy
 
@@ -153,30 +166,32 @@ class AssociationProxy(object):
     def __delete__(self, obj):
         delattr(obj, self.key)
 
-    def _new(self, collection):
+    def _new(self, lazy_collection):
         creator = self.creator and self.creator or self.target_class
 
         # Prefer class typing here to spot dicts with the required append()
         # method.
+        collection = lazy_collection()
         if isinstance(collection.data, dict):
             self.collection_class = dict
         else:
             self.collection_class = util.duck_type_collection(collection.data)
+        del collection
 
         if self.proxy_factory:
-            return self.proxy_factory(collection, creator, self.value_attr)
+            return self.proxy_factory(lazy_collection, creator, self.value_attr)
 
         value_attr = self.value_attr
         getter = lambda o: getattr(o, value_attr)
         setter = lambda o, v: setattr(o, value_attr, v)
         
         if self.collection_class is list:
-            return _AssociationList(collection, creator, getter, setter)
+            return _AssociationList(lazy_collection, creator, getter, setter)
         elif self.collection_class is dict:
             kv_setter = lambda o, k, v: setattr(o, value_attr, v)
-            return _AssociationDict(collection, creator, getter, setter)
+            return _AssociationDict(lazy_collection, creator, getter, setter)
         elif self.collection_class is util.Set:
-            return _AssociationSet(collection, creator, getter, setter)
+            return _AssociationSet(lazy_collection, creator, getter, setter)
         else:
             raise exceptions.ArgumentError(
                 'could not guess which interface to use for '
@@ -203,11 +218,11 @@ class _AssociationList(object):
     converting association objects to and from a simplified value.
     """
 
-    def __init__(self, collection, creator, getter, setter):
+    def __init__(self, lazy_collection, creator, getter, setter):
         """
-        collection
-          A list-based collection of entities (usually an object attribute
-          managed by a SQLAlchemy relation())
+        lazy_collection
+          A callable returning a list-based collection of entities (usually
+          an object attribute managed by a SQLAlchemy relation())
           
         creator
           A function that creates new target entities.  Given one parameter:
@@ -223,11 +238,13 @@ class _AssociationList(object):
           that value on the object.
         """
 
-        self.col = collection
+        self.lazy_collection = lazy_collection
         self.creator = creator
         self.getter = getter
         self.setter = setter
 
+    col = property(lambda self: self.lazy_collection())
+
     # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator.
     # (see append() below)
     def _create(self, value, **kw):
@@ -320,11 +337,11 @@ class _AssociationDict(object):
     converting association objects to and from a simplified value.
     """
 
-    def __init__(self, collection, creator, getter, setter):
+    def __init__(self, lazy_collection, creator, getter, setter):
         """
-        collection
-          A list-based collection of entities (usually an object attribute
-          managed by a SQLAlchemy relation())
+        lazy_collection
+          A callable returning a dict-based collection of entities (usually
+          an object attribute managed by a SQLAlchemy relation())
           
         creator
           A function that creates new target entities.  Given two parameters:
@@ -340,11 +357,13 @@ class _AssociationDict(object):
           that value on the object.
         """
 
-        self.col = collection
+        self.lazy_collection = lazy_collection
         self.creator = creator
         self.getter = getter
         self.setter = setter
 
+    col = property(lambda self: self.lazy_collection())
+
     def _create(self, key, value):
         return self.creator(key, value)
 
@@ -380,7 +399,7 @@ class _AssociationDict(object):
     has_key = __contains__
 
     def __iter__(self):
-        return iter(self.col)
+        return self.col.iterkeys()
 
     def clear(self):
         self.col.clear()
@@ -465,11 +484,11 @@ class _AssociationSet(object):
     converting association objects to and from a simplified value.
     """
 
-    def __init__(self, collection, creator, getter, setter):
+    def __init__(self, lazy_collection, creator, getter, setter):
         """
         collection
-          A list-based collection of entities (usually an object attribute
-          managed by a SQLAlchemy relation())
+          A callable returning a set-based collection of entities (usually an
+          object attribute managed by a SQLAlchemy relation())
           
         creator
           A function that creates new target entities.  Given one parameter:
@@ -485,11 +504,13 @@ class _AssociationSet(object):
           that value on the object.
         """
 
-        self.col = collection
+        self.lazy_collection = lazy_collection
         self.creator = creator
         self.getter = getter
         self.setter = setter
 
+    col = property(lambda self: self.lazy_collection())
+
     def _create(self, value):
         return self.creator(value)
 
index 27ff408dcdd279b93391231d308b2f94f930a27d..e02990a2633d2787b014ff163c91128a1795b7d1 100644 (file)
@@ -165,8 +165,12 @@ class OrderingList(list):
         return entity
         
     def __setitem__(self, index, entity):
-        super(OrderingList, self).__setitem__(index, entity)
-        self._order_entity(index, entity, True)
+        if isinstance(index, slice):
+            for i in range(index.start or 0, index.stop or 0, index.step or 1):
+                self.__setitem__(i, entity[i])
+        else:
+            self._order_entity(index, entity, True)
+            super(OrderingList, self).__setitem__(index, entity)
             
     def __delitem__(self, index):
         super(OrderingList, self).__delitem__(index)
index 08d3fba31814a4ffd3198d3cd185dff7da9e6676..02cd69df77e5038166cf1ec8f30ce421b2dffa15 100644 (file)
@@ -283,6 +283,8 @@ class Table(SchemaItem, sql.TableClause):
         # store extra kwargs, which should only contain db-specific options
         self.kwargs = kwargs
 
+    key = property(lambda self:_get_table_key(self.name, self.schema))
+    
     def _get_case_sensitive_schema(self):
         try:
             return getattr(self, '_case_sensitive_schema')
@@ -1116,6 +1118,10 @@ class MetaData(SchemaItem):
     def clear(self):
         self.tables.clear()
 
+    def remove(self, table):
+        # TODO: scan all other tables and remove FK _column 
+        del self.tables[table.key]
+        
     def table_iterator(self, reverse=True, tables=None):
         import sqlalchemy.sql_util
         if tables is None:
index afdfc9cb0919dd08926bf202da0441ddc6c5be8b..9b9858cc2ff2538eac8a3d5e01fef167ae10ce83 100644 (file)
@@ -791,7 +791,7 @@ class ClauseParameters(object):
     """
 
     def __init__(self, dialect, positional=None):
-        super(ClauseParameters, self).__init__(self)
+        super(ClauseParameters, self).__init__()
         self.dialect = dialect
         self.binds = {}
         self.binds_to_names = {}
@@ -1670,7 +1670,7 @@ class FromClause(Selectable):
           it merely shares a common anscestor with one of
           the exported columns of this ``FromClause``.
         """
-
+            
         if column in self.c:
             return column
         if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
@@ -1734,6 +1734,8 @@ class FromClause(Selectable):
         for co in self._adjusted_exportable_columns():
             cp = self._proxy_column(co)
             for ci in cp.orig_set:
+                # note that some ambiguity is raised here, whereby a selectable might have more than 
+                # one column that maps to an "original" column.  examples include unions and joins
                 self._orig_cols[ci] = cp
         if self.oid_column is not None:
             for ci in self.oid_column.orig_set:
@@ -2685,6 +2687,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         self.is_compound = True
         self.is_where = False
         self.is_scalar = False
+        self.is_subquery = False
 
         self.selects = selects
 
index 75c8c016661c8c01ebe1849a20a5e9dbcd1100b7..bcef161a3475aea5dc7a103c8ed60e57145c4db4 100644 (file)
@@ -104,7 +104,7 @@ class TypesTest(AssertMixin):
              'SMALLINT(4) UNSIGNED ZEROFILL'),
            ]
 
-        table_args = ['test_mysql_numeric', db]
+        table_args = ['test_mysql_numeric', BoundMetaData(db)]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
@@ -188,7 +188,7 @@ class TypesTest(AssertMixin):
              '''ENUM('foo','bar') UNICODE''')
            ]
 
-        table_args = ['test_mysql_charset', db]
+        table_args = ['test_mysql_charset', BoundMetaData(db)]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
@@ -212,7 +212,7 @@ class TypesTest(AssertMixin):
     def test_enum(self):
         "Exercise the ENUM type"
 
-        enum_table = Table('mysql_enum', db,
+        enum_table = Table('mysql_enum', BoundMetaData(db),
             Column('e1', mysql.MSEnum('"a"', "'b'")),
             Column('e2', mysql.MSEnum('"a"', "'b'"), nullable=False),
             Column('e3', mysql.MSEnum('"a"', "'b'", strict=True)),
@@ -294,6 +294,10 @@ class TypesTest(AssertMixin):
 
     @testbase.supported('mysql')
     def test_type_reflection(self):
+        # FIXME: older versions need their own test
+        if db.dialect.get_version_info(db) < (5, 0):
+            return
+
         # (ask_for, roundtripped_as_if_different)
         specs = [( String(), mysql.MSText(), ),
                  ( String(1), mysql.MSString(1), ),
@@ -307,11 +311,9 @@ class TypesTest(AssertMixin):
                  ( Smallinteger(4), mysql.MSSmallInteger(4), ),
                  ( mysql.MSSmallInteger(), ),
                  ( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ),
-                 ( Binary(3), mysql.MSVarBinary(3), ),
+                 ( Binary(3), mysql.MSBlob(3), ),
                  ( Binary(), mysql.MSBlob() ),
                  ( mysql.MSBinary(3), mysql.MSBinary(3), ),
-                 ( mysql.MSBaseBinary(), mysql.MSBlob(), ),
-                 ( mysql.MSBaseBinary(3), mysql.MSVarBinary(3), ),
                  ( mysql.MSVarBinary(3),),
                  ( mysql.MSVarBinary(), mysql.MSBlob()),
                  ( mysql.MSTinyBlob(),),
@@ -331,13 +333,14 @@ class TypesTest(AssertMixin):
         m2 = BoundMetaData(db)
         rt = Table('mysql_types', m2, autoload=True)
 
+        #print
         expected = [len(c) > 1 and c[1] or c[0] for c in specs]
         for i, reflected in enumerate(rt.c):
             #print (reflected, specs[i][0], '->',
             #       reflected.type, '==', expected[i])
-            assert type(reflected.type) == type(expected[i])
+            assert isinstance(reflected.type, type(expected[i]))
 
-        #m.drop_all()
+        m.drop_all()
 
 if __name__ == "__main__":
     testbase.main()
index af29fb2a53231a6d25e924609e83f33a492041a2..33c25201824faecd20fb1a4292b31c2520bcf220 100644 (file)
@@ -44,8 +44,10 @@ class ExecuteTest(testbase.PersistTest):
             res = conn.execute("select * from users")
             assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally'), (5, None)]
             conn.execute("delete from users")
-            
-    @testbase.supported('postgres', 'mysql')
+
+    # pyformat is supported for mysql, but skipping because a few driver
+    # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
+    @testbase.supported('postgres')
     def test_raw_python(self):
         for conn in (testbase.db, testbase.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
index 8990d262b07a044e20880b019a82675c5fffb06e..532f55e8a9ade7afcb6d8bee6d73c9ddd761ac63 100644 (file)
@@ -25,9 +25,14 @@ class ReflectionTest(PersistTest):
         if use_string_defaults:
             deftype2 = String
             defval2 = "im a default"
+            #deftype3 = DateTime
+            # the colon thing isnt working out for PG reflection just yet
+            #defval3 = '1999-09-09 00:00:00'
+            deftype3 = Date
+            defval3 = '1999-09-09'
         else:
-            deftype2 = Integer
-            defval2 = "15"
+            deftype2, deftype3 = Integer, Integer
+            defval2, defval3 = "15", "16"
         
         meta = BoundMetaData(testbase.db)
         
@@ -46,6 +51,7 @@ class ReflectionTest(PersistTest):
             Column('test_passivedefault', deftype, PassiveDefault(defval)),
             Column('test_passivedefault2', Integer, PassiveDefault("5")),
             Column('test_passivedefault3', deftype2, PassiveDefault(defval2)),
+            Column('test_passivedefault4', deftype3, PassiveDefault(defval3)),
             Column('test9', Binary(100)),
             Column('test_numeric', Numeric(None, None)),
             mysql_engine='InnoDB'
@@ -58,7 +64,7 @@ class ReflectionTest(PersistTest):
             mysql_engine='InnoDB'
         )
         meta.drop_all()
-        
+
         users.create()
         addresses.create()
 
index 57efe89e39d33edd896715eed137574e6b75a4bc..7a5731d5141fb8b01f45da2a7337b99276bbf14a 100644 (file)
@@ -131,7 +131,10 @@ class _CollectionOperations(PersistTest):
 
         self.assert_(len(p1._children) == 3)
         self.assert_(len(p1.children) == 3)
-        
+
+        p1._children = []
+        self.assert_(len(p1.children) == 0)
+
 class DefaultTest(_CollectionOperations):
     def __init__(self, *args, **kw):
         super(DefaultTest, self).__init__(*args, **kw)
@@ -209,10 +212,15 @@ class CustomDictTest(DictTest):
         self.assert_(len(p1._children) == 3)
         self.assert_(len(p1.children) == 3)
 
+        self.assert_(set(p1.children) == set(['d','e','f']))
+
         del ch
         p1 = self.roundtrip(p1)
         self.assert_(len(p1._children) == 3)
         self.assert_(len(p1.children) == 3)
+
+        p1._children = {}
+        self.assert_(len(p1.children) == 0)
     
 
 class SetTest(_CollectionOperations):
@@ -312,6 +320,9 @@ class SetTest(_CollectionOperations):
         p1 = self.roundtrip(p1)
         self.assert_(p1.children == set(['c']))
 
+        p1._children = []
+        self.assert_(len(p1.children) == 0)
+
     def test_set_comparisons(self):
         Parent, Child = self.Parent, self.Child
 
index 5d5a734108b530ba63bd97607028410e95d9fa20..0d12aa19399b6e8458618c34a80d4343c2df6b4a 100644 (file)
@@ -33,7 +33,49 @@ class QueryTest(PersistTest):
     def testinsert(self):
         users.insert().execute(user_id = 7, user_name = 'jack')
         assert users.count().scalar() == 1
+    
+    @testbase.unsupported('sqlite')
+    def test_lastrow_accessor(self):
+        """test the last_inserted_ids() and lastrow_has_id() functions"""
         
+        def insert_values(table, values):
+            result = table.insert().execute(**values)
+            ret = values.copy()
+            
+            for col, id in zip(table.primary_key, result.last_inserted_ids()):
+                ret[col.key] = id
+                
+            if result.lastrow_has_defaults():
+                criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
+                row = table.select(criterion).execute().fetchone()
+                ret.update(row)
+            return ret
+            
+        for table, values, assertvalues in [
+            (
+                Table("t1", metadata, 
+                    Column('id', Integer, primary_key=True),
+                    Column('foo', String(30), primary_key=True)),
+                {'foo':'hi'},
+                {'id':1, 'foo':'hi'}
+            ),
+            (
+                Table("t2", metadata, 
+                    Column('id', Integer, primary_key=True),
+                    Column('foo', String(30), primary_key=True),
+                    Column('bar', String(30), PassiveDefault('hi'))
+                ),
+                {'foo':'hi'},
+                {'id':1, 'foo':'hi', 'bar':'hi'}
+            ),
+            
+        ]:
+            try:
+                table.create()
+                assert insert_values(table, values) == assertvalues
+            finally:
+                table.drop()
+            
     def testupdate(self):
 
         users.insert().execute(user_id = 7, user_name = 'jack')
@@ -360,6 +402,19 @@ class QueryTest(PersistTest):
             con.execute("""drop trigger paj""")
             meta.drop_all()
 
+    @testbase.supported('mssql')
+    def test_insertid_schema(self):
+        meta = BoundMetaData(testbase.db)
+        con = testbase.db.connect()
+        con.execute('create schema paj')
+        tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj')
+        tbl.create()        
+        try:
+            tbl.insert().execute({'id':1})        
+        finally:
+            tbl.drop()
+            con.execute('drop schema paj')
+        
 
 class CompoundTest(PersistTest):
     """test compound statements like UNION, INTERSECT, particularly their ability to nest on
index 05d0f21105a76f7512c1e9d8c7f94ee23d823ba6..95cab898c3cd223deef184ebe2b2ff2132774a17 100644 (file)
@@ -31,7 +31,7 @@ class FoundRowsTest(testbase.AssertMixin):
         i.execute(*[{'name':n, 'department':d} for n, d in data])
     def tearDown(self):
         employees_table.delete().execute()
-        
+
     def tearDownAll(self):
         employees_table.drop()
 
@@ -45,23 +45,26 @@ class FoundRowsTest(testbase.AssertMixin):
         # WHERE matches 3, 3 rows changed
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='Z')
-        assert r.rowcount == 3
-        
+        if testbase.db.dialect.supports_sane_rowcount():
+            assert r.rowcount == 3
+
     def test_update_rowcount2(self):
         # WHERE matches 3, 0 rows changed
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='C')
-        assert r.rowcount == 3
-        
+        if testbase.db.dialect.supports_sane_rowcount():
+            assert r.rowcount == 3
+
     def test_delete_rowcount(self):
         # WHERE matches 3, 3 rows deleted
         department = employees_table.c.department
         r = employees_table.delete(department=='C').execute()
-        assert r.rowcount == 3
+        if testbase.db.dialect.supports_sane_rowcount():
+            assert r.rowcount == 3
 
 if __name__ == '__main__':
     testbase.main()
-    
+
 
 
 
index f236c60f0436602603cbf77927fbe51dfeeabfbd..b7ab91ee8fbbd6c0d236c5d99bf154e742ce6c99 100755 (executable)
@@ -27,6 +27,11 @@ table2 = Table('table2', metadata,
 )\r
 \r
 class SelectableTest(testbase.AssertMixin):\r
+    def testjoinagainstself(self):\r
+        jj = select([table.c.col1.label('bar_col1')])\r
+        jjj = join(table, jj, table.c.col1==jj.c.bar_col1)\r
+        assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1\r
+\r
     def testjoinagainstjoin(self):\r
         j  = outerjoin(table, table2, table.c.col1==table2.c.col2)\r
         jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')\r
index d494909ea89addfece5c5fc47dea95208168bb23..8406860ee7ceb34c49a284a70b56d638d1fa2a50 100644 (file)
@@ -142,28 +142,35 @@ class UnicodeTest(AssertMixin):
         metadata = BoundMetaData(db)
         unicode_table = Table('unicode_table', metadata, 
             Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True),
-            Column('unicode_data', Unicode(250)),
-            Column('plain_data', String(250))
+            Column('unicode_varchar', Unicode(250)),
+            Column('unicode_text', Unicode),
+            Column('plain_varchar', String(250))
             )
         unicode_table.create()
     def tearDownAll(self):
         unicode_table.drop()
+
     def testbasic(self):
-        assert unicode_table.c.unicode_data.type.length == 250
+        assert unicode_table.c.unicode_varchar.type.length == 250
         rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
         unicodedata = rawdata.decode('utf-8')
-        unicode_table.insert().execute(unicode_data=unicodedata, plain_data=rawdata)
+        unicode_table.insert().execute(unicode_varchar=unicodedata,
+                                       unicode_text=unicodedata,
+                                       plain_varchar=rawdata)
         x = unicode_table.select().execute().fetchone()
-        self.echo(repr(x['unicode_data']))
-        self.echo(repr(x['plain_data']))
-        self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata)
-        if isinstance(x['plain_data'], unicode):
+        self.echo(repr(x['unicode_varchar']))
+        self.echo(repr(x['unicode_text']))
+        self.echo(repr(x['plain_varchar']))
+        self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
+        self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
+        if isinstance(x['plain_varchar'], unicode):
             # SQLLite and MSSQL return non-unicode data as unicode
             self.assert_(db.name in ('sqlite', 'mssql'))
-            self.assert_(x['plain_data'] == unicodedata)
+            self.assert_(x['plain_varchar'] == unicodedata)
             self.echo("it's %s!" % db.name)
         else:
-            self.assert_(not isinstance(x['plain_data'], unicode) and x['plain_data'] == rawdata)
+            self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
+
     def testengineparam(self):
         """tests engine-wide unicode conversion"""
         prev_unicode = db.engine.dialect.convert_unicode
@@ -171,17 +178,24 @@ class UnicodeTest(AssertMixin):
             db.engine.dialect.convert_unicode = True
             rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
             unicodedata = rawdata.decode('utf-8')
-            unicode_table.insert().execute(unicode_data=unicodedata, plain_data=rawdata)
+            unicode_table.insert().execute(unicode_varchar=unicodedata,
+                                           unicode_text=unicodedata,
+                                           plain_varchar=rawdata)
             x = unicode_table.select().execute().fetchone()
-            self.echo(repr(x['unicode_data']))
-            self.echo(repr(x['plain_data']))
-            self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata)
-            self.assert_(isinstance(x['plain_data'], unicode) and x['plain_data'] == unicodedata)
+            self.echo(repr(x['unicode_varchar']))
+            self.echo(repr(x['unicode_text']))
+            self.echo(repr(x['plain_varchar']))
+            self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
+            self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
+            self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata)
         finally:
             db.engine.dialect.convert_unicode = prev_unicode
-    
-
 
+    def testlength(self):
+        """checks the database correctly understands the length of a unicode string"""
+        teststr = u'aaa\x1234'
+        self.assert_(db.func.length(teststr).scalar() == len(teststr))
+  
 class BinaryTest(AssertMixin):
     def setUpAll(self):
         global binary_table
@@ -305,6 +319,24 @@ class DateTest(AssertMixin):
         #x = db.text("select * from query_users_with_date where user_datetime=:date", bindparams=[bindparam('date', )]).execute(date=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall()
         #print repr(x)
 
+    @testbase.unsupported('sqlite')
+    def testdate2(self):
+        t = Table('testdate', testbase.metadata, Column('id', Integer, primary_key=True),
+                Column('adate', Date), Column('adatetime', DateTime))
+        t.create()
+        try:
+            d1 = datetime.date(2007, 10, 30)
+            t.insert().execute(adate=d1, adatetime=d1)
+            d2 = datetime.datetime(2007, 10, 30)
+            t.insert().execute(adate=d2, adatetime=d2)
+
+            x = t.select().execute().fetchall()[0]
+            self.assert_(x.adate.__class__ == datetime.date)
+            self.assert_(x.adatetime.__class__ == datetime.datetime)
+
+        finally:
+            t.drop()
+
 class TimezoneTest(AssertMixin):
     """test timezone-aware datetimes.  psycopg will return a datetime with a tzinfo attached to it,
     if postgres returns it.  python then will not let you compare a datetime with a tzinfo to a datetime