From: Mike Bayer Date: Sun, 26 Nov 2006 02:36:27 +0000 (+0000) Subject: - made kwargs parsing to Table strict; removed various obsoluete "redefine=True"... X-Git-Tag: rel_0_3_2~38 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b6b0130646b677e507d2fb461829ed5d72658000;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - made kwargs parsing to Table strict; removed various obsoluete "redefine=True" kw's from the unit tests - documented instance variables in ANSICompiler - fixed [ticket:120], adds "inline_params" set to ANSICompiler which DefaultDialect picks up on when determining defaults. added unittests to query.py - additionally fixed up the behavior of the "values" parameter on _Insert/_Update - more cleanup to sql/Select - more succinct organization of FROM clauses, removed silly _process_from_dict methods and JoinMarker object --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 2e0fe6e347..e470a2101c 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -70,20 +70,64 @@ class ANSICompiler(sql.Compiled): actual compilation, as in the case of an INSERT where the actual columns inserted will correspond to the keys present in the parameters.""" sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs) + + # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} + + # a dictionary which stores the string representation for every ClauseElement + # processed by this compiler. + self.strings = {} + + # a dictionary which stores the string representation for ClauseElements + # processed by this compiler, which are to be used in the FROM clause + # of a select. items are often placed in "froms" as well as "strings" + # and sometimes with different representations. self.froms = {} + + # slightly hacky. maps FROM clauses to WHERE clauses, and used in select + # generation to modify the WHERE clause of the select. currently a hack + # used by the oracle module. self.wheres = {} - self.strings = {} + + # when the compiler visits a SELECT statement, the clause object is appended + # to this stack. various visit operations will check this stack to determine + # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only + # apply to the topmost-level SELECT statement ?) self.select_stack = [] + + # a dictionary of result-set column names (strings) to TypeEngine instances, + # which will be passed to a ResultProxy and used for resultset-level value conversion self.typemap = {} + + # True if this compiled represents an INSERT self.isinsert = False + + # True if this compiled represents an UPDATE self.isupdate = False + + # default formatting style for bind parameters self.bindtemplate = ":%s" + + # paramstyle from the dialect (comes from DBAPI) self.paramstyle = dialect.paramstyle + + # true if the paramstyle is positional self.positional = dialect.positional + + # a list of the compiled's bind parameter names, used to help + # formulate a positional argument list self.positiontup = [] + + # an ANSIIdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer + # for UPDATE and INSERT statements, a set of columns whos values are being set + # from a SQL expression (i.e., not one of the bind parameter values). if present, + # default-value logic in the Dialect knows not to fire off column defaults + # and also knows postfetching will be needed to get the values represented by these + # parameters. + self.inline_params = None + def after_compile(self): # this re will search for params like :param # it has a negative lookbehind for an extra ':' so that it doesnt match @@ -295,13 +339,10 @@ class ANSICompiler(sql.Compiled): def visit_select(self, select): # the actual list of columns to print in the SELECT column list. - # its an ordered dictionary to insure that the actual labeled column name - # is unique. inner_columns = util.OrderedDict() self.select_stack.append(select) for c in select._raw_columns: - # TODO: make this polymorphic? if isinstance(c, sql.Select) and c.is_scalar: c.accept_visitor(self) inner_columns[self.get_str(c)] = c @@ -431,7 +472,6 @@ class ANSICompiler(sql.Compiled): self.strings[table] = "" def visit_join(self, join): - # TODO: ppl are going to want RIGHT, FULL OUTER and NATURAL joins. righttext = self.get_from_text(join.right) if join.right._group_parenthesized(): righttext = "(" + righttext + ")" @@ -488,13 +528,15 @@ class ANSICompiler(sql.Compiled): self.isinsert = True colparams = self._get_colparams(insert_stmt, default_params) - def create_param(p): + self.inline_params = util.Set() + def create_param(col, p): if isinstance(p, sql._BindParamClause): self.binds[p.key] = p if p.shortname is not None: self.binds[p.shortname] = p return self.bindparam_string(p.key) else: + self.inline_params.add(col) p.accept_visitor(self) if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): return "(" + self.get_str(p) + ")" @@ -502,7 +544,7 @@ class ANSICompiler(sql.Compiled): return self.get_str(p) text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + - " VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")") + " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")") self.strings[insert_stmt] = text @@ -520,19 +562,22 @@ class ANSICompiler(sql.Compiled): self.isupdate = True colparams = self._get_colparams(update_stmt, default_params) - def create_param(p): + + self.inline_params = util.Set() + def create_param(col, p): if isinstance(p, sql._BindParamClause): self.binds[p.key] = p self.binds[p.shortname] = p return self.bindparam_string(p.key) else: p.accept_visitor(self) + self.inline_params.add(col) if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): return "(" + self.get_str(p) + ")" else: return self.get_str(p) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(c[1])) for c in colparams], ', ') + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ') if update_stmt.whereclause: text += " WHERE " + self.get_str(update_stmt.whereclause) @@ -541,55 +586,51 @@ class ANSICompiler(sql.Compiled): def _get_colparams(self, stmt, default_params): - """determines the VALUES or SET clause for an INSERT or UPDATE - clause based on the arguments specified to this ANSICompiler object - (i.e., the execute() or compile() method clause object): - - insert(mytable).execute(col1='foo', col2='bar') - mytable.update().execute(col2='foo', col3='bar') - - in the above examples, the insert() and update() methods have no "values" sent to them - at all, so compiling them with no arguments would yield an insert for all table columns, - or an update with no SET clauses. but the parameters sent indicate a set of per-compilation - arguments that result in a differently compiled INSERT or UPDATE object compared to the - original. The "values" parameter to the insert/update is figured as well if present, - but the incoming "parameters" sent here take precedence. + """organize UPDATE/INSERT SET/VALUES parameters into a list of tuples, + each tuple containing the Column and a ClauseElement representing the + value to be set (usually a _BindParamClause, but could also be other + SQL expressions.) + + the list of tuples will determine the columns that are actually rendered + into the SET/VALUES clause of the rendered UPDATE/INSERT statement. It will + also determine how to generate the list/dictionary of bind parameters at + execution time (i.e. get_params()). + + this list takes into account the "values" keyword specified to the statement, + the parameters sent to this Compiled instance, and the default bind parameter + values corresponding to the dialect's behavior for otherwise unspecified + primary key columns. """ - # case one: no parameters in the statement, no parameters in the - # compiled params - just return binds for all the table columns + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns if self.parameters is None and stmt.parameters is None: return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns] + def to_col(key): + if not isinstance(key, sql._ColumnClause): + return stmt.table.columns.get(str(key), key) + else: + return key + # if we have statement parameters - set defaults in the # compiled params if self.parameters is None: parameters = {} else: - parameters = self.parameters.copy() + parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()]) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): - parameters.setdefault(k, v) + parameters.setdefault(to_col(k), v) for k, v in default_params.iteritems(): - parameters.setdefault(k, v) - - # now go thru compiled params, get the Column object for each key - d = {} - for key, value in parameters.iteritems(): - if isinstance(key, sql._ColumnClause): - d[key] = value - else: - try: - d[stmt.table.columns[str(key)]] = value - except KeyError: - pass + parameters.setdefault(to_col(k), v) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: - if d.has_key(c): - value = d[c] + if parameters.has_key(c): + value = parameters[c] if sql._is_literal(value): value = sql.bindparam(c.key, value, type=c.type) values.append((c, value)) diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index f009034c75..45d6e2cbc7 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -5,4 +5,4 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -__all__ = ['oracle', 'postgres', 'sqlite', 'mysql', 'mssql'] +__all__ = ['oracle', 'postgres', 'sqlite', 'mysql', 'mssql', 'firebird'] diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4af539e784..d5cb0cc9f7 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -175,6 +175,7 @@ class DefaultExecutionContext(base.ExecutionContext): visit_update methods that add the appropriate column clauses to the statement when its being compiled, so that these parameters can be bound to the statement.""" if compiled is None: return + if getattr(compiled, "isinsert", False): if isinstance(parameters, list): plist = parameters @@ -185,8 +186,20 @@ class DefaultExecutionContext(base.ExecutionContext): for param in plist: last_inserted_ids = [] need_lastrowid=False + # check the "default" status of each column in the table for c in compiled.statement.table.c: - if not param.has_key(c.key) or param[c.key] is None: + # check if it will be populated by a SQL clause - we'll need that + # after execution. + if c in compiled.inline_params: + self._lastrow_has_defaults = True + if c.primary_key: + need_lastrowid = True + # check if its not present at all. see if theres a default + # and fire it off, and add to bind parameters. if + # its a pk, add the value to our last_inserted_ids list, + # or, if its a SQL-side default, dont do any of that, but we'll need + # the SQL-generated value after execution. + elif not param.has_key(c.key) or param[c.key] is None: if isinstance(c.default, schema.PassiveDefault): self._lastrow_has_defaults = True newid = drunner.get_column_default(c) @@ -196,13 +209,14 @@ class DefaultExecutionContext(base.ExecutionContext): last_inserted_ids.append(param[c.key]) elif c.primary_key: need_lastrowid = True + # its an explicitly passed pk value - add it to + # our last_inserted_ids list. elif c.primary_key: last_inserted_ids.append(param[c.key]) if need_lastrowid: self._last_inserted_ids = None else: self._last_inserted_ids = last_inserted_ids - #print "LAST INSERTED PARAMS", param self._last_inserted_params = param elif getattr(compiled, 'isupdate', False): if isinstance(parameters, list): @@ -212,8 +226,15 @@ class DefaultExecutionContext(base.ExecutionContext): drunner = self.dialect.defaultrunner(engine, proxy) self._lastrow_has_defaults = False for param in plist: + # check the "onupdate" status of each column in the table for c in compiled.statement.table.c: - if c.onupdate is not None and (not param.has_key(c.key) or param[c.key] is None): + # it will be populated by a SQL clause - we'll need that + # after execution. + if c in compiled.inline_params: + pass + # its not in the bind parameters, and theres an "onupdate" defined for the column; + # execute it and add to bind params + elif c.onupdate is not None and (not param.has_key(c.key) or param[c.key] is None): value = drunner.get_column_onupdate(c) if value is not None: param[c.key] = value diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index d9a7684e72..8ebeaea27d 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -14,7 +14,7 @@ structure with its own clause-specific objects as well as the visitor interface, the schema package "plugs in" to the SQL package. """ -from sqlalchemy import sql, types, exceptions,util +from sqlalchemy import sql, types, exceptions,util, databases import sqlalchemy import copy, re, string @@ -125,7 +125,7 @@ class _TableSingleton(type): table = metadata.tables[key] if len(args): if not useexisting: - raise exceptions.ArgumentError("Table '%s.%s' is already defined for this MetaData instance." % (schema, name)) + raise exceptions.ArgumentError("Table '%s' is already defined for this MetaData instance." % key) return table except KeyError: if mustexist: @@ -183,8 +183,7 @@ class Table(SchemaItem, sql.TableClause): else an exception is raised. useexisting=False : indicates that if this Table was already defined elsewhere in the application, disregard - the rest of the constructor arguments. If this flag and the "redefine" flag are not set, constructing - the same table twice will result in an exception. + the rest of the constructor arguments. owner=None : optional owning user of this table. useful for databases such as Oracle to aid in table reflection. @@ -207,8 +206,8 @@ class Table(SchemaItem, sql.TableClause): self.indexes = util.Set() self.constraints = util.Set() self.primary_key = PrimaryKeyConstraint() - self.quote = kwargs.get('quote', False) - self.quote_schema = kwargs.get('quote_schema', False) + self.quote = kwargs.pop('quote', False) + self.quote_schema = kwargs.pop('quote_schema', False) if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: @@ -217,8 +216,13 @@ class Table(SchemaItem, sql.TableClause): self._set_casing_strategy(name, kwargs) self._set_casing_strategy(self.schema or '', kwargs, keyname='case_sensitive_schema') + + if len([k for k in kwargs if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]): + raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys())) + + # store extra kwargs, which should only contain db-specific options self.kwargs = kwargs - + def _get_case_sensitive_schema(self): try: return getattr(self, '_case_sensitive_schema') diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 8605d5c0c5..b3d61dc7e8 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -423,14 +423,9 @@ class ClauseElement(object): FROM list of a query, when this ClauseElement is placed in the column clause of a Select statement.""" raise NotImplementedError(repr(self)) - def _process_from_dict(self, data, asfrom): - """given a dictionary attached to a Select object, places the appropriate - FROM objects in the dictionary corresponding to this ClauseElement, - and possibly removes or modifies others.""" - for f in self._get_from_objects(): - data.setdefault(f, f) - if asfrom: - data[self] = self + def _hide_froms(self): + """return a list of FROM clause elements which this ClauseElement replaces.""" + return [] def compare(self, other): """compare this ClauseElement to the given ClauseElement. @@ -832,8 +827,9 @@ class _BindParamClause(ClauseElement, _CompareMixin): return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ def _make_proxy(self, selectable, name = None): return self -# return self.obj._make_proxy(selectable, name=self.name) - + def __repr__(self): + return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type)) + class _TypeClause(ClauseElement): """handles a type keyword in a SQL statement. used by the Case statement.""" def __init__(self, type): @@ -966,11 +962,6 @@ class _CalculatedClause(ClauseList, ColumnElement): self._engine = kwargs.get('engine', None) ClauseList.__init__(self, *clauses) key = property(lambda self:self.name or "_calc_") - def _process_from_dict(self, data, asfrom): - super(_CalculatedClause, self)._process_from_dict(data, asfrom) - # this helps a Select object get the engine from us - if asfrom: - data.setdefault(self, self) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return _CalculatedClause(type=self.type, engine=self._engine, *clauses) @@ -1156,25 +1147,13 @@ class Join(FromClause): engine = property(lambda s:s.left.engine or s.right.engine) - class JoinMarker(FromClause): - def __init__(self, join): - FromClause.__init__(self) - self.join = join - def _exportable_columns(self): - return [] - def alias(self, name=None): """creates a Select out of this Join clause and returns an Alias of it. The Select is not correlating.""" return self.select(use_labels=True, correlate=False).alias(name) - def _process_from_dict(self, data, asfrom): - for f in self.onclause._get_from_objects(): - data[f] = f - for f in self.left._get_from_objects() + self.right._get_from_objects(): - # mark the object as a "blank" "from" that wont be printed - data[f] = Join.JoinMarker(self) - # a JOIN always impacts the final FROM list of a select statement - data[self] = self - + + def _hide_froms(self): + return self.left._get_from_objects() + self.right._get_from_objects() + def _get_from_objects(self): return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects() @@ -1323,11 +1302,6 @@ class TableClause(FromClause): raise NotImplementedError() def _group_parenthesized(self): return False - def _process_from_dict(self, data, asfrom): - for f in self._get_from_objects(): - data.setdefault(f, f) - if asfrom: - data[self] = self def count(self, whereclause=None, **params): if len(self.primary_key): col = list(self.primary_key)[0] @@ -1443,7 +1417,8 @@ class Select(_SelectBaseMixin, FromClause): the ability to execute itself and return a result set.""" def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): _SelectBaseMixin.__init__(self) - self.__froms = util.OrderedDict() + self.__froms = util.OrderedSet() + self.__hide_froms = util.Set([self]) self.use_labels = use_labels self.whereclause = None self.having = None @@ -1526,7 +1501,7 @@ class Select(_SelectBaseMixin, FromClause): # visit the FROM objects of the column looking for more Selects for f in column._get_from_objects(): f.accept_visitor(self.__correlator) - column._process_from_dict(self.__froms, False) + self._process_froms(column, False) def _exportable_columns(self): return self._raw_columns @@ -1535,6 +1510,15 @@ class Select(_SelectBaseMixin, FromClause): return column._make_proxy(self, name=column._label) else: return column._make_proxy(self, name=column.name) + + def _process_froms(self, elem, asfrom): + for f in elem._get_from_objects(): + self.__froms.add(f) + if asfrom: + self.__froms.add(elem) + for f in elem._hide_froms(): + self.__hide_froms.add(f) + def append_whereclause(self, whereclause): self._append_condition('whereclause', whereclause) def append_having(self, having): @@ -1543,7 +1527,7 @@ class Select(_SelectBaseMixin, FromClause): if type(condition) == str: condition = _TextClause(condition) condition.accept_visitor(self.__wherecorrelator) - condition._process_from_dict(self.__froms, False) + self._process_froms(condition, False) if getattr(self, attribute) is not None: setattr(self, attribute, and_(getattr(self, attribute), condition)) else: @@ -1560,9 +1544,10 @@ class Select(_SelectBaseMixin, FromClause): if type(fromclause) == str: fromclause = _TextClause(fromclause) fromclause.accept_visitor(self.__correlator) - fromclause._process_from_dict(self.__froms, True) + self._process_froms(fromclause, True) + def _locate_oid_column(self): - for f in self.__froms.values(): + for f in self.__froms: if f is self: # we might be in our own _froms list if a column with us as the parent is attached, # which includes textual columns. @@ -1572,16 +1557,11 @@ class Select(_SelectBaseMixin, FromClause): return oid else: return None - def _get_froms(self): - return [f for f in self.__froms.values() if f is not self and (f not in self.__correlated)] - froms = property(lambda s: s._get_froms(), doc="""a list containing all elements of the FROM clause""") + + froms = property(lambda self: self.__froms.difference(self.__hide_froms).difference(self.__correlated), doc="""a collection containing all elements of the FROM clause""") def accept_visitor(self, visitor): - # TODO: add contextual visit_ methods - # visit_select_whereclause, visit_select_froms, visit_select_orderby, etc. - # which will allow the compiler to set contextual flags before traversing - # into each thing. - for f in self._get_froms(): + for f in self.froms: f.accept_visitor(visitor) if self.whereclause is not None: self.whereclause.accept_visitor(visitor) @@ -1601,7 +1581,7 @@ class Select(_SelectBaseMixin, FromClause): if self._engine is not None: return self._engine - for f in self.__froms.values(): + for f in self.__froms: if f is self: continue e = f.engine diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py index f87cbb46ed..75bc34f502 100644 --- a/test/ext/activemapper.py +++ b/test/ext/activemapper.py @@ -11,6 +11,7 @@ import sqlalchemy.ext.activemapper as activemapper class testcase(testbase.PersistTest): def setUpAll(self): sqlalchemy.clear_mappers() + objectstore.clear() global Person, Preferences, Address class Person(ActiveMapper): @@ -260,6 +261,8 @@ class testcase(testbase.PersistTest): class testmanytomany(testbase.PersistTest): def setUpAll(self): + sqlalchemy.clear_mappers() + objectstore.clear() global secondarytable, foo, baz secondarytable = Table("secondarytable", activemapper.metadata, @@ -315,6 +318,8 @@ class testmanytomany(testbase.PersistTest): class testselfreferential(testbase.PersistTest): def setUpAll(self): + sqlalchemy.clear_mappers() + objectstore.clear() global TreeNode class TreeNode(activemapper.ActiveMapper): class mapping: diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index e6a1060aa1..0034b31b18 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -1341,32 +1341,27 @@ class SaveTest2(UnitOfWorkTest): def setUp(self): ctx.current.clear() clear_mappers() - self.users = Table('users', db, + global meta, users, addresses + meta = BoundMetaData(db) + users = Table('users', meta, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', String(20)), - redefine=True ) - self.addresses = Table('email_addresses', db, + addresses = Table('email_addresses', meta, Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), - Column('rel_user_id', Integer, ForeignKey(self.users.c.user_id)), + Column('rel_user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(20)), - redefine=True ) - x = sql.join(self.users, self.addresses) -# raise repr(self.users) + repr(self.users.primary_key) -# raise repr(self.addresses) + repr(self.addresses.foreign_keys) - self.users.create() - self.addresses.create() + meta.create_all() def tearDown(self): - self.addresses.drop() - self.users.drop() + meta.drop_all() UnitOfWorkTest.tearDown(self) def testbackwardsnonmatch(self): - m = mapper(Address, self.addresses, properties = dict( - user = relation(mapper(User, self.users), lazy = True, uselist = False) + m = mapper(Address, addresses, properties = dict( + user = relation(mapper(User, users), lazy = True, uselist = False) )) data = [ {'user_name' : 'thesub' , 'email_address' : 'bar@foo.com'}, diff --git a/test/sql/query.py b/test/sql/query.py index d88b2bf83f..a19b8cf25f 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -5,18 +5,17 @@ import unittest, sys, datetime import sqlalchemy.databases.sqlite as sqllite import tables -db = testbase.db from sqlalchemy import * from sqlalchemy.engine import ResultProxy, RowProxy class QueryTest(PersistTest): def setUpAll(self): - global users - users = Table('query_users', db, + global users, metadata + metadata = BoundMetaData(testbase.db) + users = Table('query_users', metadata, Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), - redefine = True ) users.create() @@ -71,16 +70,16 @@ class QueryTest(PersistTest): default_metadata.drop_all() default_metadata.clear() + @testbase.supported('postgres') def testpassiveoverride(self): """primarily for postgres, tests that when we get a primary key column back from reflecting a table which has a default value on it, we pre-execute that PassiveDefault upon insert, even though PassiveDefault says "let the database execute this", because in postgres we must have all the primary key values in memory before insert; otherwise we cant locate the just inserted row.""" - if db.engine.name != 'postgres': - return try: - db.execute(""" + meta = BoundMetaData(testbase.db) + testbase.db.execute(""" CREATE TABLE speedy_users ( speedy_user_id SERIAL PRIMARY KEY, @@ -90,19 +89,17 @@ class QueryTest(PersistTest): ); """, None) - t = Table("speedy_users", db, autoload=True) + t = Table("speedy_users", meta, autoload=True) t.insert().execute(user_name='user', user_password='lala') l = t.select().execute().fetchall() - print l self.assert_(l == [(1, 'user', 'lala')]) finally: - db.execute("drop table speedy_users", None) + testbase.db.execute("drop table speedy_users", None) + @testbase.supported('postgres') def testschema(self): - if not db.engine.__module__.endswith('postgres'): - return - - test_table = Table('my_table', db, + meta1 = BoundMetaData(testbase.db) + test_table = Table('my_table', meta1, Column('id', Integer, primary_key=True), Column('data', String(20), nullable=False), schema='alt_schema' @@ -112,9 +109,8 @@ class QueryTest(PersistTest): # plain insert test_table.insert().execute(data='test') - # try with a PassiveDefault - test_table.deregister() - test_table = Table('my_table', db, autoload=True, redefine=True, schema='alt_schema') + meta2 = BoundMetaData(testbase.db) + test_table = Table('my_table', meta2, autoload=True, schema='alt_schema') test_table.insert().execute(data='test') finally: @@ -187,10 +183,10 @@ class QueryTest(PersistTest): r = self.users.select().execute().fetchone() self.assertEqual(len(r), 2) r.close() - r = db.execute('select user_name, user_id from query_users', {}).fetchone() + r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone() self.assertEqual(len(r), 2) r.close() - r = db.execute('select user_name from query_users', {}).fetchone() + r = testbase.db.execute('select user_name from query_users', {}).fetchone() self.assertEqual(len(r), 1) r.close() @@ -200,6 +196,56 @@ class QueryTest(PersistTest): z = testbase.db.func.current_date().scalar() assert x == y == z + def test_update_functions(self): + """test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances, + and that column-level defaults get overridden""" + meta = BoundMetaData(testbase.db) + t = Table('t1', meta, + Column('id', Integer, primary_key=True), + Column('value', Integer) + ) + t2 = Table('t2', meta, + Column('id', Integer, primary_key=True), + Column('value', Integer, default="7"), + Column('stuff', String(20), onupdate="thisisstuff") + ) + meta.create_all() + try: + t.insert().execute(value=func.length("one")) + assert t.select().execute().fetchone()['value'] == 3 + t.update().execute(value=func.length("asfda")) + assert t.select().execute().fetchone()['value'] == 5 + + r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() + id = r.last_inserted_ids()[0] + assert t.select(t.c.id==id).execute().fetchone()['value'] == 9 + t.update(values={t.c.value:func.length("asdf")}).execute() + assert t.select().execute().fetchone()['value'] == 4 + + t2.insert().execute() + t2.insert().execute(value=func.length("one")) + t2.insert().execute(value=func.length("asfda") + -19, stuff="hi") + + assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(7,None), (3,None), (-14,"hi")] + + t2.update().execute(value=func.length("asdsafasd"), stuff="some stuff") + assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] + + t2.delete().execute() + + t2.insert(values=dict(value=func.length("one") + 8)).execute() + assert t2.select().execute().fetchone()['value'] == 11 + + t2.update(values=dict(value=func.length("asfda"))).execute() + assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff") + + t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() + print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone() + assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo") + + finally: + meta.drop_all() + @testbase.supported('postgres') def test_functions_with_cols(self): x = testbase.db.func.current_date().execute().scalar() @@ -226,7 +272,7 @@ class QueryTest(PersistTest): def test_column_order_with_text_query(self): # should return values in query order self.users.insert().execute(user_id=1, user_name='foo') - r = db.execute('select user_name, user_id from query_users', {}).fetchone() + r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone() self.assertEqual(r[0], 'foo') self.assertEqual(r[1], 1) self.assertEqual(r.keys(), ['user_name', 'user_id']) @@ -234,14 +280,14 @@ class QueryTest(PersistTest): @testbase.unsupported('oracle', 'firebird') def test_column_accessor_shadow(self): - shadowed = Table('test_shadowed', db, + meta = BoundMetaData(testbase.db) + shadowed = Table('test_shadowed', meta, Column('shadow_id', INT, primary_key = True), Column('shadow_name', VARCHAR(20)), Column('parent', VARCHAR(20)), Column('row', VARCHAR(40)), Column('__parent', VARCHAR(20)), Column('__row', VARCHAR(20)), - redefine = True ) shadowed.create() try: diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 0c2aa1b56d..cd434a1845 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -16,7 +16,7 @@ table = Table('table1', db, Column('col2', String(20)), Column('col3', Integer), Column('colx', Integer), - redefine=True + ) table2 = Table('table2', db, @@ -24,7 +24,6 @@ table2 = Table('table2', db, Column('col2', Integer, ForeignKey('table1.col1')), Column('col3', String(20)), Column('coly', Integer), - redefine=True ) class SelectableTest(testbase.AssertMixin): diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 1d70558114..2700ec6c79 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -265,7 +265,7 @@ class DateTest(AssertMixin): collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)), Column('user_date', Date), Column('user_time', Time)] - users_with_date = Table('query_users_with_date', db, redefine = True, *collist) + users_with_date = Table('query_users_with_date', db, *collist) users_with_date.create() insert_dicts = [dict(zip(fnames, d)) for d in insert_data]