From c2c8b14815a7231105f2b0f322e7c1c128fb126c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 6 Aug 2005 20:32:42 +0000 Subject: [PATCH] --- lib/sqlalchemy/ansisql.py | 20 +-- lib/sqlalchemy/mapper.py | 4 +- lib/sqlalchemy/schema.py | 89 ++++++-------- lib/sqlalchemy/sql.py | 252 ++++++++++++++++++++++++++------------ test/select.py | 148 +++++++++++----------- 5 files changed, 303 insertions(+), 210 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 3f6cbb835d..93e7e737d6 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -207,13 +207,19 @@ class ANSICompiler(sql.Compiled): def visit_update(self, update_stmt): colparams = update_stmt.get_colparams(self._bindparams) - - for c in colparams: - b = c[1] - self.binds[b.key] = b - self.binds[b.shortname] = b - - text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=:%s" % (c[0].name, c[1].key) for c in colparams], ', ') + def create_param(p): + if isinstance(p, BindParamClause): + self.binds[p.key] = p + self.binds[p.shortname] = p + return ":" + p.key + else: + p.accept_visitor(self) + if isinstance(p, ClauseElement): + return "(" + self.get_str(p) + ")" + else: + return self.get_str(p) + + text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=%s" % (c[0].name, create_param(c[1])) for c in colparams], ', ') if update_stmt.whereclause: text += " WHERE " + self.get_str(update_stmt.whereclause) diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 28d42792f8..561a0a4a7e 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -70,7 +70,7 @@ def eagerload(name): def lazyload(name): return EagerLazySwitcher(name, toeager = False) -class Mapper(object): +copy_containerclass Mapper(object): def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True, echo = None): self.class_ = class_ self.selectable = selectable @@ -408,7 +408,7 @@ class LazyLoader(PropertyLoader): self.lazywhere = sql.and_(self.primaryjoin, self.secondaryjoin) else: self.lazywhere = self.primaryjoin - self.lazywhere = self.lazywhere.copy_structure() + self.lazywhere = self.lazywhere.copy_container() li = LazyIzer(primarytable) self.lazywhere.accept_visitor(li) self.binds = li.binds diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 41cb166212..74bb2e3c5e 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -18,31 +18,32 @@ from sqlalchemy.util import * import copy -engine = None - - -__ALL__ = ['Table', 'Column', 'Relation', 'Sequence', +__ALL__ = ['Table', 'Column', 'Sequence', 'INT', 'CHAR', 'VARCHAR', 'TEXT', 'FLOAT', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN' ] -class INT: pass +class INT: + """integer datatype""" + pass class CHAR: + """character datatype""" def __init__(self, length): self.length = length - + class VARCHAR: def __init__(self, length): self.length = length - -class TEXT: pass + + class FLOAT: def __init__(self, precision, length): self.precision = precision self.length = length - + +class TEXT: pass class DECIMAL: pass class TIMESTAMP: pass class DATETIME: pass @@ -56,12 +57,18 @@ class SchemaItem(object): def _init_items(self, *args): for item in args: item._set_parent(self) - - def accept_visitor(self, visitor): raise NotImplementedError() - def _set_parent(self, parent): raise NotImplementedError() + + def accept_visitor(self, visitor): + raise NotImplementedError() + + def _set_parent(self, parent): + """a child item attaches itself to its parent via this method.""" + raise NotImplementedError() + def hash_key(self): + """returns a string that identifies this SchemaItem uniquely""" return repr(self) - + def __getattr__(self, key): return getattr(self._impl, key) @@ -69,27 +76,27 @@ class SchemaItem(object): class Table(SchemaItem): """represents a relational database table.""" - def __init__(self, name, engine, *args, **params): + def __init__(self, name, engine, *args, **kwargs): self.name = name self.columns = OrderedProperties() self.c = self.columns self.relations = [] + self.primary_keys = [] self.engine = engine self._impl = self.engine.tableimpl(self) self._init_items(*args) - - if params.get('autoload', False): + + # load column definitions from the database if 'autoload' is defined + if kwargs.get('autoload', False): self.engine.reflecttable(self) def append_item(self, item): self._init_items(item) - + def _set_parent(self, schema): schema.tables[self.name] = self self.schema = schema - primary_keys = property (lambda self: [c for c in self.columns if c.primary_key]) - def accept_visitor(self, visitor): for c in self.columns: c.accept_visitor(visitor) @@ -97,29 +104,27 @@ class Table(SchemaItem): class Column(SchemaItem): """represents a column in a database table.""" - def __init__(self, name, type, reference = None, key = None, primary_key = False, *args, **params): + def __init__(self, name, type, key = None, primary_key = False, *args): self.name = name self.type = type self.sequences = OrderedProperties() - self.reference = reference self.key = key or name self.primary_key = primary_key self._items = args def _set_parent(self, table): table.columns[self.key] = self + if self.primary_key: + table.primary_keys.append(self) self.table = table self.engine = table.engine self._impl = self.engine.columnimpl(self) - self._init_items(*self._items) - - if self.reference is not None: - Relation(self.table, self.reference.table, self == self.reference) def _make_proxy(self, selectable, name = None): - # wow! using copy.copy(c) adds a full second to the select.py unittest package + """creates a copy of this Column for use in a new selectable unit""" + # using copy.copy(c) seems to add a full second to the select.py unittest package #c = copy.copy(self) #if name is not None: # c.name = name @@ -131,7 +136,7 @@ class Column(SchemaItem): selectable.columns[c.key] = c c._impl = self.engine.columnimpl(c) return c - + def accept_visitor(self, visitor): return visitor.visit_column(self) @@ -144,42 +149,24 @@ class Column(SchemaItem): def __str__(self): return self._impl.__str__() -class Relation(SchemaItem): - def __init__(self, parent, child, relationship, association = None, lazy = True): - self.parent = parent - self.child = child - self.relationship = relationship - self.lazy = lazy - self.association = association - - self._set_parent(parent) - - def _set_parent(self, table): - table.relations.append(self) - self.table = table - - def accept_visitor(self, visitor): - visitor.visit_relation(self) - class Sequence(SchemaItem): - """represents a sequence.""" + """represents a sequence, which applies to Oracle and Postgres databases.""" def set_parent(self, column, key): column.sequences[key] = self self.column = column - - def accept_visitor(self, visitor): + def accept_visitor(self, visitor): return visitor.visit_sequence(self) - + class SchemaEngine(object): + """a factory object used to create implementations for schema objects""" def tableimpl(self, table): raise NotImplementedError() - + def columnimpl(self, column): raise NotImplementedError() class SchemaVisitor(object): - """base class for an object that traverses a Schema object structure, - or sub-objects within one, and acts upon each node.""" + """base class for an object that traverses across Schema objects""" def visit_schema(self, schema):pass def visit_table(self, table):pass diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 2c133f6c16..643f9e3d33 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -16,9 +16,8 @@ # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. -"""base sql module used by all sql implementations. defines abstract units which construct -expression trees that generate into text strings + bind parameters. -""" +"""defines the base components of SQL expression trees.""" + import sqlalchemy.schema as schema import sqlalchemy.util as util import string @@ -26,28 +25,74 @@ import string __ALL__ = ['textclause', 'select', 'join', 'and_', 'or_', 'union', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence'] def desc(column): + """returns a descending ORDER BY clause element""" return CompoundClause(None, column, "DESC") def asc(column): + """returns an ascending ORDER BY clause element""" return CompoundClause(None, column, "ASC") -def outerjoin(left, right, onclause, **params): - return Join(left, right, onclause, isouter = True, **params) +def outerjoin(left, right, onclause, **kwargs): + """returns an OUTER JOIN clause element, given the left and right hand expressions, + as well as the ON condition's expression. When chaining joins together, the previous JOIN + expression should be specified as the left side of this JOIN expression.""" + return Join(left, right, onclause, isouter = True, **kwargs) + +def join(left, right, onclause, **kwargs): + """returns a JOIN clause element (regular inner join), given the left and right hand expressions, + as well as the ON condition's expression. When chaining joins together, the previous JOIN + expression should be specified as the left side of this JOIN expression.""" + return Join(left, right, onclause, **kwargs) + +def select(columns, whereclause = None, from_obj = [], **kwargs): + """returns a SELECT clause element, given a list of columns and/or selectable items to select + columns from, an optional expression for the WHERE clause, an optional list of "FROM" objects + to select from, and additional parameters.""" + return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs) + +def insert(table, values = None, **kwargs): + """returns an INSERT clause element. -def join(left, right, onclause, **params): - return Join(left, right, onclause, **params) - -def select(columns, whereclause = None, from_obj = [], **params): - return Select(columns, whereclause = whereclause, from_obj = from_obj, **params) - -def insert(table, values = None, **params): - return Insert(table, values, **params) + 'table' is the table to be inserted into. + 'values' is a dictionary which specifies the column specifications of the INSERT, and is optional. + If left as None, the + column specifications are determined from the bind parameters used during the compile phase of the + INSERT statement. If the bind parameters also are None during the compile phase, then the column + specifications will be generated from the full list of table columns. + + If both 'values' and compile-time bind parameters are present, the compile-time bind parameters + override the information specified within 'values' on a per-key basis. + + The keys within 'values' can be either Column objects or their string identifiers. + Each key may reference one of: a literal data value (i.e. string, number, etc.), a Column object, + or a SELECT statement. If a SELECT statement is specified which references this INSERT statement's + table, the statement will be correlated against the INSERT statement. + """ + return Insert(table, values, **kwargs) -def update(table, whereclause = None, values = None, **params): - return Update(table, whereclause, values, **params) +def update(table, whereclause = None, values = None, **kwargs): + """returns an UPDATE clause element. + + 'table' is the table to be updated. + 'whereclause' is a ClauseElement describing the WHERE condition of the UPDATE statement. + 'values' is a dictionary which specifies the SET conditions of the UPDATE, and is optional. + If left as None, the + SET conditions are determined from the bind parameters used during the compile phase of the + UPDATE statement. If the bind parameters also are None during the compile phase, then the SET + conditions will be generated from the full list of table columns. + + If both 'values' and compile-time bind parameters are present, the compile-time bind parameters + override the information specified within 'values' on a per-key basis. + + The keys within 'values' can be either Column objects or their string identifiers. + Each key may reference one of: a literal data value (i.e. string, number, etc.), a Column object, + or a SELECT statement. If a SELECT statement is specified which references this UPDATE statement's + table, the statement will be correlated against the UPDATE statement. + """ + return Update(table, whereclause, values, **kwargs) -def delete(table, whereclause = None, **params): - return Delete(table, whereclause, **params) +def delete(table, whereclause = None, **kwargs): + return Delete(table, whereclause, **kwargs) def and_(*clauses): return _compound_clause('AND', *clauses) @@ -55,7 +100,7 @@ def and_(*clauses): def or_(*clauses): clause = _compound_clause('OR', *clauses) return clause - + def union(*selects, **params): return _compound_select('UNION', *selects, **params) @@ -73,23 +118,25 @@ def textclause(text): def sequence(): return Sequence() - + def _compound_clause(keyword, *clauses): return CompoundClause(keyword, *clauses) def _compound_select(keyword, *selects, **params): - if len(selects) == 0: return None - + if len(selects) == 0: + return None s = selects[0] for n in selects[1:]: s.append_clause(keyword, n) - + if params.get('order_by', None) is not None: s.order_by(*params['order_by']) return s class ClauseVisitor(schema.SchemaVisitor): + """builds upon SchemaVisitor to define the visiting of SQL statement elements in + addition to Schema elements.""" def visit_columnclause(self, column):pass def visit_fromclause(self, fromclause):pass def visit_bindparam(self, bindparam):pass @@ -101,13 +148,23 @@ class ClauseVisitor(schema.SchemaVisitor): def visit_join(self, join):pass class Compiled(ClauseVisitor): - pass - + """represents a compiled SQL expression. the __str__ method of the Compiled object + should produce the actual text of the statement. Compiled objects are specific to the database + library that created them, and also may or may not be specific to the columns referenced + within a particular set of bind parameters. In no case should the Compiled object be dependent + on the actual values of those bind parameters, even though it may reference those values + as defaults.""" + def __str__(self): + raise NotImplementedError() + def get_params(self, **params): + """returns the bind params for this compiled object, with values overridden by + those given in the **params dictionary""" + raise NotImplementedError() + class ClauseElement(object): - """base class for elements of a generated SQL statement. + """base class for elements of a programmatically constructed SQL expression. - includes a parameter hash to store bind parameter key/value pairs, - as well as a list of 'from objects' which collects items to be placed + includes a list of 'from objects' which collects items to be placed in the FROM clause of a SQL statement. when many ClauseElements are attached together, the from objects and bind @@ -115,26 +172,46 @@ class ClauseElement(object): """ def hash_key(self): + """returns a string that uniquely identifies the concept this ClauseElement represents. + + two ClauseElements can have the same value for hash_key() iff they both correspond to the + exact same generated SQL. This allows the hash_key() values of a collection of ClauseElements + to be constructed into a larger identifying string for the purpose of caching a SQL expression. + + Note that since ClauseElements may be mutable, the hash_key() value is subject to change + if the underlying structure of the ClauseElement changes.""" raise NotImplementedError(repr(self)) def _get_from_objects(self): raise NotImplementedError(repr(self)) def accept_visitor(self, visitor): raise NotImplementedError(repr(self)) - def compile(self, engine, bindparams = None): - return engine.compile(self, bindparams = bindparams) - def copy_structure(self): - """allows the copying of a statement's containers, so that a modified statement - can be produced without affecting the original. containing clauseelements, - like Select, Join, CompoundClause, BinaryClause, etc., should produce a copy of - themselves, whereas "leaf-node" clauseelements should return themselves.""" + def copy_container(self): + """should return a copy of this ClauseElement, iff this ClauseElement contains other + ClauseElements. Otherwise, it should be left alone to return self. This is used to create + copies of expression trees that still reference the same "leaf nodes". The new structure + can then be restructured without affecting the original.""" return self - + def _engine(self): + """should return a SQLEngine instance that is associated with this expression tree. + this engine is usually attached to one of the underlying Table objects within the expression.""" raise NotImplementedError("Object %s has no built-in SQLEngine." % repr(self)) - + + def compile(self, engine, bindparams = None): + """compiles this SQL expression using its underlying SQLEngine to produce + a Compiled object. The actual SQL statement is the Compiled object's string representation. + bindparams is an optional dictionary representing the bind parameters to be used with + the statement. Currently, only the compilations of INSERT and UPDATE statements + use the bind parameters, in order to determine which + table columns should be used in the statement.""" + return engine.compile(self, bindparams = bindparams) + def execute(self, **params): + """compiles and executes this SQL expression using its underlying SQLEngine. + the given **params are used as bind parameters when compiling and executing the expression. + the DBAPI cursor object is returned.""" e = self._engine() c = self.compile(e, bindparams = params) # TODO: do pre-execute right here, for sequences, if the compiled object @@ -142,13 +219,14 @@ class ClauseElement(object): return e.execute(str(c), c.get_params(), echo = getattr(self, 'echo', None)) def result(self, **params): + """the same as execute(), except a RowProxy object is returned instead of a DBAPI cursor.""" e = self._engine() c = self.compile(e, bindparams = params) - return e.result(str(c), c.binds) - + return e.result(str(c), c.get_params(), echo = getattr(self, 'echo', None)) + class ColumnClause(ClauseElement): - """represents a column clause element in a SQL statement.""" - + """represents a textual column clause in a SQL statement.""" + def __init__(self, text, selectable): self.text = text self.table = selectable @@ -165,7 +243,7 @@ class ColumnClause(ClauseElement): def hash_key(self): return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key()) - + def _get_from_objects(self): return [] @@ -236,8 +314,8 @@ class CompoundClause(ClauseElement): if c is None: continue self.append(c) - def copy_structure(self): - clauses = [clause.copy_structure() for clause in self.clauses] + def copy_container(self): + clauses = [clause.copy_container() for clause in self.clauses] return CompoundClause(self.operator, *clauses) def append(self, clause): @@ -279,8 +357,8 @@ class BinaryClause(ClauseElement): self.operator = operator self.parens = False - def copy_structure(self): - return BinaryClause(self.left.copy_structure(), self.right.copy_structure(), self.operator) + def copy_container(self): + return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator) def _get_from_objects(self): return self.left._get_from_objects() + self.right._get_from_objects() @@ -394,7 +472,7 @@ class ColumnSelectable(Selectable): self.label = self.column.name self.fullname = self.column.name - def copy_structure(self): + def copy_container(self): return self.column def _get_from_objects(self): @@ -594,29 +672,59 @@ class Select(Selectable): class UpdateBase(ClauseElement): + """forms the base for INSERT, UPDATE, and DELETE statements. + Deals with the special needs of INSERT and UPDATE parameter lists - + these statements have two separate lists of parameters, those + defined when the statement is constructed, and those specified at compile time.""" + + def _process_colparams(self, parameters): + if parameters is None: + return None + + for key in parameters.keys(): + value = parameters[key] + if isinstance(value, Select): + value.append_from(FromClause(from_key=self.table.id)) + elif not isinstance(value, schema.Column) and not isinstance(value, ClauseElement): + try: + col = self.table.c[key] + parameters[key] = bindparam(col.name, value) + except KeyError: + del parameters[key] + + return parameters + def get_colparams(self, parameters): - values = [] + # case one: no parameters in the statement, no parameters in the + # compiled params - just return binds for all the table columns + if parameters is None and self.parameters is None: + return [(c, bindparam(c.name)) for c in self.table.columns] + # if we have statement parameters - set defaults in the + # compiled params if parameters is None: - parameters = self.parameters + parameters = {} - if parameters is None: - for c in self.table.columns: - values.append((c, bindparam(c.name))) - else: - d = {} - for key, value in parameters.iteritems(): - if isinstance(key, schema.Column): - d[key] = value - else: - d[self.table.columns[str(key)]] = value - - for c in self.table.columns: - if d.has_key(c): - value = d[c] - if not isinstance(value, BindParamClause): - value = bindparam(c.name, value) - values.append((c, value)) + if self.parameters is not None: + for k, v in self.parameters.iteritems(): + parameters.setdefault(k, v) + + # now go thru compiled params, get the Column object for each key + d = {} + for key, value in parameters.iteritems(): + if isinstance(key, schema.Column): + d[key] = value + else: + d[self.table.columns[str(key)]] = value + + # create a list of column assignment clauses as tuples + values = [] + for c in self.table.columns: + if d.has_key(c): + value = d[c] + if isinstance(value, str): + value = bindparam(c.name, value) + values.append((c, value)) return values def _engine(self): @@ -635,7 +743,7 @@ class Insert(UpdateBase): def __init__(self, table, parameters = None, **params): self.table = table self.select = None - self.parameters = parameters + self.parameters = self._process_colparams(parameters) self.engine = self.table._engine() def accept_visitor(self, visitor): @@ -650,41 +758,33 @@ class Insert(UpdateBase): def compile(self, engine = None, bindparams = None): if engine is None: engine = self.engine - if engine is None: raise "no engine supplied, and no engine could be located within the clauses!" - return engine.compile(self, bindparams) class Update(UpdateBase): def __init__(self, table, whereclause, parameters = None, **params): self.table = table self.whereclause = whereclause - - self.parameters = parameters + self.parameters = self._process_colparams(parameters) self.engine = self.table._engine() - def accept_visitor(self, visitor): if self.whereclause is not None: self.whereclause.accept_visitor(visitor) - visitor.visit_update(self) class Delete(UpdateBase): def __init__(self, table, whereclause, **params): self.table = table self.whereclause = whereclause - self.engine = self.table._engine() - def accept_visitor(self, visitor): if self.whereclause is not None: self.whereclause.accept_visitor(visitor) - visitor.visit_delete(self) - + class Sequence(BindParamClause): def __init__(self): BindParamClause.__init__(self, 'sequence') diff --git a/test/select.py b/test/select.py index b1e2ed13ca..2d3f23eb61 100644 --- a/test/select.py +++ b/test/select.py @@ -11,38 +11,26 @@ from sqlalchemy.schema import * from testbase import PersistTest import unittest, re -class SelectTest(PersistTest): - - def setUp(self): - - self.table = Table('mytable', db, - Column('myid', 3, key = 'id'), - Column('name', 4, key = 'name'), - Column('description', 4, key = 'description'), - ) - self.table2 = Table( - 'myothertable', db, - Column('otherid',3, key='id'), - Column('othername', 4, key='name'), - ) +table = Table('mytable', db, + Column('myid', 3, key = 'id'), + Column('name', 4, key = 'name'), + Column('description', 4, key = 'description'), +) - self.table3 = Table( - 'thirdtable', db, - Column('userid', 5, key='id'), - Column('otherstuff', 5), - ) +table2 = Table( + 'myothertable', db, + Column('otherid',3, key='id'), + Column('othername', 4, key='name'), +) - - def testoperator(self): - return - table = Table( - 'mytable', - Column('myid',3, key='id'), - Column('name', 4) - ) +table3 = Table( + 'thirdtable', db, + Column('userid', 5, key='id'), + Column('otherstuff', 5), +) - print (table.c.id == 5) +class SelectTest(PersistTest): def testtext(self): self.runtest( @@ -52,14 +40,14 @@ class SelectTest(PersistTest): ) def testtableselect(self): - self.runtest(self.table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable") + self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable") - self.runtest(select([self.table, self.table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \ + self.runtest(select([table, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \ myothertable.othername FROM mytable, myothertable") def testsubquery(self): - s = select([self.table], self.table.c.name == 'jack') + s = select([table], table.c.name == 'jack') self.runtest( select( [s], @@ -68,7 +56,7 @@ myothertable.othername FROM mytable, myothertable") , "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name) WHERE myid = :myid") - sq = Select([self.table]) + sq = Select([table]) self.runtest( sq.select(), "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable)" @@ -76,7 +64,7 @@ myothertable.othername FROM mytable, myothertable") sq = subquery( 'sq', - [self.table], + [table], ) self.runtest( @@ -87,8 +75,8 @@ myothertable.othername FROM mytable, myothertable") sq = subquery( 'sq', - [self.table, self.table2], - and_(self.table.c.id ==7, self.table2.c.id==self.table.c.id), + [table, table2], + and_(table.c.id ==7, table2.c.id==table.c.id), use_labels = True ) @@ -115,15 +103,15 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") s def testand(self): self.runtest( - select(['*'], and_(self.table.c.id == 12, self.table.c.name=='asdf', self.table2.c.name == 'foo', "sysdate() = today()")), + select(['*'], and_(table.c.id == 12, table.c.name=='asdf', table2.c.name == 'foo', "sysdate() = today()")), "SELECT * FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name AND myothertable.othername = :myothertable_othername AND sysdate() = today()" ) def testor(self): self.runtest( - select([self.table], and_( - self.table.c.id == 12, - or_(self.table2.c.name=='asdf', self.table2.c.name == 'foo', self.table2.c.id == 9), + select([table], and_( + table.c.id == 12, + or_(table2.c.name=='asdf', table2.c.name == 'foo', table2.c.id == 9), "sysdate() = today()", )), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND (myothertable.othername = :myothertable_othername OR myothertable.othername = :myothertable_othername_1 OR myothertable.otherid = :myothertable_otherid) AND sysdate() = today()" @@ -132,25 +120,25 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") s def testmultiparam(self): self.runtest( - select(["*"], or_(self.table.c.id == 12, self.table.c.id=='asdf', self.table.c.id == 'foo')), + select(["*"], or_(table.c.id == 12, table.c.id=='asdf', table.c.id == 'foo')), "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid OR mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2" ) def testorderby(self): self.runtest( - self.table2.select(order_by = [self.table2.c.id, asc(self.table2.c.name)]), + table2.select(order_by = [table2.c.id, asc(table2.c.name)]), "SELECT myothertable.otherid, myothertable.othername FROM myothertable ORDER BY myothertable.otherid, myothertable.othername ASC" ) def testalias(self): # test the alias for a table. column names stay the same, table name "changes" to "foo". self.runtest( - select([alias(self.table, 'foo')]) + select([alias(table, 'foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo") # create a select for a join of two tables. use_labels means the column names will have # labels tablename_columnname, which become the column keys accessible off the Selectable object. # also, only use one column from the second table and all columns from the first table. - q = select([self.table, self.table2.c.id], self.table.c.id == self.table2.c.id, use_labels = True) + q = select([table, table2.c.id], table.c.id == table2.c.id, use_labels = True) # make an alias of the "selectable". column names stay the same (i.e. the labels), table name "changes" to "t2view". a = alias(q, 't2view') @@ -177,11 +165,11 @@ WHERE mytable.myid = myothertable.otherid) t2view WHERE t2view.mytable_myid = :t def testliteralmix(self): self.runtest(select( - [self.table, self.table2.c.id, "sysdate()", "foo, bar, lala"], + [table, table2.c.id, "sysdate()", "foo, bar, lala"], and_( "foo.id = foofoo(lala)", "datetime(foo) = Today", - self.table.c.id == self.table2.c.id, + table.c.id == table2.c.id, ) ), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, sysdate(), foo, bar, lala \ @@ -189,7 +177,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today def testliteralsubquery(self): self.runtest(select( - [alias(self.table, 't'), "foo.f"], + [alias(table, 't'), "foo.f"], "foo.f = t.id", from_obj = ["(select f from bar where lala=heyhey) foo"] ), @@ -197,38 +185,38 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today def testjoin(self): self.runtest( - join(self.table2, self.table, self.table.c.id == self.table2.c.id).select(), + join(table2, table, table.c.id == table2.c.id).select(), "SELECT myothertable.otherid, myothertable.othername, mytable.myid, mytable.name, mytable.description \ FROM myothertable, mytable WHERE mytable.myid = myothertable.otherid" ) self.runtest( select( - [self.table], - from_obj = [join(self.table, self.table2, self.table.c.id == self.table2.c.id)] + [table], + from_obj = [join(table, table2, table.c.id == table2.c.id)] ), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") self.runtest( select( - [join(join(self.table, self.table2, self.table.c.id == self.table2.c.id), self.table3, self.table.c.id == self.table3.c.id) + [join(join(table, table2, table.c.id == table2.c.id), table3, table.c.id == table3.c.id) ]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid" ) def testmultijoin(self): self.runtest( - select([self.table, self.table2, self.table3], - from_obj = [outerjoin(join(self.table, self.table2, self.table.c.id == self.table2.c.id), self.table3, self.table.c.id==self.table3.c.id)] + select([table, table2, table3], + from_obj = [outerjoin(join(table, table2, table.c.id == table2.c.id), table3, table.c.id==table3.c.id)] ) ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON mytable.myid = thirdtable.userid" ) def testunion(self): x = union( - select([self.table], self.table.c.id == 5), - select([self.table], self.table.c.id == 12), - order_by = [self.table.c.id], + select([table], table.c.id == 5), + select([table], table.c.id == 12), + order_by = [table.c.id], ) self.runtest(x, "SELECT mytable.myid, mytable.name, mytable.description \ @@ -238,9 +226,9 @@ FROM mytable WHERE mytable.myid = :mytable_myid_1 ORDER BY mytable.myid") self.runtest( union( - select([self.table]), - select([self.table2]), - select([self.table3]) + select([table]), + select([table2]), + select([table3]) ) , "SELECT mytable.myid, mytable.name, mytable.description \ @@ -255,14 +243,14 @@ FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thi # parameters. query = select( - [self.table, self.table2], + [table, table2], and_( - self.table.c.name == 'fred', - self.table.c.id == 10, - self.table2.c.name != 'jack', + table.c.name == 'fred', + table.c.id == 10, + table2.c.name != 'jack', "EXISTS (select yay from foo where boo = lar)" ), - from_obj = [ outerjoin(self.table, self.table2, self.table.c.id == self.table2.c.id) ] + from_obj = [ outerjoin(table, table2, table.c.id == table2.c.id) ] ) self.runtest(query, @@ -286,9 +274,9 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo def testbindparam(self): #return self.runtest(select( - [self.table, self.table2], - and_(self.table.c.id == self.table2.c.id, - self.table.c.name == bindparam('mytablename'), + [table, table2], + and_(table.c.id == table2.c.id, + table.c.name == bindparam('mytablename'), ) ), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ @@ -298,33 +286,45 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable def testinsert(self): # generic insert, will create bind params for all columns - self.runtest(insert(self.table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") + self.runtest(insert(table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") # insert with user-supplied bind params for specific columns, # cols provided literally self.runtest( - insert(self.table, {self.table.c.id : bindparam('userid'), self.table.c.name : bindparam('username')}), + insert(table, {table.c.id : bindparam('userid'), table.c.name : bindparam('username')}), "INSERT INTO mytable (myid, name) VALUES (:userid, :username)") # insert with user-supplied bind params for specific columns, cols # provided as strings self.runtest( - insert(self.table, dict(id = 3, name = 'jack')), + insert(table, dict(id = 3, name = 'jack')), "INSERT INTO mytable (myid, name) VALUES (:myid, :name)" ) # insert with a subselect provided #self.runtest( - # insert(self.table, select([self.table2])), + # insert(table, select([table2])), # "" #) def testupdate(self): - self.runtest(update(self.table, self.table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {self.table.c.name:'fred'}) - self.runtest(update(self.table, self.table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'}) - + self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table.c.name:'fred'}) + self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'}) + self.runtest(update(table, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid") + self.runtest(update(table, table.c.id == 12, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'}) + + def testcorrelatedupdate(self): + # test against a straight text subquery + u = update(table, values = {table.c.name : TextClause("select name from mytable where id=mytable.id")}) + self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") + + # test against a regular constructed subquery + s = select([table2], table2.c.id == table.c.id) + u = update(table, table.c.name == 'jack', values = {table.c.name : s}) + self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name") + def testdelete(self): - self.runtest(delete(self.table, self.table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") + self.runtest(delete(table, table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") def runtest(self, clause, result, engine = None, params = None): -- 2.47.2