From: Mike Bayer Date: Sat, 25 Feb 2006 07:12:50 +0000 (+0000) Subject: merged sql_rearrangement branch , refactors sql package to work standalone with X-Git-Tag: rel_0_1_3~39 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=72dd2b08beb9803269983aa220e75b44007e5158;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merged sql_rearrangement branch , refactors sql package to work standalone with clause elements including tables and columns, schema package deals with "physical" representations --- diff --git a/CHANGES b/CHANGES index 6b85598d9e..52b2846fff 100644 --- a/CHANGES +++ b/CHANGES @@ -2,6 +2,9 @@ - fix to Oracle "row_number over" clause with mulitple tables - mapper.get() was not selecting multiple-keyed objects if the mapper's table was a join, such as in an inheritance relationship, this is fixed. +- overhaul to sql/schema packages so that the sql package can run all on its own, +producing selects, inserts, etc. without any engine dependencies. Table/Column +are the "physical" subclasses of TableClause/ColumnClause. 0.1.2 - fixed a recursive call in schema that was somehow running 994 times then returning normally. broke nothing, slowed down everything. thanks to jpellerin for finding this. diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 0c8aa2fe01..d38a557f97 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -6,8 +6,8 @@ from engine import * from types import * +from sql import * from schema import * from exceptions import * -from sql import * import mapping as mapperlib from mapping import * diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index ac10d27f16..c25a55c7ac 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -152,16 +152,11 @@ class ANSICompiler(sql.Compiled): # if we are within a visit to a Select, set up the "typemap" # for this column which is used to translate result set values self.typemap.setdefault(column.key.lower(), column.type) - if column.table.name is None: + if column.table is not None and column.table.name is None: self.strings[column] = column.name else: self.strings[column] = "%s.%s" % (column.table.name, column.name) - def visit_columnclause(self, column): - if column.table is not None and column.table.name is not None: - self.strings[column] = "%s.%s" % (column.table.name, column.text) - else: - self.strings[column] = column.text def visit_fromclause(self, fromclause): self.froms[fromclause] = fromclause.from_name @@ -257,11 +252,13 @@ class ANSICompiler(sql.Compiled): l = co.label(co._label) l.accept_visitor(self) inner_columns[co._label] = l - elif select.issubquery and isinstance(co, Column): + # TODO: figure this out, a ColumnClause with a select as a parent + # is different from any other kind of parent + elif select.issubquery and isinstance(co, sql.ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select): # SQLite doesnt like selecting from a subquery where the column # names look like table.colname, so add a label synonomous with # the column name - l = co.label(co.key) + l = co.label(co.text) l.accept_visitor(self) inner_columns[self.get_str(l.obj)] = l else: @@ -379,7 +376,7 @@ class ANSICompiler(sql.Compiled): contains a Sequence object.""" pass - def visit_insert_column(selef, column): + def visit_insert_column(self, column): """called when visiting an Insert statement, for each column in the table that is a NULL insert into the table""" pass @@ -395,8 +392,8 @@ class ANSICompiler(sql.Compiled): self.visit_insert_sequence(c, seq) vis = DefaultVisitor() for c in insert_stmt.table.c: - if (self.parameters is None or self.parameters.get(c.key, None) is None): - c.accept_visitor(vis) + if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): + c.accept_schema_visitor(vis) self.isinsert = True colparams = self._get_colparams(insert_stmt) @@ -419,7 +416,7 @@ class ANSICompiler(sql.Compiled): return self.bindparam_string(p.key) else: p.accept_visitor(self) - if isinstance(p, sql.ClauseElement): + if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnClause): return "(" + self.get_str(p) + ")" else: return self.get_str(p) @@ -466,7 +463,7 @@ class ANSICompiler(sql.Compiled): # now go thru compiled params, get the Column object for each key d = {} for key, value in parameters.iteritems(): - if isinstance(key, schema.Column): + if isinstance(key, sql.ColumnClause): d[key] = value else: try: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 04bdc24fa4..d660db7bdc 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -131,11 +131,6 @@ class MySQLEngine(ansisql.ANSISQLEngine): def supports_sane_rowcount(self): return False - def tableimpl(self, table, **kwargs): - """returns a new sql.TableImpl object to correspond to the given Table object.""" - mysql_engine = kwargs.pop('mysql_engine', None) - return MySQLTableImpl(table, mysql_engine=mysql_engine) - def compiler(self, statement, bindparams, **kwargs): return MySQLCompiler(self, statement, bindparams, **kwargs) @@ -175,7 +170,7 @@ class MySQLEngine(ansisql.ANSISQLEngine): #ischema.reflecttable(self, table, ischema_names, use_mysql=True) tabletype, foreignkeyD = self.moretableinfo(table=table) - table._impl.mysql_engine = tabletype + table.kwargs['mysql_engine'] = tabletype c = self.execute("describe " + table.name, {}) while True: @@ -235,14 +230,6 @@ class MySQLEngine(ansisql.ANSISQLEngine): return (tabletype, foreignkeyD) -class MySQLTableImpl(sql.TableImpl): - """attached to a schema.Table to provide it with a Selectable interface - as well as other functions - """ - def __init__(self, table, mysql_engine=None): - super(MySQLTableImpl, self).__init__(table) - self.mysql_engine = mysql_engine - class MySQLCompiler(ansisql.ANSICompiler): def visit_function(self, func): @@ -277,12 +264,13 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if first_pk and isinstance(column.type, types.Integer): colspec += " AUTO_INCREMENT" if column.foreign_key: - colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) + colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec def post_create_table(self, table): - if table.mysql_engine is not None: - return " ENGINE=%s" % table.mysql_engine + mysql_engine = table.kwargs.get('mysql_engine', None) + if mysql_engine is not None: + return " ENGINE=%s" % mysql_engine else: return "" diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 2115f5d568..238310b1b1 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -312,7 +312,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): if column.primary_key and not override_pk: colspec += " PRIMARY KEY" if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.fullname, column.column.foreign_key.column.name) + colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name) return colspec def visit_sequence(self, sequence): diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 73b8769f25..56488a197c 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -16,8 +16,7 @@ A SQLEngine is provided to an application as a subclass that is specific to a pa of DBAPI, and is the central switching point for abstracting different kinds of database behavior into a consistent set of behaviors. It provides a variety of factory methods to produce everything specific to a certain kind of database, including a Compiler, -schema creation/dropping objects, and TableImpl and ColumnImpl objects to augment the -behavior of table metadata objects. +schema creation/dropping objects. The term "database-specific" will be used to describe any object or function that has behavior corresponding to a particular vendor, such as mysql-specific, sqlite-specific, etc. @@ -131,7 +130,7 @@ class DefaultRunner(schema.SchemaVisitor): def get_column_default(self, column): if column.default is not None: - return column.default.accept_visitor(self) + return column.default.accept_schema_visitor(self) else: return None @@ -296,11 +295,11 @@ class SQLEngine(schema.SchemaEngine): def create(self, entity, **params): """creates a table or index within this engine's database connection given a schema.Table object.""" - entity.accept_visitor(self.schemagenerator(**params)) + entity.accept_schema_visitor(self.schemagenerator(**params)) def drop(self, entity, **params): """drops a table or index within this engine's database connection given a schema.Table object.""" - entity.accept_visitor(self.schemadropper(**params)) + entity.accept_schema_visitor(self.schemadropper(**params)) def compile(self, statement, parameters, **kwargs): """given a sql.ClauseElement statement plus optional bind parameters, creates a new @@ -315,28 +314,6 @@ class SQLEngine(schema.SchemaEngine): """given a Table object, reflects its columns and properties from the database.""" raise NotImplementedError() - def tableimpl(self, table, **kwargs): - """returns a new sql.TableImpl object to correspond to the given Table object. - A TableImpl provides SQL statement builder operations on a Table metadata object, - and a subclass of this object may be provided by a SQLEngine subclass to provide - database-specific behavior.""" - return sql.TableImpl(table) - - def columnimpl(self, column): - """returns a new sql.ColumnImpl object to correspond to the given Column object. - A ColumnImpl provides SQL statement builder operations on a Column metadata object, - and a subclass of this object may be provided by a SQLEngine subclass to provide - database-specific behavior.""" - return sql.ColumnImpl(column) - - def indeximpl(self, index): - """returns a new sql.IndexImpl object to correspond to the given Index - object. An IndexImpl provides SQL statement builder operations on an - Index metadata object, and a subclass of this object may be provided - by a SQLEngine subclass to provide database-specific behavior. - """ - return sql.IndexImpl(index) - def get_default_schema_name(self): """returns the currently selected schema in the current connection.""" return None diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py index c1bdd9fa53..4f1f9e4010 100644 --- a/lib/sqlalchemy/ext/proxy.py +++ b/lib/sqlalchemy/ext/proxy.py @@ -13,7 +13,7 @@ class ProxyEngine(object): """ SQLEngine proxy. Supports lazy and late initialization by delegating to a real engine (set with connect()), and using proxy - classes for TableImpl, ColumnImpl and TypeEngine. + classes for TypeEngine. """ def __init__(self): @@ -61,15 +61,6 @@ class ProxyEngine(object): return None return self.get_engine().oid_column_name() - def columnimpl(self, column): - """Proxy point: return a ProxyColumnImpl - """ - return ProxyColumnImpl(self, column) - - def tableimpl(self, table): - """Proxy point: return a ProxyTableImpl - """ - return ProxyTableImpl(self, table) def type_descriptor(self, typeobj): """Proxy point: return a ProxyTypeEngine @@ -84,45 +75,6 @@ class ProxyEngine(object): raise AttributeError('No connection established in ProxyEngine: ' ' no access to %s' % attr) - -class ProxyColumnImpl(sql.ColumnImpl): - """Proxy column; defers engine access to ProxyEngine - """ - def __init__(self, engine, column): - sql.ColumnImpl.__init__(self, column) - self._engine = engine - self.impls = weakref.WeakKeyDictionary() - def _get_impl(self): - e = self._engine.engine - try: - return self.impls[e] - except KeyError: - impl = e.columnimpl(self.column) - self.impls[e] = impl - def __getattr__(self, key): - return getattr(self._get_impl(), key) - engine = property(lambda self: self._engine.engine) - -class ProxyTableImpl(sql.TableImpl): - """Proxy table; defers engine access to ProxyEngine - """ - def __init__(self, engine, table): - sql.TableImpl.__init__(self, table) - self._engine = engine - self.impls = weakref.WeakKeyDictionary() - def _get_impl(self): - e = self._engine.engine - try: - return self.impls[e] - except KeyError: - impl = e.tableimpl(self.table) - self.impls[e] = impl - return impl - def __getattr__(self, key): - return getattr(self._get_impl(), key) - - engine = property(lambda self: self._engine.engine) - class ProxyType(object): """ProxyType base class; used by ProxyTypeEngine to construct proxying types diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 61d9a3c2ef..33bec863e8 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -262,7 +262,7 @@ class Mapper(object): """returns an instance of the object based on the given identifier, or None if not found. The *ident argument is a list of primary key columns in the order of the table def's primary key columns.""" - key = objectstore.get_id_key(ident, self.class_, self.primarytable) + key = objectstore.get_id_key(ident, self.class_) #print "key: " + repr(key) + " ident: " + repr(ident) return self._get(key, ident) @@ -284,7 +284,7 @@ class Mapper(object): def identity_key(self, *primary_key): - return objectstore.get_id_key(tuple(primary_key), self.class_, self.primarytable) + return objectstore.get_id_key(tuple(primary_key), self.class_) def instance_key(self, instance): return self.identity_key(*[self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.table]]) @@ -683,7 +683,7 @@ class Mapper(object): return statement def _identity_key(self, row): - return objectstore.get_row_key(row, self.class_, self.identitytable, self.pks_by_table[self.table]) + return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table]) def _instance(self, row, imap, result = None, populate_existing = False): """pulls an object instance from the given row and appends it to the given result diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index 311a6c5420..4f0dc4dafd 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -48,7 +48,7 @@ class Session(object): self.hash_key = hash_key _sessions[self.hash_key] = self - def get_id_key(ident, class_, table): + def get_id_key(ident, class_): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a tuple of the object's primary key values. @@ -62,10 +62,10 @@ class Session(object): selectable - a Selectable object which represents all the object's column-based fields. this Selectable may be synonymous with the table argument or can be a larger construct containing that table. return value: a tuple object which is used as an identity key. """ - return (class_, table.hash_key(), tuple(ident)) + return (class_, tuple(ident)) get_id_key = staticmethod(get_id_key) - def get_row_key(row, class_, table, primary_key): + def get_row_key(row, class_, primary_key): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a result set row. @@ -80,7 +80,7 @@ class Session(object): this Selectable may be synonymous with the table argument or can be a larger construct containing that table. return value: a tuple object which is used as an identity key. """ - return (class_, table.hash_key(), tuple([row[column] for column in primary_key])) + return (class_, tuple([row[column] for column in primary_key])) get_row_key = staticmethod(get_row_key) class SessionTrans(object): @@ -181,7 +181,6 @@ class Session(object): return None key = getattr(instance, '_instance_key', None) mapper = object_mapper(instance) - key = (key[0], mapper.table.hash_key(), key[2]) u = self.uow if key is not None: if u.identity_map.has_key(key): @@ -194,11 +193,11 @@ class Session(object): u.register_new(instance) return instance -def get_id_key(ident, class_, table): - return Session.get_id_key(ident, class_, table) +def get_id_key(ident, class_): + return Session.get_id_key(ident, class_) -def get_row_key(row, class_, table, primary_key): - return Session.get_row_key(row, class_, table, primary_key) +def get_row_key(row, class_, primary_key): + return Session.get_row_key(row, class_, primary_key) def begin(): """begins a new UnitOfWork transaction. the next commit will affect only diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 8f55231380..8e9a434825 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -14,7 +14,7 @@ structure with its own clause-specific objects as well as the visitor interface, the schema package "plugs in" to the SQL package. """ - +import sql from util import * from types import * from exceptions import * @@ -29,30 +29,12 @@ class SchemaItem(object): for item in args: if item is not None: item._set_parent(self) - - def accept_visitor(self, visitor): - """all schema items implement an accept_visitor method that should call the appropriate - visit_XXXX method upon the given visitor object.""" - raise NotImplementedError() - def _set_parent(self, parent): """a child item attaches itself to its parent via this method.""" raise NotImplementedError() - - def hash_key(self): - """returns a string that identifies this SchemaItem uniquely""" - return "%s(%d)" % (self.__class__.__name__, id(self)) - def __repr__(self): return "%s()" % self.__class__.__name__ - def __getattr__(self, key): - """proxies method calls to an underlying implementation object for methods not found - locally""" - if not self.__dict__.has_key('_impl'): - raise AttributeError(key) - return getattr(self._impl, key) - def _get_table_key(engine, name, schema): if schema is not None and schema == engine.get_default_schema_name(): schema = None @@ -95,8 +77,10 @@ class TableSingleton(type): return table -class Table(SchemaItem): - """represents a relational database table. +class Table(sql.TableClause, SchemaItem): + """represents a relational database table. This subclasses sql.TableClause to provide + a table that is "wired" to an engine. Whereas TableClause represents a table as its + used in a SQL expression, Table represents a table as its created in the database. Be sure to look at sqlalchemy.sql.TableImpl for additional methods defined on a Table.""" __metaclass__ = TableSingleton @@ -134,19 +118,15 @@ class Table(SchemaItem): the same table twice will result in an exception. """ - self.name = name - self.columns = OrderedProperties() - self.c = self.columns - self.foreign_keys = [] - self.primary_key = [] - self.engine = engine + super(Table, self).__init__(name) + self._engine = engine self.schema = kwargs.pop('schema', None) - self._impl = self.engine.tableimpl(self, **kwargs) if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: self.fullname = self.name - + self.kwargs = kwargs + def __repr__(self): return "Table(%s)" % string.join( [repr(self.name)] + [repr(self.engine)] + @@ -160,44 +140,45 @@ class Table(SchemaItem): else: return self.schema + "." + self.name - def hash_key(self): - return "Table(%s)" % string.join( - [repr(self.name)] + [self.engine.hash_key()] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']], ',' - ) - def reload_values(self, *args): """clears out the columns and other properties of this Table, and reloads them from the given argument list. This is used with the "redefine" keyword argument sent to the metaclass constructor.""" - self.columns = OrderedProperties() - self.c = self.columns - self.foreign_keys = [] - self.primary_key = [] - self._impl = self.engine.tableimpl(self) + self._clear() + + print "RELOAD VALUES", args self._init_items(*args) def append_item(self, item): """appends a Column item or other schema item to this Table.""" self._init_items(item) - + + def append_column(self, column): + if not column.hidden: + self._columns[column.key] = column + if column.primary_key: + self.primary_key.append(column) + column.table = self + column.type = self.engine.type_descriptor(column.type) + def _set_parent(self, schema): schema.tables[self.name] = self self.schema = schema - - def accept_visitor(self, visitor): + def accept_schema_visitor(self, visitor): """traverses the given visitor across the Column objects inside this Table, then calls the visit_table method on the visitor.""" for c in self.columns: - c.accept_visitor(visitor) + c.accept_schema_visitor(visitor) return visitor.visit_table(self) - def deregister(self): """removes this table from it's engines table registry. this does not issue a SQL DROP statement.""" key = _get_table_key(self.engine, self.name, self.schema) del self.engine.tables[key] - + def create(self, **params): + self.engine.create(self) + def drop(self, **params): + self.engine.drop(self) def toengine(self, engine, schema=None): """returns a singleton instance of this Table with a different engine""" try: @@ -211,8 +192,9 @@ class Table(SchemaItem): args.append(c.copy()) return Table(self.name, engine, schema=schema, *args) -class Column(SchemaItem): - """represents a column in a database table.""" +class Column(sql.ColumnClause, SchemaItem): + """represents a column in a database table. this is a subclass of sql.ColumnClause and + represents an actual existing table in the database, in a similar fashion as TableClause/Table.""" def __init__(self, name, type, *args, **kwargs): """constructs a new Column object. Arguments are: @@ -244,24 +226,27 @@ class Column(SchemaItem): hidden=False : indicates this column should not be listed in the table's list of columns. Used for the "oid" column, which generally isnt in column lists. """ - self.name = str(name) # in case of incoming unicode - self.type = type + name = str(name) # in case of incoming unicode + super(Column, self).__init__(name, None, type) self.args = args self.key = kwargs.pop('key', name) - self.primary_key = kwargs.pop('primary_key', False) + self._primary_key = kwargs.pop('primary_key', False) self.nullable = kwargs.pop('nullable', not self.primary_key) self.hidden = kwargs.pop('hidden', False) self.default = kwargs.pop('default', None) - self.foreign_key = None + self._foreign_key = None self._orig = None self._parent = None if len(kwargs): raise ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) - + + primary_key = AttrProp('_primary_key') + foreign_key = AttrProp('_foreign_key') original = property(lambda s: s._orig or s) parent = property(lambda s:s._parent or s) engine = property(lambda s: s.table.engine) - + columns = property(lambda self:[self]) + def __repr__(self): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + @@ -282,16 +267,7 @@ class Column(SchemaItem): def _set_parent(self, table): if getattr(self, 'table', None) is not None: raise ArgumentError("this Column already has a table!") - if not self.hidden: - table.columns[self.key] = self - if self.primary_key: - table.primary_key.append(self) - self.table = table - if self.table.engine is not None: - self.type = self.table.engine.type_descriptor(self.type) - - self._impl = self.table.engine.columnimpl(self) - + table.append_column(self) if self.default is not None: self.default = ColumnDefault(self.default) self._init_items(self.default) @@ -320,35 +296,19 @@ class Column(SchemaItem): selectable.columns[c.key] = c if self.primary_key: selectable.primary_key.append(c) - c._impl = self.engine.columnimpl(c) if fk is not None: c._init_items(fk) return c - def accept_visitor(self, visitor): + def accept_schema_visitor(self, visitor): """traverses the given visitor to this Column's default and foreign key object, then calls visit_column on the visitor.""" if self.default is not None: - self.default.accept_visitor(visitor) + self.default.accept_schema_visitor(visitor) if self.foreign_key is not None: - self.foreign_key.accept_visitor(visitor) + self.foreign_key.accept_schema_visitor(visitor) visitor.visit_column(self) - def __lt__(self, other): return self._impl.__lt__(other) - def __le__(self, other): return self._impl.__le__(other) - def __eq__(self, other): return self._impl.__eq__(other) - def __ne__(self, other): return self._impl.__ne__(other) - def __gt__(self, other): return self._impl.__gt__(other) - def __ge__(self, other): return self._impl.__ge__(other) - def __add__(self, other): return self._impl.__add__(other) - def __sub__(self, other): return self._impl.__sub__(other) - def __mul__(self, other): return self._impl.__mul__(other) - def __and__(self, other): return self._impl.__and__(other) - def __or__(self, other): return self._impl.__or__(other) - def __div__(self, other): return self._impl.__div__(other) - def __truediv__(self, other): return self._impl.__truediv__(other) - def __invert__(self, other): return self._impl.__invert__(other) - def __str__(self): return self._impl.__str__() class ForeignKey(SchemaItem): """defines a ForeignKey constraint between two columns. ForeignKey is @@ -374,7 +334,7 @@ class ForeignKey(SchemaItem): elif self._colspec.table.schema is not None: return "%s.%s.%s" % (self._colspec.table.schema, self._colspec.table.name, self._colspec.column.key) else: - return "%s.%s" % (self._colspec.table.name, self._colspec.column.key) + return "%s.%s" % (self._colspec.table.name, self._colspec.key) def references(self, table): """returns True if the given table is referenced by this ForeignKey.""" @@ -406,7 +366,7 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_visitor(self, visitor): + def accept_schema_visitor(self, visitor): """calls the visit_foreign_key method on the given visitor.""" visitor.visit_foreign_key(self) @@ -432,7 +392,7 @@ class PassiveDefault(DefaultGenerator): """a default that takes effect on the database side""" def __init__(self, arg): self.arg = arg - def accept_visitor(self, visitor): + def accept_schema_visitor(self, visitor): return visitor.visit_passive_default(self) def __repr__(self): return "PassiveDefault(%s)" % repr(self.arg) @@ -442,7 +402,7 @@ class ColumnDefault(DefaultGenerator): a callable function, or a SQL clause.""" def __init__(self, arg): self.arg = arg - def accept_visitor(self, visitor): + def accept_schema_visitor(self, visitor): """calls the visit_column_default method on the given visitor.""" return visitor.visit_column_default(self) def __repr__(self): @@ -461,7 +421,7 @@ class Sequence(DefaultGenerator): ["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']] , ',') - def accept_visitor(self, visitor): + def accept_schema_visitor(self, visitor): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) @@ -486,6 +446,7 @@ class Index(SchemaItem): self.unique = kw.pop('unique', False) self._init_items() + engine = property(lambda s:s.table.engine) def _init_items(self): # make sure all columns are from the same table # FIXME: and no column is repeated @@ -499,10 +460,13 @@ class Index(SchemaItem): "%s is from %s not %s" % (column, column.table, self.table)) - # set my _impl from col.table.engine - self._impl = self.table.engine.indeximpl(self) - - def accept_visitor(self, visitor): + def create(self): + self.engine.create(self) + def drop(self): + self.engine.drop(self) + def execute(self): + self.create() + def accept_schema_visitor(self, visitor): visitor.visit_index(self) def __str__(self): return repr(self) @@ -515,24 +479,13 @@ class Index(SchemaItem): class SchemaEngine(object): """a factory object used to create implementations for schema objects. This object is the ultimate base class for the engine.SQLEngine class.""" - def tableimpl(self, table): - """returns a new implementation object for a Table (usually sql.TableImpl)""" - raise NotImplementedError() - def columnimpl(self, column): - """returns a new implementation object for a Column (usually sql.ColumnImpl)""" - raise NotImplementedError() - def indeximpl(self, index): - """returns a new implementation object for an Index (usually - sql.IndexImpl) - """ - raise NotImplementedError() def reflecttable(self, table): """given a table, will query the database and populate its Column and ForeignKey objects.""" raise NotImplementedError() -class SchemaVisitor(object): - """base class for an object that traverses across Schema structures.""" +class SchemaVisitor(sql.ClauseVisitor): + """defines the visiting for SchemaItem objects""" def visit_schema(self, schema): """visit a generic SchemaItem""" pass diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index cbd9a82f31..8ebf7624ef 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -13,7 +13,7 @@ from exceptions import * import string, re, random types = __import__('types') -__all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] +__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] def desc(column): """returns a descending ORDER BY clause element, e.g.: @@ -160,11 +160,15 @@ def label(name, obj): """returns a Label object for the given selectable, used in the column list for a select statement.""" return Label(name, obj) -def column(table, text): - """returns a textual column clause, relative to a table. this differs from using straight text - or text() in that the column is treated like a regular column, i.e. gets added to a Selectable's list - of columns.""" - return ColumnClause(text, table) +def column(text, table=None, type=None): + """returns a textual column clause, relative to a table. this is also the primitive version of + a schema.Column which is a subclass. """ + return ColumnClause(text, table, type) + +def table(name, *columns): + """returns a table clause. this is a primitive version of the schema.Table object, which is a subclass + of this object.""" + return TableClause(name, *columns) def bindparam(key, value = None, type=None): """creates a bind parameter clause with the given key. @@ -172,7 +176,7 @@ def bindparam(key, value = None, type=None): An optional default value can be specified by the value parameter, and the optional type parameter is a sqlalchemy.types.TypeEngine object which indicates bind-parameter and result-set translation for this bind parameter.""" - if isinstance(key, schema.Column): + if isinstance(key, ColumnClause): return BindParamClause(key.name, value, type=key.type) else: return BindParamClause(key, value, type=type) @@ -190,7 +194,7 @@ def text(text, engine=None, *args, **kwargs): text - the text of the SQL statement to be created. use : to specify bind parameters; they will be compiled to their engine-specific format. - engine - the engine to be used for this text query. Alternatively, call the + engine - an optional engine to be used for this text query. Alternatively, call the text() method off the engine directly. bindparams - a list of bindparam() instances which can be used to define the @@ -222,15 +226,15 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem) + return not isinstance(element, ClauseElement) def is_column(col): - return isinstance(col, schema.Column) or isinstance(col, ColumnElement) + return isinstance(col, ColumnElement) -class ClauseVisitor(schema.SchemaVisitor): - """builds upon SchemaVisitor to define the visiting of SQL statement elements in - addition to Schema elements.""" - def visit_columnclause(self, column):pass +class ClauseVisitor(object): + """Defines the visiting of ClauseElements.""" + def visit_column(self, column):pass + def visit_table(self, column):pass def visit_fromclause(self, fromclause):pass def visit_bindparam(self, bindparam):pass def visit_textclause(self, textclause):pass @@ -309,18 +313,6 @@ class Compiled(ClauseVisitor): class ClauseElement(object): """base class for elements of a programmatically constructed SQL expression.""" - def hash_key(self): - """returns a string that uniquely identifies the concept this ClauseElement - represents. - - two ClauseElements can have the same value for hash_key() iff they both correspond to - the exact same generated SQL. This allows the hash_key() values of a collection of - ClauseElements to be constructed into a larger identifying string for the purpose of - caching a SQL expression. - - Note that since ClauseElements may be mutable, the hash_key() value is subject to - change if the underlying structure of the ClauseElement changes.""" - raise NotImplementedError(repr(self)) def _get_from_objects(self): """returns objects represented in this ClauseElement that should be added to the FROM list of a query.""" @@ -357,19 +349,24 @@ class ClauseElement(object): return False def _find_engine(self): + """default strategy for locating an engine within the clause element. + relies upon a local engine property, or looks in the "from" objects which + ultimately have to contain Tables or TableClauses. """ try: if self._engine is not None: return self._engine except AttributeError: pass for f in self._get_from_objects(): + if f is self: + continue engine = f.engine if engine is not None: return engine else: return None - engine = property(lambda s: s._find_engine()) + engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.") def compile(self, engine = None, parameters = None, typemap=None): """compiles this SQL expression using its underlying SQLEngine to produce @@ -380,16 +377,13 @@ class ClauseElement(object): engine = self.engine if engine is None: - raise InvalidRequestError("no SQLEngine could be located within this ClauseElement.") + import sqlalchemy.ansisql as ansisql + engine = ansisql.engine() return engine.compile(self, parameters=parameters, typemap=typemap) def __str__(self): - e = self.engine - if e is None: - import sqlalchemy.ansisql as ansisql - e = ansisql.engine() - return str(self.compile(e)) + return str(self.compile()) def execute(self, *multiparams, **params): """compiles and executes this SQL expression using its underlying SQLEngine. the @@ -425,6 +419,7 @@ class ClauseElement(object): return not_(self) class CompareMixin(object): + """defines comparison operations for ClauseElements.""" def __lt__(self, other): return self._compare('<', other) def __le__(self, other): @@ -500,19 +495,15 @@ class Selectable(ClauseElement): def accept_visitor(self, visitor): raise NotImplementedError(repr(self)) - def is_selectable(self): return True - def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _group_parenthesized(self): """indicates if this Selectable requires parenthesis when grouped into a compound statement""" return True - class ColumnElement(Selectable, CompareMixin): """represents a column element within the list of a Selectable's columns. Provides default implementations for the things a "column" needs, including a "primary_key" flag, @@ -552,8 +543,6 @@ class FromClause(Selectable): return [self.oid_column] else: return self.primary_key - def hash_key(self): - return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name)) def accept_visitor(self, visitor): visitor.visit_fromclause(self) def count(self, whereclause=None, **params): @@ -627,8 +616,6 @@ class BindParamClause(ClauseElement, CompareMixin): visitor.visit_bindparam(self) def _get_from_objects(self): return [] - def hash_key(self): - return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname)) def typeprocess(self, value, engine): return self._get_convert_type(engine).convert_bind_param(value, engine) def compare(self, other): @@ -674,8 +661,6 @@ class TextClause(ClauseElement): for item in self.bindparams.values(): item.accept_visitor(visitor) visitor.visit_textclause(self) - def hash_key(self): - return "TextClause(%s)" % repr(self.text) def _get_from_objects(self): return [] @@ -686,8 +671,6 @@ class Null(ClauseElement): visitor.visit_null(self) def _get_from_objects(self): return [] - def hash_key(self): - return "Null" class ClauseList(ClauseElement): """describes a list of clauses. by default, is comma-separated, @@ -698,8 +681,6 @@ class ClauseList(ClauseElement): if c is None: continue self.append(c) self.parens = kwargs.get('parens', False) - def hash_key(self): - return string.join([c.hash_key() for c in self.clauses], ",") def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return ClauseList(parens=self.parens, *clauses) @@ -753,8 +734,6 @@ class CompoundClause(ClauseList): for c in self.clauses: f += c._get_from_objects() return f - def hash_key(self): - return string.join([c.hash_key() for c in self.clauses], self.operator or " ") def compare(self, other): """compares this CompoundClause to the given item. @@ -794,8 +773,6 @@ class Function(ClauseList, ColumnElement): return BindParamClause(self.name, obj, shortname=self.name, type=self.type) def select(self): return select([self]) - def hash_key(self): - return self.name + "(" + string.join([c.hash_key() for c in self.clauses], ", ") + ")" def _compare_type(self, obj): return self.type @@ -811,8 +788,6 @@ class BinaryClause(ClauseElement): return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator) def _get_from_objects(self): return self.left._get_from_objects() + self.right._get_from_objects() - def hash_key(self): - return self.left.hash_key() + (self.operator or " ") + self.right.hash_key() def accept_visitor(self, visitor): self.left.accept_visitor(visitor) self.right.accept_visitor(visitor) @@ -879,16 +854,9 @@ class Join(FromClause): return and_(*crit) def _group_parenthesized(self): - """indicates if this Selectable requires parenthesis when grouped into a compound - statement""" return True - - def hash_key(self): - return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) - def select(self, whereclauses = None, **params): return select([self.left, self.right], whereclauses, from_obj=[self], **params) - def accept_visitor(self, visitor): self.left.accept_visitor(visitor) self.right.accept_visitor(visitor) @@ -941,9 +909,6 @@ class Alias(FromClause): def _exportable_columns(self): return self.selectable.columns - def hash_key(self): - return "Alias(%s, %s)" % (self.selectable.hash_key(), repr(self.name)) - def accept_visitor(self, visitor): self.selectable.accept_visitor(visitor) visitor.visit_alias(self) @@ -975,35 +940,27 @@ class Label(ColumnElement): return self.obj._get_from_objects() def _make_proxy(self, selectable, name = None): return self.obj._make_proxy(selectable, name=self.name) - - def hash_key(self): - return "Label(%s, %s)" % (self.name, self.obj.hash_key()) class ColumnClause(ColumnElement): - """represents a textual column clause in a SQL statement. allows the creation - of an additional ad-hoc column that is compiled against a particular table.""" - - def __init__(self, text, selectable=None): - self.text = text + """represents a textual column clause in a SQL statement. May or may not + be bound to an underlying Selectable.""" + def __init__(self, text, selectable=None, type=None): + self.key = self.name = self.text = text self.table = selectable - self.type = sqltypes.NullTypeEngine() - - name = property(lambda self:self.text) - key = property(lambda self:self.text) - _label = property(lambda self:self.text) - - def accept_visitor(self, visitor): - visitor.visit_columnclause(self) - - def hash_key(self): + self.type = type or sqltypes.NullTypeEngine() + def _get_label(self): if self.table is not None: - return "ColumnClause(%s, %s)" % (self.text, util.hash_key(self.table)) + return self.table.name + "_" + self.text else: - return "ColumnClause(%s)" % self.text - + return self.text + _label = property(_get_label) + def accept_visitor(self, visitor): + visitor.visit_column(self) def _get_from_objects(self): - return [] - + if self.table is not None: + return [self.table] + else: + return [] def _bind_param(self, obj): if self.table.name is None: return BindParamClause(self.text, obj, shortname=self.text, type=self.type) @@ -1013,79 +970,35 @@ class ColumnClause(ColumnElement): c = ColumnClause(name or self.text, selectable) selectable.columns[c.key] = c return c - -class ColumnImpl(ColumnElement): - """gets attached to a schema.Column object.""" - - def __init__(self, column): - self.column = column - self.name = column.name - - if column.table.name: - self._label = column.table.name + "_" + self.column.name - else: - self._label = self.column.name - - engine = property(lambda s: s.column.engine) - default_label = property(lambda s:s._label) - original = property(lambda self:self.column.original) - parent = property(lambda self:self.column.parent) - columns = property(lambda self:[self.column]) - - def label(self, name): - return Label(name, self.column) - - def copy_container(self): - return self.column - - def compare(self, other): - """compares this ColumnImpl's column to the other given Column""" - return self.column is other - + def _compare_type(self, obj): + return self.type def _group_parenthesized(self): return False - - def _get_from_objects(self): - return [self.column.table] - - def _bind_param(self, obj): - if self.column.table.name is None: - return BindParamClause(self.name, obj, shortname = self.name, type = self.column.type) - else: - return BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name, type = self.column.type) - def _compare_self(self): - """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to - just return self""" - return self.column - def _compare_type(self, obj): - return self.column.type - - def compile(self, engine = None, parameters = None, typemap=None): - if engine is None: - engine = self.engine - if engine is None: - raise InvalidRequestError("no SQLEngine could be located within this ClauseElement.") - return engine.compile(self.column, parameters=parameters, typemap=typemap) -class TableImpl(FromClause): - """attached to a schema.Table to provide it with a Selectable interface - as well as other functions - """ - - def __init__(self, table): - self.table = table - self.id = self.table.name +class TableClause(FromClause): + def __init__(self, name, *columns): + super(TableClause, self).__init__(name) + self.name = self.id = self.fullname = name + self._columns = util.OrderedProperties() + self._foreign_keys = [] + self._primary_key = [] + for c in columns: + self.append_column(c) + def append_column(self, c): + self._columns[c.text] = c + c.table = self def _oid_col(self): + if self.engine is None: + return None # OID remains a little hackish so far if not hasattr(self, '_oid_column'): - if self.table.engine.oid_column_name() is not None: - self._oid_column = schema.Column(self.table.engine.oid_column_name(), sqltypes.Integer, hidden=True) - self._oid_column._set_parent(self.table) + if self.engine.oid_column_name() is not None: + self._oid_column = schema.Column(self.engine.oid_column_name(), sqltypes.Integer, hidden=True) + self._oid_column._set_parent(self) else: self._oid_column = None return self._oid_column - def _orig_columns(self): try: return self._orig_cols @@ -1097,47 +1010,52 @@ class TableImpl(FromClause): if oid is not None: self._orig_cols[oid.original] = oid return self._orig_cols - - oid_column = property(_oid_col) - engine = property(lambda s: s.table.engine) - columns = property(lambda self: self.table.columns) - primary_key = property(lambda self:self.table.primary_key) - foreign_keys = property(lambda self:self.table.foreign_keys) + columns = property(lambda s:s._columns) + c = property(lambda s:s._columns) + primary_key = property(lambda s:s._primary_key) + foreign_keys = property(lambda s:s._foreign_keys) original_columns = property(_orig_columns) + oid_column = property(_oid_col) + + def _clear(self): + """clears all attributes on this TableClause so that new items can be added again""" + self.columns.clear() + self.foreign_keys[:] = [] + self.primary_key[:] = [] + try: + delattr(self, '_orig_cols') + except AttributeError: + pass + def accept_visitor(self, visitor): + visitor.visit_table(self) def _exportable_columns(self): raise NotImplementedError() - def _group_parenthesized(self): return False - def _process_from_dict(self, data, asfrom): for f in self._get_from_objects(): data.setdefault(f.id, f) if asfrom: - data[self.id] = self.table + data[self.id] = self def count(self, whereclause=None, **params): - return select([func.count(1).label('count')], whereclause, from_obj=[self.table], **params) + return select([func.count(1).label('count')], whereclause, from_obj=[self], **params) def join(self, right, *args, **kwargs): - return Join(self.table, right, *args, **kwargs) + return Join(self, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): - return Join(self.table, right, isouter = True, *args, **kwargs) + return Join(self, right, isouter = True, *args, **kwargs) def alias(self, name=None): - return Alias(self.table, name) + return Alias(self, name) def select(self, whereclause = None, **params): - return select([self.table], whereclause, **params) + return select([self], whereclause, **params) def insert(self, values = None): - return insert(self.table, values=values) + return insert(self, values=values) def update(self, whereclause = None, values = None): - return update(self.table, whereclause, values) + return update(self, whereclause, values) def delete(self, whereclause = None): - return delete(self.table, whereclause) - def create(self, **params): - self.table.engine.create(self.table) - def drop(self, **params): - self.table.engine.drop(self.table) + return delete(self, whereclause) def _get_from_objects(self): - return [self.table] + return [self] class SelectBaseMixin(object): """base class for Select and CompoundSelects""" @@ -1191,11 +1109,6 @@ class CompoundSelect(SelectBaseMixin, FromClause): order_by = kwargs.get('order_by', None) if order_by: self.order_by(*order_by) - def hash_key(self): - return "CompoundSelect(%s)" % string.join( - [util.hash_key(s) for s in self.selects] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'keyword']], - ",") def _exportable_columns(self): return self.selects[0].columns def _proxy_column(self, column): @@ -1271,6 +1184,8 @@ class Select(SelectBaseMixin, FromClause): self.is_where = is_where def visit_compound_select(self, cs): self.visit_select(cs) + def visit_column(self, c):pass + def visit_table(self, c):pass def visit_select(self, select): if select is self.select: return @@ -1288,7 +1203,6 @@ class Select(SelectBaseMixin, FromClause): for f in column._get_from_objects(): f.accept_visitor(self._correlator) column._process_from_dict(self._froms, False) - def _exportable_columns(self): return self._raw_columns def _proxy_column(self, column): @@ -1313,24 +1227,6 @@ class Select(SelectBaseMixin, FromClause): _hash_recursion = util.RecursionStack() - def hash_key(self): - # selects call alot of stuff so we do some "recursion checking" - # to eliminate loops - if Select._hash_recursion.push(self): - return "recursive_select()" - try: - return "Select(%s)" % string.join( - [ - "columns=" + string.join([util.hash_key(c) for c in self._raw_columns],','), - "where=" + util.hash_key(self.whereclause), - "from=" + string.join([util.hash_key(f) for f in self.froms],','), - "having=" + util.hash_key(self.having), - "clauses=" + string.join([util.hash_key(c) for c in self.clauses], ',') - ] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'distinct', 'limit', 'offset']], "," - ) - finally: - Select._hash_recursion.pop(self) - def clear_from(self, id): self.append_from(FromClause(from_name = None, from_key = id)) @@ -1342,7 +1238,7 @@ class Select(SelectBaseMixin, FromClause): fromclause._process_from_dict(self._froms, True) def _get_froms(self): - return [f for f in self._froms.values() if self._correlated is None or not self._correlated.has_key(f.id)] + return [f for f in self._froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f.id))] froms = property(lambda s: s._get_froms()) def accept_visitor(self, visitor): @@ -1388,9 +1284,6 @@ class Select(SelectBaseMixin, FromClause): class UpdateBase(ClauseElement): """forms the base for INSERT, UPDATE, and DELETE statements.""" - def hash_key(self): - return str(id(self)) - def _process_colparams(self, parameters): """receives the "values" of an INSERT or UPDATE statement and constructs appropriate ind parameters.""" @@ -1419,6 +1312,9 @@ class UpdateBase(ClauseElement): except KeyError: del parameters[key] return parameters + + def _find_engine(self): + return self._engine class Insert(UpdateBase): @@ -1457,25 +1353,3 @@ class Delete(UpdateBase): self.whereclause.accept_visitor(visitor) visitor.visit_delete(self) -class IndexImpl(ClauseElement): - - def __init__(self, index): - self.index = index - self.name = index.name - self._engine = self.index.table.engine - - table = property(lambda s: s.index.table) - columns = property(lambda s: s.index.columns) - - def hash_key(self): - return self.index.hash_key() - def accept_visitor(self, visitor): - visitor.visit_index(self.index) - def compare(self, other): - return self.index is other - def create(self): - self._engine.create(self.index) - def drop(self): - self._engine.drop(self.index) - def execute(self): - self.create() diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 301db0ec4d..fccb2f3bd3 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -__all__ = ['OrderedProperties', 'OrderedDict', 'generic_repr', 'HashSet'] +__all__ = ['OrderedProperties', 'OrderedDict', 'generic_repr', 'HashSet', 'AttrProp'] import thread, weakref, UserList,string, inspect from exceptions import * @@ -23,7 +23,21 @@ def to_set(x): return HashSet(to_list(x)) else: return x - + +class AttrProp(object): + """a quick way to stick a property accessor on an object""" + def __init__(self, key): + self.key = key + def __set__(self, obj, value): + setattr(obj, self.key, value) + def __delete__(self, obj): + delattr(obj, self.key) + def __get__(self, obj, owner): + if obj is None: + return self + else: + return getattr(obj, self.key) + def generic_repr(obj, exclude=None): L = ['%s=%s' % (a, repr(getattr(obj, a))) for a in dir(obj) if not callable(getattr(obj, a)) and not a.startswith('_') and (exclude is None or not exclude.has_key(a))] return '%s(%s)' % (obj.__class__.__name__, ','.join(L)) @@ -65,9 +79,10 @@ class OrderedProperties(object): def __setattr__(self, key, object): if not hasattr(self, key): self._list.append(key) - self.__dict__[key] = object - + def clear(self): + for key in self._list[:]: + del self[key] class RecursionStack(object): """a thread-local stack used to detect recursive object traversals.""" def __init__(self): diff --git a/test/objectstore.py b/test/objectstore.py index 687c9b1028..63a39641e8 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -961,6 +961,9 @@ class SaveTest2(AssertMixin): Column('email_address', String(20)), redefine=True ) + x = sql.Join(self.users, self.addresses) +# raise repr(self.users) + repr(self.users.primary_key) +# raise repr(self.addresses) + repr(self.addresses.foreign_keys) self.users.create() self.addresses.create() db.echo = testbase.echo diff --git a/test/select.py b/test/select.py index 625a1ec7cf..788e39f7b5 100644 --- a/test/select.py +++ b/test/select.py @@ -10,23 +10,26 @@ db = ansisql.engine() from testbase import PersistTest import unittest, re - -table = Table('mytable', db, - Column('myid', Integer, key = 'id'), - Column('name', String, key = 'name'), - Column('description', String, key = 'description'), +# the select test now tests almost completely with TableClause/ColumnClause objects, +# which are free-roaming table/column objects not attached to any database. +# so SQLAlchemy's SQL construction engine can be used with no database dependencies at all. + +table1 = table('mytable', + column('myid'), + column('name'), + column('description'), ) -table2 = Table( - 'myothertable', db, - Column('otherid', Integer, key='id'), - Column('othername', String, key='name'), +table2 = table( + 'myothertable', + column('otherid'), + column('othername'), ) -table3 = Table( - 'thirdtable', db, - Column('userid', Integer, key='id'), - Column('otherstuff', Integer), +table3 = table( + 'thirdtable', + column('userid'), + column('otherstuff'), ) table4 = Table( @@ -37,27 +40,27 @@ table4 = Table( schema = 'remote_owner' ) -users = Table('users', db, - Column('user_id', Integer, primary_key = True), - Column('user_name', String(40)), - Column('password', String(10)), +users = table('users', + column('user_id'), + column('user_name'), + column('password'), ) -addresses = Table('addresses', db, - Column('address_id', Integer, primary_key = True), - Column('user_id', Integer, ForeignKey("users.user_id")), - Column('street', String(100)), - Column('city', String(80)), - Column('state', String(2)), - Column('zip', String(10)) +addresses = table('addresses', + column('address_id'), + column('user_id'), + column('street'), + column('city'), + column('state'), + column('zip') ) - class SQLTest(PersistTest): def runtest(self, clause, result, engine = None, params = None, checkparams = None): + if engine is None: + engine = db c = clause.compile(engine, params) self.echo("\nSQL String:\n" + str(c) + repr(c.get_params())) - self.echo("\nHash Key:\n" + clause.hash_key()) cc = re.sub(r'\n', '', str(c)) self.assert_(cc == result, str(c) + "\n does not match \n" + result) if checkparams is not None: @@ -67,53 +70,44 @@ class SQLTest(PersistTest): self.assert_(c.get_params() == checkparams, "params dont match") class SelectTest(SQLTest): - - def testtableselect(self): - self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable") + self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable") - self.runtest(select([table, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \ + self.runtest(select([table1, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \ myothertable.othername FROM mytable, myothertable") def testsubquery(self): - - # TODO: a subquery in a column clause. - #self.runtest( - # select([table, select([table2.c.id])]), - # """""" - #) - - s = select([table], table.c.name == 'jack') + s = select([table1], table1.c.name == 'jack') print [key for key in s.c.keys()] self.runtest( select( [s], - s.c.id == 7 + s.c.myid == 7 ) , - "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable WHERE mytable.name = :mytable_name) WHERE id = :id") + "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable WHERE mytable.name = :mytable_name) WHERE myid = :myid") - sq = select([table]) + sq = select([table1]) self.runtest( sq.select(), - "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable)" + "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)" ) sq = subquery( 'sq', - [table], + [table1], ) self.runtest( - sq.select(sq.c.id == 7), - "SELECT sq.id, sq.name, sq.description FROM \ -(SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.id = :sq_id" + sq.select(sq.c.myid == 7), + "SELECT sq.myid, sq.name, sq.description FROM \ +(SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.myid = :sq_myid" ) sq = subquery( 'sq', - [table, table2], - and_(table.c.id ==7, table2.c.id==table.c.id), + [table1, table2], + and_(table1.c.myid ==7, table2.c.otherid==table1.c.myid), use_labels = True ) @@ -140,15 +134,15 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def testand(self): self.runtest( - select(['*'], and_(table.c.id == 12, table.c.name=='asdf', table2.c.name == 'foo', "sysdate() = today()")), + select(['*'], and_(table1.c.myid == 12, table1.c.name=='asdf', table2.c.othername == 'foo', "sysdate() = today()")), "SELECT * FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name AND myothertable.othername = :myothertable_othername AND sysdate() = today()" ) def testor(self): self.runtest( - select([table], and_( - table.c.id == 12, - or_(table2.c.name=='asdf', table2.c.name == 'foo', table2.c.id == 9), + select([table1], and_( + table1.c.myid == 12, + or_(table2.c.othername=='asdf', table2.c.othername == 'foo', table2.c.otherid == 9), "sysdate() = today()", )), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND (myothertable.othername = :myothertable_othername OR myothertable.othername = :myothertable_othername_1 OR myothertable.otherid = :myothertable_otherid) AND sysdate() = today()", @@ -157,7 +151,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def testoperators(self): self.runtest( - table.select((table.c.id != 12) & ~(table.c.name=='john')), + table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name)" ) @@ -167,35 +161,35 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def testmultiparam(self): self.runtest( - select(["*"], or_(table.c.id == 12, table.c.id=='asdf', table.c.id == 'foo')), + select(["*"], or_(table1.c.myid == 12, table1.c.myid=='asdf', table1.c.myid == 'foo')), "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid OR mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2" ) def testorderby(self): self.runtest( - table2.select(order_by = [table2.c.id, asc(table2.c.name)]), + table2.select(order_by = [table2.c.otherid, asc(table2.c.othername)]), "SELECT myothertable.otherid, myothertable.othername FROM myothertable ORDER BY myothertable.otherid, myothertable.othername ASC" ) def testgroupby(self): self.runtest( - select([table2.c.name, func.count(table2.c.id)], group_by = [table2.c.name]), + select([table2.c.othername, func.count(table2.c.otherid)], group_by = [table2.c.othername]), "SELECT myothertable.othername, count(myothertable.otherid) FROM myothertable GROUP BY myothertable.othername" ) def testgroupby_and_orderby(self): self.runtest( - select([table2.c.name, func.count(table2.c.id)], group_by = [table2.c.name], order_by = [table2.c.name]), + select([table2.c.othername, func.count(table2.c.otherid)], group_by = [table2.c.othername], order_by = [table2.c.othername]), "SELECT myothertable.othername, count(myothertable.otherid) FROM myothertable GROUP BY myothertable.othername ORDER BY myothertable.othername" ) def testalias(self): - # test the alias for a table. column names stay the same, table name "changes" to "foo". + # test the alias for a table1. column names stay the same, table name "changes" to "foo". self.runtest( - select([alias(table, 'foo')]) + select([alias(table1, 'foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo") # create a select for a join of two tables. use_labels means the column names will have # labels tablename_columnname, which become the column keys accessible off the Selectable object. - # also, only use one column from the second table and all columns from the first table. - q = select([table, table2.c.id], table.c.id == table2.c.id, use_labels = True) + # also, only use one column from the second table and all columns from the first table1. + q = select([table1, table2.c.otherid], table1.c.myid == table2.c.otherid, use_labels = True) # make an alias of the "selectable". column names stay the same (i.e. the labels), table name "changes" to "t2view". a = alias(q, 't2view') @@ -265,11 +259,11 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = def testtextmix(self): self.runtest(select( - [table, table2.c.id, "sysdate()", "foo, bar, lala"], + [table1, table2.c.otherid, "sysdate()", "foo, bar, lala"], and_( "foo.id = foofoo(lala)", "datetime(foo) = Today", - table.c.id == table2.c.id, + table1.c.myid == table2.c.otherid, ) ), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, sysdate(), foo, bar, lala \ @@ -277,68 +271,68 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today def testtextualsubquery(self): self.runtest(select( - [alias(table, 't'), "foo.f"], + [alias(table1, 't'), "foo.f"], "foo.f = t.id", from_obj = ["(select f from bar where lala=heyhey) foo"] ), "SELECT t.myid, t.name, t.description, foo.f FROM mytable AS t, (select f from bar where lala=heyhey) foo WHERE foo.f = t.id") def testliteral(self): - self.runtest(select([literal("foo") + literal("bar")], from_obj=[table]), + self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]), "SELECT :literal + :literal_1 FROM mytable", engine=db) def testfunction(self): - self.runtest(func.lala(3, 4, literal("five"), table.c.id) * table2.c.id, + self.runtest(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, "lala(:lala, :lala_1, :literal, mytable.myid) * myothertable.otherid", engine=db) def testjoin(self): self.runtest( - join(table2, table, table.c.id == table2.c.id).select(), + join(table2, table1, table1.c.myid == table2.c.otherid).select(), "SELECT myothertable.otherid, myothertable.othername, mytable.myid, mytable.name, \ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertable.otherid" ) self.runtest( select( - [table], - from_obj = [join(table, table2, table.c.id == table2.c.id)] + [table1], + from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid)] ), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") self.runtest( select( - [join(join(table, table2, table.c.id == table2.c.id), table3, table.c.id == table3.c.id) + [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid) ]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid" ) self.runtest( - join(users, addresses).select(), + join(users, addresses, users.c.user_id==addresses.c.user_id).select(), "SELECT users.user_id, users.user_name, users.password, addresses.address_id, addresses.user_id, addresses.street, addresses.city, addresses.state, addresses.zip FROM users JOIN addresses ON users.user_id = addresses.user_id" ) def testmultijoin(self): self.runtest( - select([table, table2, table3], + select([table1, table2, table3], - from_obj = [join(table, table2, table.c.id == table2.c.id).outerjoin(table3, table.c.id==table3.c.id)] + from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid).outerjoin(table3, table1.c.myid==table3.c.userid)] - #from_obj = [outerjoin(join(table, table2, table.c.id == table2.c.id), table3, table.c.id==table3.c.id)] + #from_obj = [outerjoin(join(table, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid==table3.c.userid)] ) ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON mytable.myid = thirdtable.userid" ) self.runtest( - select([table, table2, table3], - from_obj = [outerjoin(table, join(table2, table3, table2.c.id == table3.c.id), table.c.id==table2.c.id)] + select([table1, table2, table3], + from_obj = [outerjoin(table1, join(table2, table3, table2.c.otherid == table3.c.userid), table1.c.myid==table2.c.otherid)] ) ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN (myothertable JOIN thirdtable ON myothertable.otherid = thirdtable.userid) ON mytable.myid = myothertable.otherid" ) def testunion(self): x = union( - select([table], table.c.id == 5), - select([table], table.c.id == 12), - order_by = [table.c.id], + select([table1], table1.c.myid == 5), + select([table1], table1.c.myid == 12), + order_by = [table1.c.myid], ) self.runtest(x, "SELECT mytable.myid, mytable.name, mytable.description \ @@ -348,7 +342,7 @@ FROM mytable WHERE mytable.myid = :mytable_myid_1 ORDER BY mytable.myid") self.runtest( union( - select([table]), + select([table1]), select([table2]), select([table3]) ) @@ -365,14 +359,14 @@ FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thi # parameters. query = select( - [table, table2], + [table1, table2], and_( - table.c.name == 'fred', - table.c.id == 10, - table2.c.name != 'jack', + table1.c.name == 'fred', + table1.c.myid == 10, + table2.c.othername != 'jack', "EXISTS (select yay from foo where boo = lar)" ), - from_obj = [ outerjoin(table, table2, table.c.id == table2.c.id) ] + from_obj = [ outerjoin(table1, table2, table1.c.myid == table2.c.otherid) ] ) self.runtest(query, @@ -393,9 +387,9 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo def testbindparam(self): self.runtest(select( - [table, table2], - and_(table.c.id == table2.c.id, - table.c.name == bindparam('mytablename'), + [table1, table2], + and_(table1.c.myid == table2.c.otherid, + table1.c.name == bindparam('mytablename'), ) ), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ @@ -404,26 +398,26 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable # check that the bind params sent along with a compile() call # get preserved when the params are retreived later - s = select([table], table.c.id == bindparam('test')) - c = s.compile(parameters = {'test' : 7}) + s = select([table1], table1.c.myid == bindparam('test')) + c = s.compile(parameters = {'test' : 7}, engine=db) self.assert_(c.get_params() == {'test' : 7}) def testcorrelatedsubquery(self): self.runtest( - table.select(table.c.id == select([table2.c.id], table.c.name == table2.c.name)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid AS id FROM myothertable WHERE mytable.name = myothertable.othername)" + table1.select(table1.c.myid == select([table2.c.otherid], table1.c.name == table2.c.othername)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid AS otherid FROM myothertable WHERE mytable.name = myothertable.othername)" ) self.runtest( - table.select(exists([1], table2.c.id == table.c.id)), + table1.select(exists([1], table2.c.otherid == table1.c.myid)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)" ) - talias = table.alias('ta') - s = subquery('sq2', [talias], exists([1], table2.c.id == talias.c.id)) + talias = table1.alias('ta') + s = subquery('sq2', [talias], exists([1], table2.c.otherid == talias.c.myid)) self.runtest( - select([s, table]) - ,"SELECT sq2.id, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS id, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable") + select([s, table1]) + ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS myid, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable") s = select([addresses.c.street], addresses.c.user_id==users.c.user_id).alias('s') self.runtest( @@ -431,81 +425,80 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") def testin(self): - self.runtest(select([table], table.c.id.in_(1, 2, 3)), + self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_myid_1, :mytable_myid_2)") - self.runtest(select([table], table.c.id.in_(select([table2.c.id]))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid AS id FROM myothertable)") + self.runtest(select([table1], table1.c.myid.in_(select([table2.c.otherid]))), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid AS otherid FROM myothertable)") def testlateargs(self): """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments are sent""" - self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'id':'3', 'name':'jack'}) + self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'myid':'3', 'name':'jack'}) - self.runtest(table.select(table.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'id':'3'}) + self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3'}) - self.runtest(table.select(table.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'id':'3', 'name':'fred'}) + self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'}) class CRUDTest(SQLTest): def testinsert(self): # generic insert, will create bind params for all columns - self.runtest(insert(table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") + self.runtest(insert(table1), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") # insert with user-supplied bind params for specific columns, # cols provided literally self.runtest( - insert(table, {table.c.id : bindparam('userid'), table.c.name : bindparam('username')}), + insert(table1, {table1.c.myid : bindparam('userid'), table1.c.name : bindparam('username')}), "INSERT INTO mytable (myid, name) VALUES (:userid, :username)") # insert with user-supplied bind params for specific columns, cols # provided as strings self.runtest( - insert(table, dict(id = 3, name = 'jack')), + insert(table1, dict(myid = 3, name = 'jack')), "INSERT INTO mytable (myid, name) VALUES (:myid, :name)" ) # test with a tuple of params instead of named self.runtest( - insert(table, (3, 'jack', 'mydescription')), + insert(table1, (3, 'jack', 'mydescription')), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)", checkparams = {'myid':3, 'name':'jack', 'description':'mydescription'} ) def testupdate(self): - self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table.c.name:'fred'}) - self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'}) - self.runtest(update(table, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid") - self.runtest(update(table, whereclause = table.c.name == bindparam('crit'), values = {table.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'}) - self.runtest(update(table, table.c.id == 12, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'}) - self.runtest(update(table, table.c.id == 12, values = {table.c.id : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'}) - s = table.update(table.c.id == 12, values = {table.c.name : 'lala'}) - print str(s) - c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}) + self.runtest(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table1.c.name:'fred'}) + self.runtest(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'}) + self.runtest(update(table1, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid") + self.runtest(update(table1, whereclause = table1.c.name == bindparam('crit'), values = {table1.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'}) + self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'}) + self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'}) + s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'}) + c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}, engine=db) print str(c) self.assert_(str(s) == str(c)) def testupdateexpression(self): - self.runtest(update(table, - (table.c.id == func.hoho(4)) & - (table.c.name == literal('foo') + table.c.name + literal('lala')), + self.runtest(update(table1, + (table1.c.myid == func.hoho(4)) & + (table1.c.name == literal('foo') + table1.c.name + literal('lala')), values = { - table.c.name : table.c.name + "lala", - table.c.id : func.do_stuff(table.c.id, literal('hoho')) + table1.c.name : table1.c.name + "lala", + table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho')) }), "UPDATE mytable SET myid=(do_stuff(mytable.myid, :literal_2)), name=(mytable.name + :mytable_name) WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1") def testcorrelatedupdate(self): # test against a straight text subquery - u = update(table, values = {table.c.name : text("select name from mytable where id=mytable.id")}) + u = update(table1, values = {table1.c.name : text("select name from mytable where id=mytable.id")}) self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") # test against a regular constructed subquery - s = select([table2], table2.c.id == table.c.id) - u = update(table, table.c.name == 'jack', values = {table.c.name : s}) + s = select([table2], table2.c.otherid == table1.c.myid) + u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s}) self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name") def testdelete(self): - self.runtest(delete(table, table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") + self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") class SchemaTest(SQLTest): def testselect(self): diff --git a/test/testbase.py b/test/testbase.py index a26b87bd4f..afdca47382 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -2,9 +2,6 @@ import unittest import StringIO import sqlalchemy.engine as engine import re, sys -import sqlalchemy.databases.sqlite as sqlite -import sqlalchemy.databases.postgres as postgres -#import sqlalchemy.databases.mysql as mysql echo = True #echo = False