]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- made kwargs parsing to Table strict; removed various obsoluete "redefine=True"...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Nov 2006 02:36:27 +0000 (02:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Nov 2006 02:36:27 +0000 (02:36 +0000)
- documented instance variables in ANSICompiler
- fixed [ticket:120], adds "inline_params" set to ANSICompiler which DefaultDialect picks up on when
determining defaults.  added unittests to query.py
- additionally fixed up the behavior of the "values" parameter on _Insert/_Update
- more cleanup to sql/Select - more succinct organization of FROM clauses, removed silly _process_from_dict
methods and JoinMarker object

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/__init__.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/ext/activemapper.py
test/orm/unitofwork.py
test/sql/query.py
test/sql/selectable.py
test/sql/testtypes.py

index 2e0fe6e34739a336a131e281ef049e571c9bfc2b..e470a2101c55282c1f539640f7ad4646ac1e344c 100644 (file)
@@ -70,20 +70,64 @@ class ANSICompiler(sql.Compiled):
         actual compilation, as in the case of an INSERT where the actual columns
         inserted will correspond to the keys present in the parameters."""
         sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+        
+        # a dictionary of bind parameter keys to _BindParamClause instances.
         self.binds = {}
+
+        # a dictionary which stores the string representation for every ClauseElement
+        # processed by this compiler.
+        self.strings = {}
+        
+        # a dictionary which stores the string representation for ClauseElements
+        # processed by this compiler, which are to be used in the FROM clause
+        # of a select.  items are often placed in "froms" as well as "strings"
+        # and sometimes with different representations.
         self.froms = {}
+        
+        # slightly hacky.  maps FROM clauses to WHERE clauses, and used in select 
+        # generation to modify the WHERE clause of the select.  currently a hack
+        # used by the oracle module.
         self.wheres = {}
-        self.strings = {}
+        
+        # when the compiler visits a SELECT statement, the clause object is appended
+        # to this stack.  various visit operations will check this stack to determine
+        # additional choices (TODO: it seems to be all typemap stuff.  shouldnt this only
+        # apply to the topmost-level SELECT statement ?)
         self.select_stack = []
+        
+        # a dictionary of result-set column names (strings) to TypeEngine instances,
+        # which will be passed to a ResultProxy and used for resultset-level value conversion
         self.typemap = {}
+        
+        # True if this compiled represents an INSERT
         self.isinsert = False
+        
+        # True if this compiled represents an UPDATE
         self.isupdate = False
+        
+        # default formatting style for bind parameters
         self.bindtemplate = ":%s"
+        
+        # paramstyle from the dialect (comes from DBAPI)
         self.paramstyle = dialect.paramstyle
+        
+        # true if the paramstyle is positional
         self.positional = dialect.positional
+        
+        # a list of the compiled's bind parameter names, used to help
+        # formulate a positional argument list
         self.positiontup = []
+        
+        # an ANSIIdentifierPreparer that formats the quoting of identifiers
         self.preparer = dialect.identifier_preparer
         
+        # for UPDATE and INSERT statements, a set of columns whos values are being set
+        # from a SQL expression (i.e., not one of the bind parameter values).  if present,
+        # default-value logic in the Dialect knows not to fire off column defaults
+        # and also knows postfetching will be needed to get the values represented by these
+        # parameters.
+        self.inline_params = None
+        
     def after_compile(self):
         # this re will search for params like :param
         # it has a negative lookbehind for an extra ':' so that it doesnt match
@@ -295,13 +339,10 @@ class ANSICompiler(sql.Compiled):
     def visit_select(self, select):
         
         # the actual list of columns to print in the SELECT column list.
-        # its an ordered dictionary to insure that the actual labeled column name
-        # is unique.
         inner_columns = util.OrderedDict()
 
         self.select_stack.append(select)
         for c in select._raw_columns:
-            # TODO: make this polymorphic?
             if isinstance(c, sql.Select) and c.is_scalar:
                 c.accept_visitor(self)
                 inner_columns[self.get_str(c)] = c
@@ -431,7 +472,6 @@ class ANSICompiler(sql.Compiled):
         self.strings[table] = ""
 
     def visit_join(self, join):
-        # TODO: ppl are going to want RIGHT, FULL OUTER and NATURAL joins.
         righttext = self.get_from_text(join.right)
         if join.right._group_parenthesized():
             righttext = "(" + righttext + ")"
@@ -488,13 +528,15 @@ class ANSICompiler(sql.Compiled):
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt, default_params)
 
-        def create_param(p):
+        self.inline_params = util.Set()
+        def create_param(col, p):
             if isinstance(p, sql._BindParamClause):
                 self.binds[p.key] = p
                 if p.shortname is not None:
                     self.binds[p.shortname] = p
                 return self.bindparam_string(p.key)
             else:
+                self.inline_params.add(col)
                 p.accept_visitor(self)
                 if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
                     return "(" + self.get_str(p) + ")"
@@ -502,7 +544,7 @@ class ANSICompiler(sql.Compiled):
                     return self.get_str(p)
 
         text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
-         " VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")")
+         " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")")
 
         self.strings[insert_stmt] = text
 
@@ -520,19 +562,22 @@ class ANSICompiler(sql.Compiled):
 
         self.isupdate = True
         colparams = self._get_colparams(update_stmt, default_params)
-        def create_param(p):
+
+        self.inline_params = util.Set()
+        def create_param(col, p):
             if isinstance(p, sql._BindParamClause):
                 self.binds[p.key] = p
                 self.binds[p.shortname] = p
                 return self.bindparam_string(p.key)
             else:
                 p.accept_visitor(self)
+                self.inline_params.add(col)
                 if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
                     return "(" + self.get_str(p) + ")"
                 else:
                     return self.get_str(p)
                 
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(c[1])) for c in colparams], ', ')
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
         
         if update_stmt.whereclause:
             text += " WHERE " + self.get_str(update_stmt.whereclause)
@@ -541,55 +586,51 @@ class ANSICompiler(sql.Compiled):
 
 
     def _get_colparams(self, stmt, default_params):
-        """determines the VALUES or SET clause for an INSERT or UPDATE
-        clause based on the arguments specified to this ANSICompiler object
-        (i.e., the execute() or compile() method clause object):
-
-        insert(mytable).execute(col1='foo', col2='bar')
-        mytable.update().execute(col2='foo', col3='bar')
-
-        in the above examples, the insert() and update() methods have no "values" sent to them
-        at all, so compiling them with no arguments would yield an insert for all table columns,
-        or an update with no SET clauses.  but the parameters sent indicate a set of per-compilation
-        arguments that result in a differently compiled INSERT or UPDATE object compared to the
-        original.  The "values" parameter to the insert/update is figured as well if present,
-        but the incoming "parameters" sent here take precedence.
+        """organize UPDATE/INSERT SET/VALUES parameters into a list of tuples, 
+        each tuple containing the Column and a ClauseElement representing the
+        value to be set (usually a _BindParamClause, but could also be other
+        SQL expressions.)
+
+        the list of tuples will determine the columns that are actually rendered
+        into the SET/VALUES clause of the rendered UPDATE/INSERT statement.  It will
+        also determine how to generate the list/dictionary of bind parameters at 
+        execution time (i.e. get_params()).
+        
+        this list takes into account the "values" keyword specified to the statement,
+        the parameters sent to this Compiled instance, and the default bind parameter
+        values corresponding to the dialect's behavior for otherwise unspecified 
+        primary key columns.
         """
-        # case one: no parameters in the statement, no parameters in the 
-        # compiled params - just return binds for all the table columns
+        # no parameters in the statement, no parameters in the 
+        # compiled params - return binds for all columns
         if self.parameters is None and stmt.parameters is None:
             return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns]
 
+        def to_col(key):
+            if not isinstance(key, sql._ColumnClause):
+                return stmt.table.columns.get(str(key), key)
+            else:
+                return key
+                
         # if we have statement parameters - set defaults in the 
         # compiled params
         if self.parameters is None:
             parameters = {}
         else:
-            parameters = self.parameters.copy()
+            parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()])
 
         if stmt.parameters is not None:
             for k, v in stmt.parameters.iteritems():
-                parameters.setdefault(k, v)
+                parameters.setdefault(to_col(k), v)
 
         for k, v in default_params.iteritems():
-            parameters.setdefault(k, v)
-            
-        # now go thru compiled params, get the Column object for each key
-        d = {}
-        for key, value in parameters.iteritems():
-            if isinstance(key, sql._ColumnClause):
-                d[key] = value
-            else:
-                try:
-                    d[stmt.table.columns[str(key)]] = value
-                except KeyError:
-                    pass
+            parameters.setdefault(to_col(k), v)
 
         # create a list of column assignment clauses as tuples
         values = []
         for c in stmt.table.columns:
-            if d.has_key(c):
-                value = d[c]
+            if parameters.has_key(c):
+                value = parameters[c]
                 if sql._is_literal(value):
                     value = sql.bindparam(c.key, value, type=c.type)
                 values.append((c, value))
index f009034c751b6f5337243da1c042529f640b5800..45d6e2cbc736189eeaccad0afe0fe67ce38b4ddb 100644 (file)
@@ -5,4 +5,4 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 
-__all__ = ['oracle', 'postgres', 'sqlite', 'mysql', 'mssql']
+__all__ = ['oracle', 'postgres', 'sqlite', 'mysql', 'mssql', 'firebird']
index 4af539e784021a44352d0b87bfc10d4b5d246e7d..d5cb0cc9f764a4bd3cd2914b4844756105b98cf5 100644 (file)
@@ -175,6 +175,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         visit_update methods that add the appropriate column clauses to the statement when its 
         being compiled, so that these parameters can be bound to the statement."""
         if compiled is None: return
+        
         if getattr(compiled, "isinsert", False):
             if isinstance(parameters, list):
                 plist = parameters
@@ -185,8 +186,20 @@ class DefaultExecutionContext(base.ExecutionContext):
             for param in plist:
                 last_inserted_ids = []
                 need_lastrowid=False
+                # check the "default" status of each column in the table
                 for c in compiled.statement.table.c:
-                    if not param.has_key(c.key) or param[c.key] is None:
+                    # check if it will be populated by a SQL clause - we'll need that
+                    # after execution.
+                    if c in compiled.inline_params:
+                        self._lastrow_has_defaults = True
+                        if c.primary_key:
+                            need_lastrowid = True
+                    # check if its not present at all.  see if theres a default
+                    # and fire it off, and add to bind parameters.  if 
+                    # its a pk, add the value to our last_inserted_ids list,
+                    # or, if its a SQL-side default, dont do any of that, but we'll need 
+                    # the SQL-generated value after execution.
+                    elif not param.has_key(c.key) or param[c.key] is None:
                         if isinstance(c.default, schema.PassiveDefault):
                             self._lastrow_has_defaults = True
                         newid = drunner.get_column_default(c)
@@ -196,13 +209,14 @@ class DefaultExecutionContext(base.ExecutionContext):
                                 last_inserted_ids.append(param[c.key])
                         elif c.primary_key:
                             need_lastrowid = True
+                    # its an explicitly passed pk value - add it to 
+                    # our last_inserted_ids list.
                     elif c.primary_key:
                         last_inserted_ids.append(param[c.key])
                 if need_lastrowid:
                     self._last_inserted_ids = None
                 else:
                     self._last_inserted_ids = last_inserted_ids
-                #print "LAST INSERTED PARAMS", param
                 self._last_inserted_params = param
         elif getattr(compiled, 'isupdate', False):
             if isinstance(parameters, list):
@@ -212,8 +226,15 @@ class DefaultExecutionContext(base.ExecutionContext):
             drunner = self.dialect.defaultrunner(engine, proxy)
             self._lastrow_has_defaults = False
             for param in plist:
+                # check the "onupdate" status of each column in the table 
                 for c in compiled.statement.table.c:
-                    if c.onupdate is not None and (not param.has_key(c.key) or param[c.key] is None):
+                    # it will be populated by a SQL clause - we'll need that
+                    # after execution.
+                    if c in compiled.inline_params:
+                        pass
+                    # its not in the bind parameters, and theres an "onupdate" defined for the column;
+                    # execute it and add to bind params
+                    elif c.onupdate is not None and (not param.has_key(c.key) or param[c.key] is None):
                         value = drunner.get_column_onupdate(c)
                         if value is not None:
                             param[c.key] = value
index d9a7684e72fc84d6d68959e2cfb3adc858ddeb32..8ebeaea27ddb7e520d8901754c814a53c5031cd8 100644 (file)
@@ -14,7 +14,7 @@ structure with its own clause-specific objects as well as the visitor interface,
 the schema package "plugs in" to the SQL package.
 
 """
-from sqlalchemy import sql, types, exceptions,util
+from sqlalchemy import sql, types, exceptions,util, databases
 import sqlalchemy
 import copy, re, string
 
@@ -125,7 +125,7 @@ class _TableSingleton(type):
             table = metadata.tables[key]
             if len(args):
                 if not useexisting:
-                    raise exceptions.ArgumentError("Table '%s.%s' is already defined for this MetaData instance." % (schema, name))
+                    raise exceptions.ArgumentError("Table '%s' is already defined for this MetaData instance." % key)
             return table
         except KeyError:
             if mustexist:
@@ -183,8 +183,7 @@ class Table(SchemaItem, sql.TableClause):
         else an exception is raised.
         
         useexisting=False : indicates that if this Table was already defined elsewhere in the application, disregard
-        the rest of the constructor arguments.  If this flag and the "redefine" flag are not set, constructing 
-        the same table twice will result in an exception.
+        the rest of the constructor arguments.  
         
         owner=None : optional owning user of this table.  useful for databases such as Oracle to aid in table
         reflection.
@@ -207,8 +206,8 @@ class Table(SchemaItem, sql.TableClause):
         self.indexes = util.Set()
         self.constraints = util.Set()
         self.primary_key = PrimaryKeyConstraint()
-        self.quote = kwargs.get('quote', False)
-        self.quote_schema = kwargs.get('quote_schema', False)
+        self.quote = kwargs.pop('quote', False)
+        self.quote_schema = kwargs.pop('quote_schema', False)
         if self.schema is not None:
             self.fullname = "%s.%s" % (self.schema, self.name)
         else:
@@ -217,8 +216,13 @@ class Table(SchemaItem, sql.TableClause):
 
         self._set_casing_strategy(name, kwargs)
         self._set_casing_strategy(self.schema or '', kwargs, keyname='case_sensitive_schema')
+        
+        if len([k for k in kwargs if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]):
+            raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
+        
+        # store extra kwargs, which should only contain db-specific options
         self.kwargs = kwargs
-
+        
     def _get_case_sensitive_schema(self):
         try:
             return getattr(self, '_case_sensitive_schema')
index 8605d5c0c5dd5fbb3be8fb867fa267327369ae24..b3d61dc7e8e6c42bdee35f468281e2ecdde7f6a7 100644 (file)
@@ -423,14 +423,9 @@ class ClauseElement(object):
         FROM list of a query, when this ClauseElement is placed in the column clause of a Select
         statement."""
         raise NotImplementedError(repr(self))
-    def _process_from_dict(self, data, asfrom):
-        """given a dictionary attached to a Select object, places the appropriate
-        FROM objects in the dictionary corresponding to this ClauseElement,
-        and possibly removes or modifies others."""
-        for f in self._get_from_objects():
-            data.setdefault(f, f)
-        if asfrom:
-            data[self] = self
+    def _hide_froms(self):
+        """return a list of FROM clause elements which this ClauseElement replaces."""
+        return []
     def compare(self, other):
         """compare this ClauseElement to the given ClauseElement.
         
@@ -832,8 +827,9 @@ class _BindParamClause(ClauseElement, _CompareMixin):
         return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
     def _make_proxy(self, selectable, name = None):
         return self
-#        return self.obj._make_proxy(selectable, name=self.name)
-
+    def __repr__(self):
+        return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+        
 class _TypeClause(ClauseElement):
     """handles a type keyword in a SQL statement.  used by the Case statement."""
     def __init__(self, type):
@@ -966,11 +962,6 @@ class _CalculatedClause(ClauseList, ColumnElement):
         self._engine = kwargs.get('engine', None)
         ClauseList.__init__(self, *clauses)
     key = property(lambda self:self.name or "_calc_")
-    def _process_from_dict(self, data, asfrom):
-        super(_CalculatedClause, self)._process_from_dict(data, asfrom)
-        # this helps a Select object get the engine from us
-        if asfrom:
-            data.setdefault(self, self)
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
         return _CalculatedClause(type=self.type, engine=self._engine, *clauses)
@@ -1156,25 +1147,13 @@ class Join(FromClause):
 
     engine = property(lambda s:s.left.engine or s.right.engine)
 
-    class JoinMarker(FromClause):
-        def __init__(self, join):
-            FromClause.__init__(self)
-            self.join = join
-        def _exportable_columns(self):
-            return []
-    
     def alias(self, name=None):
         """creates a Select out of this Join clause and returns an Alias of it.  The Select is not correlating."""
         return self.select(use_labels=True, correlate=False).alias(name)            
-    def _process_from_dict(self, data, asfrom):
-        for f in self.onclause._get_from_objects():
-            data[f] = f
-        for f in self.left._get_from_objects() + self.right._get_from_objects():
-            # mark the object as a "blank" "from" that wont be printed
-            data[f] = Join.JoinMarker(self)
-        # a JOIN always impacts the final FROM list of a select statement
-        data[self] = self
-        
+
+    def _hide_froms(self):
+        return self.left._get_from_objects() + self.right._get_from_objects()
+            
     def _get_from_objects(self):
         return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
         
@@ -1323,11 +1302,6 @@ class TableClause(FromClause):
         raise NotImplementedError()
     def _group_parenthesized(self):
         return False
-    def _process_from_dict(self, data, asfrom):
-        for f in self._get_from_objects():
-            data.setdefault(f, f)
-        if asfrom:
-            data[self] = self
     def count(self, whereclause=None, **params):
         if len(self.primary_key):
             col = list(self.primary_key)[0]
@@ -1443,7 +1417,8 @@ class Select(_SelectBaseMixin, FromClause):
     the ability to execute itself and return a result set."""
     def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True):
         _SelectBaseMixin.__init__(self)
-        self.__froms = util.OrderedDict()
+        self.__froms = util.OrderedSet()
+        self.__hide_froms = util.Set([self])
         self.use_labels = use_labels
         self.whereclause = None
         self.having = None
@@ -1526,7 +1501,7 @@ class Select(_SelectBaseMixin, FromClause):
         # visit the FROM objects of the column looking for more Selects
         for f in column._get_from_objects():
             f.accept_visitor(self.__correlator)
-        column._process_from_dict(self.__froms, False)
+        self._process_froms(column, False)
 
     def _exportable_columns(self):
         return self._raw_columns
@@ -1535,6 +1510,15 @@ class Select(_SelectBaseMixin, FromClause):
             return column._make_proxy(self, name=column._label)
         else:
             return column._make_proxy(self, name=column.name)
+            
+    def _process_froms(self, elem, asfrom):
+        for f in elem._get_from_objects():
+            self.__froms.add(f)
+        if asfrom:
+            self.__froms.add(elem)
+        for f in elem._hide_froms():
+            self.__hide_froms.add(f)
+            
     def append_whereclause(self, whereclause):
         self._append_condition('whereclause', whereclause)
     def append_having(self, having):
@@ -1543,7 +1527,7 @@ class Select(_SelectBaseMixin, FromClause):
         if type(condition) == str:
             condition = _TextClause(condition)
         condition.accept_visitor(self.__wherecorrelator)
-        condition._process_from_dict(self.__froms, False)
+        self._process_froms(condition, False)
         if getattr(self, attribute) is not None:
             setattr(self, attribute, and_(getattr(self, attribute), condition))
         else:
@@ -1560,9 +1544,10 @@ class Select(_SelectBaseMixin, FromClause):
         if type(fromclause) == str:
             fromclause = _TextClause(fromclause)
         fromclause.accept_visitor(self.__correlator)
-        fromclause._process_from_dict(self.__froms, True)
+        self._process_froms(fromclause, True)
+        
     def _locate_oid_column(self):
-        for f in self.__froms.values():
+        for f in self.__froms:
             if f is self:
                 # we might be in our own _froms list if a column with us as the parent is attached,
                 # which includes textual columns. 
@@ -1572,16 +1557,11 @@ class Select(_SelectBaseMixin, FromClause):
                 return oid
         else:
             return None
-    def _get_froms(self):
-        return [f for f in self.__froms.values() if f is not self and (f not in self.__correlated)]
-    froms = property(lambda s: s._get_froms(), doc="""a list containing all elements of the FROM clause""")
+
+    froms = property(lambda self: self.__froms.difference(self.__hide_froms).difference(self.__correlated), doc="""a collection containing all elements of the FROM clause""")
 
     def accept_visitor(self, visitor):
-        # TODO: add contextual visit_ methods
-        # visit_select_whereclause, visit_select_froms, visit_select_orderby, etc.
-        # which will allow the compiler to set contextual flags before traversing 
-        # into each thing.  
-        for f in self._get_froms():
+        for f in self.froms:
             f.accept_visitor(visitor)
         if self.whereclause is not None:
             self.whereclause.accept_visitor(visitor)
@@ -1601,7 +1581,7 @@ class Select(_SelectBaseMixin, FromClause):
         
         if self._engine is not None:
             return self._engine
-        for f in self.__froms.values():
+        for f in self.__froms:
             if f is self:
                 continue
             e = f.engine
index f87cbb46ed5074f2a8512b98f3f3aad31f275163..75bc34f502322cafb733f8821c26f168f5ad39c6 100644 (file)
@@ -11,6 +11,7 @@ import sqlalchemy.ext.activemapper as activemapper
 class testcase(testbase.PersistTest):
     def setUpAll(self):
         sqlalchemy.clear_mappers()
+        objectstore.clear()
         global Person, Preferences, Address
         
         class Person(ActiveMapper):
@@ -260,6 +261,8 @@ class testcase(testbase.PersistTest):
 
 class testmanytomany(testbase.PersistTest):
      def setUpAll(self):
+         sqlalchemy.clear_mappers()
+         objectstore.clear()
          global secondarytable, foo, baz
          secondarytable = Table("secondarytable",
              activemapper.metadata,
@@ -315,6 +318,8 @@ class testmanytomany(testbase.PersistTest):
         
 class testselfreferential(testbase.PersistTest):
     def setUpAll(self):
+        sqlalchemy.clear_mappers()
+        objectstore.clear()
         global TreeNode
         class TreeNode(activemapper.ActiveMapper):
             class mapping:
index e6a1060aa1f4d469d91d0167a721d804c07d7cff..0034b31b18d3406169770168ed7f13d0b39a4ddb 100644 (file)
@@ -1341,32 +1341,27 @@ class SaveTest2(UnitOfWorkTest):
     def setUp(self):
         ctx.current.clear()
         clear_mappers()
-        self.users = Table('users', db,
+        global meta, users, addresses
+        meta = BoundMetaData(db)
+        users = Table('users', meta,
             Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
             Column('user_name', String(20)),
-            redefine=True
         )
 
-        self.addresses = Table('email_addresses', db,
+        addresses = Table('email_addresses', meta,
             Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
-            Column('rel_user_id', Integer, ForeignKey(self.users.c.user_id)),
+            Column('rel_user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(20)),
-            redefine=True
         )
-        x = sql.join(self.users, self.addresses)
-#        raise repr(self.users) + repr(self.users.primary_key)
-#        raise repr(self.addresses) + repr(self.addresses.foreign_keys)
-        self.users.create()
-        self.addresses.create()
+        meta.create_all()
 
     def tearDown(self):
-        self.addresses.drop()
-        self.users.drop()
+        meta.drop_all()
         UnitOfWorkTest.tearDown(self)
     
     def testbackwardsnonmatch(self):
-        m = mapper(Address, self.addresses, properties = dict(
-            user = relation(mapper(User, self.users), lazy = True, uselist = False)
+        m = mapper(Address, addresses, properties = dict(
+            user = relation(mapper(User, users), lazy = True, uselist = False)
         ))
         data = [
             {'user_name' : 'thesub' , 'email_address' : 'bar@foo.com'},
index d88b2bf83fdd4c87c699883f81b14dbc97c0558d..a19b8cf25feff999612dcf744e5aaeb7613ac7c7 100644 (file)
@@ -5,18 +5,17 @@ import unittest, sys, datetime
 import sqlalchemy.databases.sqlite as sqllite
 
 import tables
-db = testbase.db
 from sqlalchemy import *
 from sqlalchemy.engine import ResultProxy, RowProxy
 
 class QueryTest(PersistTest):
     
     def setUpAll(self):
-        global users
-        users = Table('query_users', db,
+        global users, metadata
+        metadata = BoundMetaData(testbase.db)
+        users = Table('query_users', metadata,
             Column('user_id', INT, primary_key = True),
             Column('user_name', VARCHAR(20)),
-            redefine = True
         )
         users.create()
     
@@ -71,16 +70,16 @@ class QueryTest(PersistTest):
             default_metadata.drop_all()
             default_metadata.clear()
  
+    @testbase.supported('postgres')
     def testpassiveoverride(self):
         """primarily for postgres, tests that when we get a primary key column back 
         from reflecting a table which has a default value on it, we pre-execute
         that PassiveDefault upon insert, even though PassiveDefault says 
         "let the database execute this", because in postgres we must have all the primary
         key values in memory before insert; otherwise we cant locate the just inserted row."""
-        if db.engine.name != 'postgres':
-            return
         try:
-            db.execute("""
+            meta = BoundMetaData(testbase.db)
+            testbase.db.execute("""
              CREATE TABLE speedy_users
              (
                  speedy_user_id   SERIAL     PRIMARY KEY,
@@ -90,19 +89,17 @@ class QueryTest(PersistTest):
              );
             """, None)
             
-            t = Table("speedy_users", db, autoload=True)
+            t = Table("speedy_users", meta, autoload=True)
             t.insert().execute(user_name='user', user_password='lala')
             l = t.select().execute().fetchall()
-            print l
             self.assert_(l == [(1, 'user', 'lala')])
         finally:
-            db.execute("drop table speedy_users", None)
+            testbase.db.execute("drop table speedy_users", None)
 
+    @testbase.supported('postgres')
     def testschema(self):
-        if not db.engine.__module__.endswith('postgres'):
-            return 
-            
-        test_table = Table('my_table', db,
+        meta1 = BoundMetaData(testbase.db)
+        test_table = Table('my_table', meta1,
                     Column('id', Integer, primary_key=True),
                     Column('data', String(20), nullable=False),
                     schema='alt_schema'
@@ -112,9 +109,8 @@ class QueryTest(PersistTest):
             # plain insert
             test_table.insert().execute(data='test')
 
-            # try with a PassiveDefault
-            test_table.deregister()
-            test_table = Table('my_table', db, autoload=True, redefine=True, schema='alt_schema')
+            meta2 = BoundMetaData(testbase.db)
+            test_table = Table('my_table', meta2, autoload=True, schema='alt_schema')
             test_table.insert().execute(data='test')
 
         finally:
@@ -187,10 +183,10 @@ class QueryTest(PersistTest):
         r = self.users.select().execute().fetchone()
         self.assertEqual(len(r), 2)
         r.close()
-        r = db.execute('select user_name, user_id from query_users', {}).fetchone()
+        r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone()
         self.assertEqual(len(r), 2)
         r.close()
-        r = db.execute('select user_name from query_users', {}).fetchone()
+        r = testbase.db.execute('select user_name from query_users', {}).fetchone()
         self.assertEqual(len(r), 1)
         r.close()
     
@@ -200,6 +196,56 @@ class QueryTest(PersistTest):
         z = testbase.db.func.current_date().scalar()
         assert x == y == z
 
+    def test_update_functions(self):
+        """test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances,
+        and that column-level defaults get overridden"""
+        meta = BoundMetaData(testbase.db)
+        t = Table('t1', meta,
+            Column('id', Integer, primary_key=True),
+            Column('value', Integer)
+        )
+        t2 = Table('t2', meta,
+            Column('id', Integer, primary_key=True),
+            Column('value', Integer, default="7"),
+            Column('stuff', String(20), onupdate="thisisstuff")
+        )
+        meta.create_all()
+        try:
+            t.insert().execute(value=func.length("one"))
+            assert t.select().execute().fetchone()['value'] == 3
+            t.update().execute(value=func.length("asfda"))
+            assert t.select().execute().fetchone()['value'] == 5
+
+            r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
+            id = r.last_inserted_ids()[0]
+            assert t.select(t.c.id==id).execute().fetchone()['value'] == 9
+            t.update(values={t.c.value:func.length("asdf")}).execute()
+            assert t.select().execute().fetchone()['value'] == 4
+
+            t2.insert().execute()
+            t2.insert().execute(value=func.length("one"))
+            t2.insert().execute(value=func.length("asfda") + -19, stuff="hi")
+
+            assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(7,None), (3,None), (-14,"hi")]
+            
+            t2.update().execute(value=func.length("asdsafasd"), stuff="some stuff")
+            assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")]
+            
+            t2.delete().execute()
+            
+            t2.insert(values=dict(value=func.length("one") + 8)).execute()
+            assert t2.select().execute().fetchone()['value'] == 11
+            
+            t2.update(values=dict(value=func.length("asfda"))).execute()
+            assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff")
+
+            t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute()
+            print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone()
+            assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo")
+            
+        finally:
+            meta.drop_all()
+            
     @testbase.supported('postgres')
     def test_functions_with_cols(self):
         x = testbase.db.func.current_date().execute().scalar()
@@ -226,7 +272,7 @@ class QueryTest(PersistTest):
     def test_column_order_with_text_query(self):
         # should return values in query order
         self.users.insert().execute(user_id=1, user_name='foo')
-        r = db.execute('select user_name, user_id from query_users', {}).fetchone()
+        r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone()
         self.assertEqual(r[0], 'foo')
         self.assertEqual(r[1], 1)
         self.assertEqual(r.keys(), ['user_name', 'user_id'])
@@ -234,14 +280,14 @@ class QueryTest(PersistTest):
        
     @testbase.unsupported('oracle', 'firebird') 
     def test_column_accessor_shadow(self):
-        shadowed = Table('test_shadowed', db,
+        meta = BoundMetaData(testbase.db)
+        shadowed = Table('test_shadowed', meta,
                          Column('shadow_id', INT, primary_key = True),
                          Column('shadow_name', VARCHAR(20)),
                          Column('parent', VARCHAR(20)),
                          Column('row', VARCHAR(40)),
                          Column('__parent', VARCHAR(20)),
                          Column('__row', VARCHAR(20)),
-            redefine = True
         )
         shadowed.create()
         try:
index 0c2aa1b56d0103608e113f35e18e7654624c7b01..cd434a18450187a9c5ae7a9e6bf2bb3421863d41 100755 (executable)
@@ -16,7 +16,7 @@ table = Table('table1', db,
     Column('col2', String(20)),\r
     Column('col3', Integer),\r
     Column('colx', Integer),\r
-    redefine=True\r
+    \r
 )\r
 \r
 table2 = Table('table2', db,\r
@@ -24,7 +24,6 @@ table2 = Table('table2', db,
     Column('col2', Integer, ForeignKey('table1.col1')),\r
     Column('col3', String(20)),\r
     Column('coly', Integer),\r
-    redefine=True\r
 )\r
 \r
 class SelectableTest(testbase.AssertMixin):\r
index 1d705581142e3602ec11ef3f2bee9add4837399b..2700ec6c795d60346ea06ebc861c555446671be6 100644 (file)
@@ -265,7 +265,7 @@ class DateTest(AssertMixin):
             collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)),
                            Column('user_date', Date), Column('user_time', Time)]
  
-        users_with_date = Table('query_users_with_date', db, redefine = True, *collist)
+        users_with_date = Table('query_users_with_date', db, *collist)
         users_with_date.create()
         insert_dicts = [dict(zip(fnames, d)) for d in insert_data]