From: Mike Bayer Date: Thu, 14 Jun 2007 18:37:20 +0000 (+0000) Subject: - merged trunk 2629-2730 X-Git-Tag: rel_0_4_6~199 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=47c3ce38aad47446a611b7a6d9367b73c0b88d0b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - merged trunk 2629-2730 - fixes to is_select() which is now an important method - mysql unit tests fixes --- diff --git a/doc/build/genhtml.py b/doc/build/genhtml.py index 840d0362ed..3b78da7690 100644 --- a/doc/build/genhtml.py +++ b/doc/build/genhtml.py @@ -29,7 +29,7 @@ files = [ parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]") parser.add_option("--file", action="store", dest="file", help="only generate file ") parser.add_option("--docstrings", action="store_true", dest="docstrings", help="only generate docstrings") -parser.add_option("--version", action="store", dest="version", default="0.3.7", help="version string") +parser.add_option("--version", action="store", dest="version", default="0.3.8", help="version string") (options, args) = parser.parse_args() if options.file: diff --git a/doc/build/runhtml.py b/doc/build/runhtml.py index c5b34e4bf1..e69de29bb2 100755 --- a/doc/build/runhtml.py +++ b/doc/build/runhtml.py @@ -1,31 +0,0 @@ -#!/usr/bin/env python -import sys,re,os - -"""starts an HTTP server which will serve generated .myt files from the ./components and -./output directories.""" - - -component_root = [ - {'components': './components'}, - {'content' : './output'} -] -doccomp = ['document_base.myt'] -output = os.path.dirname(os.getcwd()) - -sys.path = ['./lib/'] + sys.path - -import myghty.http.HTTPServerHandler as HTTPServerHandler - -port = 8080 -httpd = HTTPServerHandler.HTTPServer( - port = port, - handlers = [ - {'.*(?:\.myt|/$)' : HTTPServerHandler.HSHandler(path_translate=[(r'^/$', r'/index.myt')], data_dir = './cache', component_root = component_root, output_encoding='utf-8')}, - ], - - docroot = [{'.*' : '../'}], - -) - -print "Listening on %d" % port -httpd.serve_forever() diff --git a/examples/polymorph/polymorph.py b/examples/polymorph/polymorph.py index b1f3e75158..d5d747d36c 100644 --- a/examples/polymorph/polymorph.py +++ b/examples/polymorph/polymorph.py @@ -4,7 +4,7 @@ import sets # this example illustrates a polymorphic load of two classes -metadata = BoundMetaData('sqlite://', echo='True') +metadata = BoundMetaData('sqlite://', echo=True) # a table to store companies companies = Table('companies', metadata, diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 58d6d246f8..a02781c846 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -312,18 +312,10 @@ class FBCompiler(ansisql.ANSICompiler): else: self.strings[func] = func.name - def visit_insert(self, insert): - """Inserts are required to have the primary keys be explicitly present. - - mapper will by default not put them in the insert statement - to comply with autoincrement fields that require they not be - present. So, put them all in for all primary key columns. - """ - - for c in insert.table.primary_key: - if not self.parameters.has_key(c.key): - self.parameters[c.key] = None - return ansisql.ANSICompiler.visit_insert(self, insert) + def visit_insert_column(self, column, parameters): + # all column primary key inserts must be explicitly present + if column.primary_key: + parameters[column.key] = None def visit_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just @@ -372,7 +364,7 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper): class FBDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.engine) + c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.connection) return self.connection.execute_compiled(c).scalar() def visit_sequence(self, seq): diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 46fe990734..4336296dd9 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -99,14 +99,14 @@ class MSDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATETIME" - def convert_bind_param(self, value, dialect): - if hasattr(value, "isoformat"): - #return value.isoformat(' ') - # isoformat() bings on apodbapi -- reported/suggested by Peter Buschman - return value.strftime('%Y-%m-%d %H:%M:%S') - else: - return value +class MSDate(sqltypes.Date): + def __init__(self, *a, **kw): + super(MSDate, self).__init__(False) + def get_col_spec(self): + return "SMALLDATETIME" + +class MSDateTime_adodbapi(MSDateTime): def convert_result_value(self, value, dialect): # adodbapi will return datetimes with empty time values as datetime.date() objects. # Promote them back to full datetime.datetime() @@ -114,23 +114,34 @@ class MSDateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day) return value -class MSDate(sqltypes.Date): - def __init__(self, *a, **kw): - super(MSDate, self).__init__(False) +class MSDateTime_pyodbc(MSDateTime): + def convert_bind_param(self, value, dialect): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value - def get_col_spec(self): - return "SMALLDATETIME" - +class MSDate_pyodbc(MSDate): def convert_bind_param(self, value, dialect): - if value and hasattr(value, "isoformat"): - return value.strftime('%Y-%m-%d %H:%M') - return value + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + def convert_result_value(self, value, dialect): + # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + +class MSDate_pymssql(MSDate): def convert_result_value(self, value, dialect): # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() if value and hasattr(value, 'second'): return value.date() - return value + else: + return value class MSText(sqltypes.TEXT): def get_col_spec(self): @@ -143,7 +154,7 @@ class MSString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} -class MSNVarchar(MSString): +class MSNVarchar(sqltypes.Unicode): def get_col_spec(self): if self.length: return "NVARCHAR(%(length)s)" % {'length' : self.length} @@ -191,6 +202,10 @@ class MSBoolean(sqltypes.Boolean): else: return value and True or False +class MSTimeStamp(sqltypes.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + def descriptor(): return {'name':'mssql', 'description':'MSSQL', @@ -240,7 +255,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): if self.IINSERT: # TODO: quoting rules for table name here ? - self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name) + self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.fullname) super(MSSQLExecutionContext, self).pre_exec() @@ -253,7 +268,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): if self.compiled.isinsert: if self.IINSERT: # TODO: quoting rules for table name here ? - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name) + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.fullname) self.IINSERT = False elif self.HASIDENT: if self.dialect.use_scope_identity: @@ -294,6 +309,7 @@ class MSSQLDialect(ansisql.ANSIDialect): sqltypes.TEXT : MSText, sqltypes.CHAR: MSChar, sqltypes.NCHAR: MSNChar, + sqltypes.TIMESTAMP: MSTimeStamp, } ischema_names = { @@ -314,7 +330,8 @@ class MSSQLDialect(ansisql.ANSIDialect): 'binary' : MSBinary, 'bit': MSBoolean, 'real' : MSFloat, - 'image' : MSBinary + 'image' : MSBinary, + 'timestamp': MSTimeStamp, } def __new__(cls, dbapi=None, *args, **kwargs): @@ -330,7 +347,7 @@ class MSSQLDialect(ansisql.ANSIDialect): super(MSSQLDialect, self).__init__(**params) self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False - self.use_scope_identity = True + self.use_scope_identity = False self.set_default_schema_name("dbo") def dbapi(cls, module_name=None): @@ -570,9 +587,22 @@ class MSSQLDialect_pymssql(MSSQLDialect): return module import_dbapi = classmethod(import_dbapi) + colspecs = MSSQLDialect.colspecs.copy() + colspecs[sqltypes.Date] = MSDate_pymssql + + ischema_names = MSSQLDialect.ischema_names.copy() + ischema_names['smalldatetime'] = MSDate_pymssql + + def __init__(self, **params): + super(MSSQLDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + def supports_sane_rowcount(self): return True + def max_identifier_length(self): + return 30 + def do_rollback(self, connection): # pymssql throws an error on repeated rollbacks. Ignore it. # TODO: this is normal behavior for most DBs. are we sure we want to ignore it ? @@ -638,12 +668,21 @@ class MSSQLDialect_pyodbc(MSSQLDialect): colspecs = MSSQLDialect.colspecs.copy() colspecs[sqltypes.Unicode] = AdoMSNVarchar + colspecs[sqltypes.Date] = MSDate_pyodbc + colspecs[sqltypes.DateTime] = MSDateTime_pyodbc + ischema_names = MSSQLDialect.ischema_names.copy() ischema_names['nvarchar'] = AdoMSNVarchar + ischema_names['smalldatetime'] = MSDate_pyodbc + ischema_names['datetime'] = MSDateTime_pyodbc def supports_sane_rowcount(self): return False + def supports_unicode_statements(self): + """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" + return True + def make_connect_string(self, keys): connectors = ["Driver={SQL Server}"] connectors.append("Server=%s" % keys.get("host")) @@ -671,12 +710,19 @@ class MSSQLDialect_adodbapi(MSSQLDialect): colspecs = MSSQLDialect.colspecs.copy() colspecs[sqltypes.Unicode] = AdoMSNVarchar + colspecs[sqltypes.DateTime] = MSDateTime_adodbapi + ischema_names = MSSQLDialect.ischema_names.copy() ischema_names['nvarchar'] = AdoMSNVarchar + ischema_names['datetime'] = MSDateTime_adodbapi def supports_sane_rowcount(self): return True + def supports_unicode_statements(self): + """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" + return True + def make_connect_string(self, keys): connectors = ["Provider=SQLOLEDB"] if 'port' in keys: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 8b4b89d508..63ce05eb68 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -24,7 +24,7 @@ RESERVED_WORDS = util.Set( 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', - 'exit', 'explain', 'false', 'fetch', 'fields', 'float', 'float4', 'float8', + 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', 'having', 'high_priority', 'hour_microsecond', 'hour_minute', 'hour_second', 'if', 'ignore', 'in', 'index', 'infile', 'inner', 'inout', 'insensitive', @@ -49,9 +49,11 @@ RESERVED_WORDS = util.Set( 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', - 'write', 'x509', 'xor', 'year_month', 'zerofill', + 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 + 'fields', # 4.1 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', - 'read_only', 'read_write']) + 'read_only', 'read_write', # 5.1 + ]) class _NumericType(object): "Base for MySQL numeric types." @@ -425,7 +427,8 @@ class MSText(_StringType, sqltypes.TEXT): """ _StringType.__init__(self, **kwargs) - sqltypes.TEXT.__init__(self, length) + sqltypes.TEXT.__init__(self, length, + kwargs.get('convert_unicode', False)) def get_col_spec(self): if self.length: @@ -678,21 +681,12 @@ class MSNChar(_StringType, sqltypes.CHAR): # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". return self._extend("CHAR(%(length)s)" % {'length': self.length}) -class MSBaseBinary(sqltypes.Binary): - """Flexible binary type""" - - def __init__(self, length=None, **kw): - """Flexibly construct a binary column type. Will construct a - VARBINARY or BLOB depending on the length requested, if any. - - length - Maximum data length, in bytes. - """ - super(MSBaseBinary, self).__init__(length, **kw) +class _BinaryType(sqltypes.Binary): + """MySQL binary types""" def get_col_spec(self): - if self.length and self.length <= 255: - return "VARBINARY(%d)" % self.length + if self.length: + return "BLOB(%d)" % self.length else: return "BLOB" @@ -702,7 +696,7 @@ class MSBaseBinary(sqltypes.Binary): else: return buffer(value) -class MSVarBinary(MSBaseBinary): +class MSVarBinary(_BinaryType): """MySQL VARBINARY type, for variable length binary data""" def __init__(self, length=None, **kw): @@ -719,7 +713,7 @@ class MSVarBinary(MSBaseBinary): else: return "BLOB" -class MSBinary(MSBaseBinary): +class MSBinary(_BinaryType): """MySQL BINARY type, for fixed length binary data""" def __init__(self, length=None, **kw): @@ -746,7 +740,7 @@ class MSBinary(MSBaseBinary): else: return buffer(value) -class MSBlob(MSBaseBinary): +class MSBlob(_BinaryType): """MySQL BLOB type, for binary data up to 2^16 bytes""" @@ -865,7 +859,7 @@ class MSEnum(MSString): class MSBoolean(sqltypes.Boolean): def get_col_spec(self): - return "BOOLEAN" + return "BOOL" def convert_result_value(self, value, dialect): if value is None: @@ -893,14 +887,14 @@ colspecs = { sqltypes.Date : MSDate, sqltypes.Time : MSTime, sqltypes.String : MSString, - sqltypes.Binary : MSVarBinary, + sqltypes.Binary : MSBlob, sqltypes.Boolean : MSBoolean, sqltypes.TEXT : MSText, sqltypes.CHAR: MSChar, sqltypes.NCHAR: MSNChar, sqltypes.TIMESTAMP: MSTimeStamp, sqltypes.BLOB: MSBlob, - MSBaseBinary: MSBaseBinary, + _BinaryType: _BinaryType, } @@ -951,7 +945,10 @@ def descriptor(): class MySQLExecutionContext(default.DefaultExecutionContext): def post_exec(self): if self.compiled.isinsert: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] + self._last_inserted_ids = [self.cursor.lastrowid] + + def is_select(self): + return re.match(r'SELECT|SHOW|DESCRIBE', self.statement.lstrip(), re.I) is not None class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): @@ -1069,6 +1066,19 @@ class MySQLDialect(ansisql.ANSIDialect): else: raise + def get_version_info(self, connectable): + if hasattr(connectable, 'connect'): + con = connectable.connect().connection + else: + con = connectable + version = [] + for n in con.get_server_info().split('.'): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + def reflecttable(self, connection, table): # reference: http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html cs = connection.execute("show variables like 'lower_case_table_names'").fetchone()[1] @@ -1125,7 +1135,11 @@ class MySQLDialect(ansisql.ANSIDialect): colargs= [] if default: - colargs.append(schema.PassiveDefault(sql.text(default))) + if col_type == 'timestamp' and default == 'CURRENT_TIMESTAMP': + arg = sql.text(default) + else: + arg = default + colargs.append(schema.PassiveDefault(arg)) table.append_column(schema.Column(name, coltype, *colargs, **dict(primary_key=primary_key, nullable=nullable, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index cde8ee0981..eca0faf914 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -191,6 +191,8 @@ class DefaultExecutionContext(base.ExecutionContext): return proc(params) def is_select(self): + """return TRUE if the statement is expected to have result rows.""" + return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None def create_cursor(self): diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index d6a995450a..1b363c9acd 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -10,6 +10,7 @@ from sqlalchemy.orm.attributes import InstrumentedList import sqlalchemy.exceptions as exceptions import sqlalchemy.orm as orm import sqlalchemy.util as util +import weakref def association_proxy(targetcollection, attr, **kw): """Convenience function for use in mapped classes. Implements a Python @@ -116,6 +117,18 @@ class AssociationProxy(object): def _target_is_scalar(self): return not self._get_property().uselist + + def _lazy_collection(self, weakobjref): + target = self.target_collection + del self + def lazy_collection(): + obj = weakobjref() + if obj is None: + raise exceptions.InvalidRequestError( + "stale association proxy, parent object has gone out of " + "scope") + return getattr(obj, target) + return lazy_collection def __get__(self, obj, class_): if obj is None: @@ -130,7 +143,7 @@ class AssociationProxy(object): try: return getattr(obj, self.key) except AttributeError: - proxy = self._new(getattr(obj, self.target_collection)) + proxy = self._new(self._lazy_collection(weakref.ref(obj))) setattr(obj, self.key, proxy) return proxy @@ -153,30 +166,32 @@ class AssociationProxy(object): def __delete__(self, obj): delattr(obj, self.key) - def _new(self, collection): + def _new(self, lazy_collection): creator = self.creator and self.creator or self.target_class # Prefer class typing here to spot dicts with the required append() # method. + collection = lazy_collection() if isinstance(collection.data, dict): self.collection_class = dict else: self.collection_class = util.duck_type_collection(collection.data) + del collection if self.proxy_factory: - return self.proxy_factory(collection, creator, self.value_attr) + return self.proxy_factory(lazy_collection, creator, self.value_attr) value_attr = self.value_attr getter = lambda o: getattr(o, value_attr) setter = lambda o, v: setattr(o, value_attr, v) if self.collection_class is list: - return _AssociationList(collection, creator, getter, setter) + return _AssociationList(lazy_collection, creator, getter, setter) elif self.collection_class is dict: kv_setter = lambda o, k, v: setattr(o, value_attr, v) - return _AssociationDict(collection, creator, getter, setter) + return _AssociationDict(lazy_collection, creator, getter, setter) elif self.collection_class is util.Set: - return _AssociationSet(collection, creator, getter, setter) + return _AssociationSet(lazy_collection, creator, getter, setter) else: raise exceptions.ArgumentError( 'could not guess which interface to use for ' @@ -203,11 +218,11 @@ class _AssociationList(object): converting association objects to and from a simplified value. """ - def __init__(self, collection, creator, getter, setter): + def __init__(self, lazy_collection, creator, getter, setter): """ - collection - A list-based collection of entities (usually an object attribute - managed by a SQLAlchemy relation()) + lazy_collection + A callable returning a list-based collection of entities (usually + an object attribute managed by a SQLAlchemy relation()) creator A function that creates new target entities. Given one parameter: @@ -223,11 +238,13 @@ class _AssociationList(object): that value on the object. """ - self.col = collection + self.lazy_collection = lazy_collection self.creator = creator self.getter = getter self.setter = setter + col = property(lambda self: self.lazy_collection()) + # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator. # (see append() below) def _create(self, value, **kw): @@ -320,11 +337,11 @@ class _AssociationDict(object): converting association objects to and from a simplified value. """ - def __init__(self, collection, creator, getter, setter): + def __init__(self, lazy_collection, creator, getter, setter): """ - collection - A list-based collection of entities (usually an object attribute - managed by a SQLAlchemy relation()) + lazy_collection + A callable returning a dict-based collection of entities (usually + an object attribute managed by a SQLAlchemy relation()) creator A function that creates new target entities. Given two parameters: @@ -340,11 +357,13 @@ class _AssociationDict(object): that value on the object. """ - self.col = collection + self.lazy_collection = lazy_collection self.creator = creator self.getter = getter self.setter = setter + col = property(lambda self: self.lazy_collection()) + def _create(self, key, value): return self.creator(key, value) @@ -380,7 +399,7 @@ class _AssociationDict(object): has_key = __contains__ def __iter__(self): - return iter(self.col) + return self.col.iterkeys() def clear(self): self.col.clear() @@ -465,11 +484,11 @@ class _AssociationSet(object): converting association objects to and from a simplified value. """ - def __init__(self, collection, creator, getter, setter): + def __init__(self, lazy_collection, creator, getter, setter): """ collection - A list-based collection of entities (usually an object attribute - managed by a SQLAlchemy relation()) + A callable returning a set-based collection of entities (usually an + object attribute managed by a SQLAlchemy relation()) creator A function that creates new target entities. Given one parameter: @@ -485,11 +504,13 @@ class _AssociationSet(object): that value on the object. """ - self.col = collection + self.lazy_collection = lazy_collection self.creator = creator self.getter = getter self.setter = setter + col = property(lambda self: self.lazy_collection()) + def _create(self, value): return self.creator(value) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 27ff408dcd..e02990a263 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -165,8 +165,12 @@ class OrderingList(list): return entity def __setitem__(self, index, entity): - super(OrderingList, self).__setitem__(index, entity) - self._order_entity(index, entity, True) + if isinstance(index, slice): + for i in range(index.start or 0, index.stop or 0, index.step or 1): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super(OrderingList, self).__setitem__(index, entity) def __delitem__(self, index): super(OrderingList, self).__delitem__(index) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 08d3fba318..02cd69df77 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -283,6 +283,8 @@ class Table(SchemaItem, sql.TableClause): # store extra kwargs, which should only contain db-specific options self.kwargs = kwargs + key = property(lambda self:_get_table_key(self.name, self.schema)) + def _get_case_sensitive_schema(self): try: return getattr(self, '_case_sensitive_schema') @@ -1116,6 +1118,10 @@ class MetaData(SchemaItem): def clear(self): self.tables.clear() + def remove(self, table): + # TODO: scan all other tables and remove FK _column + del self.tables[table.key] + def table_iterator(self, reverse=True, tables=None): import sqlalchemy.sql_util if tables is None: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index afdfc9cb09..9b9858cc2f 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -791,7 +791,7 @@ class ClauseParameters(object): """ def __init__(self, dialect, positional=None): - super(ClauseParameters, self).__init__(self) + super(ClauseParameters, self).__init__() self.dialect = dialect self.binds = {} self.binds_to_names = {} @@ -1670,7 +1670,7 @@ class FromClause(Selectable): it merely shares a common anscestor with one of the exported columns of this ``FromClause``. """ - + if column in self.c: return column if require_embedded and column not in util.Set(self._get_all_embedded_columns()): @@ -1734,6 +1734,8 @@ class FromClause(Selectable): for co in self._adjusted_exportable_columns(): cp = self._proxy_column(co) for ci in cp.orig_set: + # note that some ambiguity is raised here, whereby a selectable might have more than + # one column that maps to an "original" column. examples include unions and joins self._orig_cols[ci] = cp if self.oid_column is not None: for ci in self.oid_column.orig_set: @@ -2685,6 +2687,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): self.is_compound = True self.is_where = False self.is_scalar = False + self.is_subquery = False self.selects = selects diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index 75c8c01666..bcef161a34 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -104,7 +104,7 @@ class TypesTest(AssertMixin): 'SMALLINT(4) UNSIGNED ZEROFILL'), ] - table_args = ['test_mysql_numeric', db] + table_args = ['test_mysql_numeric', BoundMetaData(db)] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw))) @@ -188,7 +188,7 @@ class TypesTest(AssertMixin): '''ENUM('foo','bar') UNICODE''') ] - table_args = ['test_mysql_charset', db] + table_args = ['test_mysql_charset', BoundMetaData(db)] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw))) @@ -212,7 +212,7 @@ class TypesTest(AssertMixin): def test_enum(self): "Exercise the ENUM type" - enum_table = Table('mysql_enum', db, + enum_table = Table('mysql_enum', BoundMetaData(db), Column('e1', mysql.MSEnum('"a"', "'b'")), Column('e2', mysql.MSEnum('"a"', "'b'"), nullable=False), Column('e3', mysql.MSEnum('"a"', "'b'", strict=True)), @@ -294,6 +294,10 @@ class TypesTest(AssertMixin): @testbase.supported('mysql') def test_type_reflection(self): + # FIXME: older versions need their own test + if db.dialect.get_version_info(db) < (5, 0): + return + # (ask_for, roundtripped_as_if_different) specs = [( String(), mysql.MSText(), ), ( String(1), mysql.MSString(1), ), @@ -307,11 +311,9 @@ class TypesTest(AssertMixin): ( Smallinteger(4), mysql.MSSmallInteger(4), ), ( mysql.MSSmallInteger(), ), ( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ), - ( Binary(3), mysql.MSVarBinary(3), ), + ( Binary(3), mysql.MSBlob(3), ), ( Binary(), mysql.MSBlob() ), ( mysql.MSBinary(3), mysql.MSBinary(3), ), - ( mysql.MSBaseBinary(), mysql.MSBlob(), ), - ( mysql.MSBaseBinary(3), mysql.MSVarBinary(3), ), ( mysql.MSVarBinary(3),), ( mysql.MSVarBinary(), mysql.MSBlob()), ( mysql.MSTinyBlob(),), @@ -331,13 +333,14 @@ class TypesTest(AssertMixin): m2 = BoundMetaData(db) rt = Table('mysql_types', m2, autoload=True) + #print expected = [len(c) > 1 and c[1] or c[0] for c in specs] for i, reflected in enumerate(rt.c): #print (reflected, specs[i][0], '->', # reflected.type, '==', expected[i]) - assert type(reflected.type) == type(expected[i]) + assert isinstance(reflected.type, type(expected[i])) - #m.drop_all() + m.drop_all() if __name__ == "__main__": testbase.main() diff --git a/test/engine/execute.py b/test/engine/execute.py index af29fb2a53..33c2520182 100644 --- a/test/engine/execute.py +++ b/test/engine/execute.py @@ -44,8 +44,10 @@ class ExecuteTest(testbase.PersistTest): res = conn.execute("select * from users") assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally'), (5, None)] conn.execute("delete from users") - - @testbase.supported('postgres', 'mysql') + + # pyformat is supported for mysql, but skipping because a few driver + # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2) + @testbase.supported('postgres') def test_raw_python(self): for conn in (testbase.db, testbase.db.connect()): conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'}) diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 8990d262b0..532f55e8a9 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -25,9 +25,14 @@ class ReflectionTest(PersistTest): if use_string_defaults: deftype2 = String defval2 = "im a default" + #deftype3 = DateTime + # the colon thing isnt working out for PG reflection just yet + #defval3 = '1999-09-09 00:00:00' + deftype3 = Date + defval3 = '1999-09-09' else: - deftype2 = Integer - defval2 = "15" + deftype2, deftype3 = Integer, Integer + defval2, defval3 = "15", "16" meta = BoundMetaData(testbase.db) @@ -46,6 +51,7 @@ class ReflectionTest(PersistTest): Column('test_passivedefault', deftype, PassiveDefault(defval)), Column('test_passivedefault2', Integer, PassiveDefault("5")), Column('test_passivedefault3', deftype2, PassiveDefault(defval2)), + Column('test_passivedefault4', deftype3, PassiveDefault(defval3)), Column('test9', Binary(100)), Column('test_numeric', Numeric(None, None)), mysql_engine='InnoDB' @@ -58,7 +64,7 @@ class ReflectionTest(PersistTest): mysql_engine='InnoDB' ) meta.drop_all() - + users.create() addresses.create() diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py index 57efe89e39..7a5731d514 100644 --- a/test/ext/associationproxy.py +++ b/test/ext/associationproxy.py @@ -131,7 +131,10 @@ class _CollectionOperations(PersistTest): self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - + + p1._children = [] + self.assert_(len(p1.children) == 0) + class DefaultTest(_CollectionOperations): def __init__(self, *args, **kw): super(DefaultTest, self).__init__(*args, **kw) @@ -209,10 +212,15 @@ class CustomDictTest(DictTest): self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) + self.assert_(set(p1.children) == set(['d','e','f'])) + del ch p1 = self.roundtrip(p1) self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) + + p1._children = {} + self.assert_(len(p1.children) == 0) class SetTest(_CollectionOperations): @@ -312,6 +320,9 @@ class SetTest(_CollectionOperations): p1 = self.roundtrip(p1) self.assert_(p1.children == set(['c'])) + p1._children = [] + self.assert_(len(p1.children) == 0) + def test_set_comparisons(self): Parent, Child = self.Parent, self.Child diff --git a/test/sql/query.py b/test/sql/query.py index 5d5a734108..0d12aa1939 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -33,7 +33,49 @@ class QueryTest(PersistTest): def testinsert(self): users.insert().execute(user_id = 7, user_name = 'jack') assert users.count().scalar() == 1 + + @testbase.unsupported('sqlite') + def test_lastrow_accessor(self): + """test the last_inserted_ids() and lastrow_has_id() functions""" + def insert_values(table, values): + result = table.insert().execute(**values) + ret = values.copy() + + for col, id in zip(table.primary_key, result.last_inserted_ids()): + ret[col.key] = id + + if result.lastrow_has_defaults(): + criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())]) + row = table.select(criterion).execute().fetchone() + ret.update(row) + return ret + + for table, values, assertvalues in [ + ( + Table("t1", metadata, + Column('id', Integer, primary_key=True), + Column('foo', String(30), primary_key=True)), + {'foo':'hi'}, + {'id':1, 'foo':'hi'} + ), + ( + Table("t2", metadata, + Column('id', Integer, primary_key=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), PassiveDefault('hi')) + ), + {'foo':'hi'}, + {'id':1, 'foo':'hi', 'bar':'hi'} + ), + + ]: + try: + table.create() + assert insert_values(table, values) == assertvalues + finally: + table.drop() + def testupdate(self): users.insert().execute(user_id = 7, user_name = 'jack') @@ -360,6 +402,19 @@ class QueryTest(PersistTest): con.execute("""drop trigger paj""") meta.drop_all() + @testbase.supported('mssql') + def test_insertid_schema(self): + meta = BoundMetaData(testbase.db) + con = testbase.db.connect() + con.execute('create schema paj') + tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj') + tbl.create() + try: + tbl.insert().execute({'id':1}) + finally: + tbl.drop() + con.execute('drop schema paj') + class CompoundTest(PersistTest): """test compound statements like UNION, INTERSECT, particularly their ability to nest on diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py index 05d0f21105..95cab898c3 100644 --- a/test/sql/rowcount.py +++ b/test/sql/rowcount.py @@ -31,7 +31,7 @@ class FoundRowsTest(testbase.AssertMixin): i.execute(*[{'name':n, 'department':d} for n, d in data]) def tearDown(self): employees_table.delete().execute() - + def tearDownAll(self): employees_table.drop() @@ -45,23 +45,26 @@ class FoundRowsTest(testbase.AssertMixin): # WHERE matches 3, 3 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='Z') - assert r.rowcount == 3 - + if testbase.db.dialect.supports_sane_rowcount(): + assert r.rowcount == 3 + def test_update_rowcount2(self): # WHERE matches 3, 0 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='C') - assert r.rowcount == 3 - + if testbase.db.dialect.supports_sane_rowcount(): + assert r.rowcount == 3 + def test_delete_rowcount(self): # WHERE matches 3, 3 rows deleted department = employees_table.c.department r = employees_table.delete(department=='C').execute() - assert r.rowcount == 3 + if testbase.db.dialect.supports_sane_rowcount(): + assert r.rowcount == 3 if __name__ == '__main__': testbase.main() - + diff --git a/test/sql/selectable.py b/test/sql/selectable.py index f236c60f04..b7ab91ee8f 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -27,6 +27,11 @@ table2 = Table('table2', metadata, ) class SelectableTest(testbase.AssertMixin): + def testjoinagainstself(self): + jj = select([table.c.col1.label('bar_col1')]) + jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + def testjoinagainstjoin(self): j = outerjoin(table, table2, table.c.col1==table2.c.col2) jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index d494909ea8..8406860ee7 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -142,28 +142,35 @@ class UnicodeTest(AssertMixin): metadata = BoundMetaData(db) unicode_table = Table('unicode_table', metadata, Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True), - Column('unicode_data', Unicode(250)), - Column('plain_data', String(250)) + Column('unicode_varchar', Unicode(250)), + Column('unicode_text', Unicode), + Column('plain_varchar', String(250)) ) unicode_table.create() def tearDownAll(self): unicode_table.drop() + def testbasic(self): - assert unicode_table.c.unicode_data.type.length == 250 + assert unicode_table.c.unicode_varchar.type.length == 250 rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') - unicode_table.insert().execute(unicode_data=unicodedata, plain_data=rawdata) + unicode_table.insert().execute(unicode_varchar=unicodedata, + unicode_text=unicodedata, + plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - self.echo(repr(x['unicode_data'])) - self.echo(repr(x['plain_data'])) - self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata) - if isinstance(x['plain_data'], unicode): + self.echo(repr(x['unicode_varchar'])) + self.echo(repr(x['unicode_text'])) + self.echo(repr(x['plain_varchar'])) + self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) + self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) + if isinstance(x['plain_varchar'], unicode): # SQLLite and MSSQL return non-unicode data as unicode self.assert_(db.name in ('sqlite', 'mssql')) - self.assert_(x['plain_data'] == unicodedata) + self.assert_(x['plain_varchar'] == unicodedata) self.echo("it's %s!" % db.name) else: - self.assert_(not isinstance(x['plain_data'], unicode) and x['plain_data'] == rawdata) + self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) + def testengineparam(self): """tests engine-wide unicode conversion""" prev_unicode = db.engine.dialect.convert_unicode @@ -171,17 +178,24 @@ class UnicodeTest(AssertMixin): db.engine.dialect.convert_unicode = True rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') - unicode_table.insert().execute(unicode_data=unicodedata, plain_data=rawdata) + unicode_table.insert().execute(unicode_varchar=unicodedata, + unicode_text=unicodedata, + plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - self.echo(repr(x['unicode_data'])) - self.echo(repr(x['plain_data'])) - self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata) - self.assert_(isinstance(x['plain_data'], unicode) and x['plain_data'] == unicodedata) + self.echo(repr(x['unicode_varchar'])) + self.echo(repr(x['unicode_text'])) + self.echo(repr(x['plain_varchar'])) + self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) + self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) + self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) finally: db.engine.dialect.convert_unicode = prev_unicode - - + def testlength(self): + """checks the database correctly understands the length of a unicode string""" + teststr = u'aaa\x1234' + self.assert_(db.func.length(teststr).scalar() == len(teststr)) + class BinaryTest(AssertMixin): def setUpAll(self): global binary_table @@ -305,6 +319,24 @@ class DateTest(AssertMixin): #x = db.text("select * from query_users_with_date where user_datetime=:date", bindparams=[bindparam('date', )]).execute(date=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall() #print repr(x) + @testbase.unsupported('sqlite') + def testdate2(self): + t = Table('testdate', testbase.metadata, Column('id', Integer, primary_key=True), + Column('adate', Date), Column('adatetime', DateTime)) + t.create() + try: + d1 = datetime.date(2007, 10, 30) + t.insert().execute(adate=d1, adatetime=d1) + d2 = datetime.datetime(2007, 10, 30) + t.insert().execute(adate=d2, adatetime=d2) + + x = t.select().execute().fetchall()[0] + self.assert_(x.adate.__class__ == datetime.date) + self.assert_(x.adatetime.__class__ == datetime.datetime) + + finally: + t.drop() + class TimezoneTest(AssertMixin): """test timezone-aware datetimes. psycopg will return a datetime with a tzinfo attached to it, if postgres returns it. python then will not let you compare a datetime with a tzinfo to a datetime