]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- column label and bind param "truncation" also generate
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Mar 2007 07:19:14 +0000 (07:19 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Mar 2007 07:19:14 +0000 (07:19 +0000)
deterministic names now, based on their ordering within the
full statement being compiled.  this means the same statement
will produce the same string across application restarts and
allowing DB query plan caching to work better.
- cleanup to sql.ClauseParameters since it was just falling
apart, API made more explicit
- many unit test tweaks to adjust for bind params not being
"pre" truncated, changes to ClauseParameters

12 files changed:
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/orm/relationships.py
test/sql/labels.py
test/sql/select.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index fd8564e784c8199756a51b66d4b623f949d8dea2..dfb5f08eebaa8cfd3d2a35e0e5fadea3193f0aaf 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,11 @@
       on postgres.  Also, the true labelname is always attached as the
       accessor on the parent Selectable so theres no need to be aware
       of the genrerated label names [ticket:512].
+    - column label and bind param "truncation" also generate 
+      deterministic names now, based on their ordering within the 
+      full statement being compiled.  this means the same statement
+      will produce the same string across application restarts and
+      allowing DB query plan caching to work better.
     - preliminary support for unicode table and column names added.
     - fix for fetchmany() "size" argument being positional in most
       dbapis [ticket:505]
index 050e605ebf6b78a223b6f6ed71fa8ceccb8f34d4..a75263d915b44d60c3bd61ab1275907819b30839 100644 (file)
@@ -132,6 +132,11 @@ class ANSICompiler(sql.Compiled):
         # a dictionary of select columns labels mapped to their "generated" label
         self.column_labels = {}
 
+        # a dictionary of ClauseElement subclasses to counters, which are used to
+        # generate truncated identifier names or "anonymous" identifiers such as
+        # for aliases
+        self.generated_ids = {}
+        
         # True if this compiled represents an INSERT
         self.isinsert = False
 
@@ -242,24 +247,27 @@ class ANSICompiler(sql.Compiled):
         return ""
 
     def visit_label(self, label):
-        labelname = label.name
-        if len(labelname) >= self.dialect.max_identifier_length():
-            labelname = labelname[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:]
+        labelname = self._truncated_identifier("colident", label.name)
         
         if len(self.select_stack):
             self.typemap.setdefault(labelname.lower(), label.obj.type)
             if isinstance(label.obj, sql._ColumnClause):
-                self.column_labels[label.obj._label] = labelname.lower()
+                self.column_labels[label.obj._label] = labelname
         self.strings[label] = self.strings[label.obj] + " AS "  + self.preparer.format_label(label, labelname)
         
     def visit_column(self, column):
-        if len(self.select_stack):
-            # if we are within a visit to a Select, set up the "typemap"
-            # for this column which is used to translate result set values
-            self.typemap.setdefault(column.name.lower(), column.type)
-            self.column_labels.setdefault(column._label, column.name.lower())
+        # there is actually somewhat of a ruleset when you would *not* necessarily
+        # want to truncate a column identifier, if its mapped to the name of a 
+        # physical column.  but thats very hard to identify at this point, and 
+        # the identifier length should be greater than the id lengths of any physical
+        # columns so should not matter.
+        if not column.is_literal:
+            name = self._truncated_identifier("colident", column.name)
+        else:
+            name = column.name
+                
         if column.table is None or not column.table.named_with_column():
-            self.strings[column] = self.preparer.format_column(column)
+            self.strings[column] = self.preparer.format_column(column, name=name)
         else:
             if column.table.oid_column is column:
                 n = self.dialect.oid_column_name(column)
@@ -270,7 +278,13 @@ class ANSICompiler(sql.Compiled):
                 else:
                     self.strings[column] = None
             else:
-                self.strings[column] = self.preparer.format_column_with_table(column)
+                self.strings[column] = self.preparer.format_column_with_table(column, column_name=name)
+
+        if len(self.select_stack):
+            # if we are within a visit to a Select, set up the "typemap"
+            # for this column which is used to translate result set values
+            self.typemap.setdefault(name.lower(), column.type)
+            self.column_labels.setdefault(column._label, name.lower())
 
     def visit_fromclause(self, fromclause):
         self.froms[fromclause] = fromclause.name
@@ -394,11 +408,23 @@ class ANSICompiler(sql.Compiled):
             
         bind_name = bindparam.key
         if len(bind_name) >= self.dialect.max_identifier_length():
-            bind_name = bind_name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:]
+            bind_name = self._truncated_identifier("bindparam", bind_name)
             # add to bind_names for translation
             self.bind_names[bindparam] = bind_name
         return bind_name
-        
+    
+    def _truncated_identifier(self, ident_class, name):
+        if (ident_class, name) in self.generated_ids:
+            return self.generated_ids[(ident_class, name)]
+        if len(name) >= self.dialect.max_identifier_length():
+            counter = self.generated_ids.get(ident_class, 1)
+            truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
+            self.generated_ids[ident_class] = counter + 1
+        else:
+            truncname = name
+        self.generated_ids[(ident_class, name)] = truncname
+        return truncname
+            
     def bindparam_string(self, name):
         return self.bindtemplate % name
 
@@ -1043,30 +1069,33 @@ class ANSIIdentifierPreparer(object):
     def format_alias(self, alias):
         return self.__generic_obj_format(alias, alias.name)
 
-    def format_table(self, table, use_schema=True):
+    def format_table(self, table, use_schema=True, name=None):
         """Prepare a quoted table and schema name."""
 
-        result = self.__generic_obj_format(table, table.name)
+        if name is None:
+            name = table.name
+        result = self.__generic_obj_format(table, name)
         if use_schema and getattr(table, "schema", None):
             result = self.__generic_obj_format(table, table.schema) + "." + result
         return result
 
-    def format_column(self, column, use_table=False):
+    def format_column(self, column, use_table=False, name=None):
         """Prepare a quoted column name."""
-
+        if name is None:
+            name = column.name
         if not getattr(column, 'is_literal', False):
             if use_table:
-                return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name)
+                return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name)
             else:
-                return self.__generic_obj_format(column, column.name)
+                return self.__generic_obj_format(column, name)
         else:
             # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
             if use_table:
-                return self.format_table(column.table, use_schema=False) + "." + column.name
+                return self.format_table(column.table, use_schema=False) + "." + name
             else:
-                return column.name
+                return name
 
-    def format_column_with_table(self, column):
+    def format_column_with_table(self, column, column_name=None):
         """Prepare a quoted column name with table name."""
-
-        return self.format_column(column, use_table=True)
+        
+        return self.format_column(column, use_table=True, name=column_name)
index 84ad6478fd17ceeffb4f835eb454ca2fd86f2d77..c2ae272f512a63417def67c3655f8368ffad860f 100644 (file)
@@ -883,7 +883,7 @@ class ResultProxy(object):
             elif isinstance(key, basestring) and key.lower() in self.props:
                 rec = self.props[key.lower()]
             elif isinstance(key, sql.ColumnElement):
-                label = self.column_labels.get(key._label, key.name)
+                label = self.column_labels.get(key._label, key.name).lower()
                 if label in self.props:
                     rec = self.props[label]
                         
index 798d02d32b7fee44b97a6d049b213e30b5b1bff3..e9ea6c1498d3a635851c1eb0763e3dd6b12b6769 100644 (file)
@@ -255,20 +255,20 @@ class DefaultExecutionContext(base.ExecutionContext):
                     # its a pk, add the value to our last_inserted_ids list,
                     # or, if its a SQL-side default, dont do any of that, but we'll need
                     # the SQL-generated value after execution.
-                    elif not param.has_key(c.key) or param[c.key] is None:
+                    elif not c.key in param or param.get_original(c.key) is None:
                         if isinstance(c.default, schema.PassiveDefault):
                             self._lastrow_has_defaults = True
                         newid = drunner.get_column_default(c)
                         if newid is not None:
-                            param[c.key] = newid
+                            param.set_value(c.key, newid)
                             if c.primary_key:
-                                last_inserted_ids.append(param[c.key])
+                                last_inserted_ids.append(param.get_processed(c.key))
                         elif c.primary_key:
                             need_lastrowid = True
                     # its an explicitly passed pk value - add it to
                     # our last_inserted_ids list.
                     elif c.primary_key:
-                        last_inserted_ids.append(param[c.key])
+                        last_inserted_ids.append(param.get_processed(c.key))
                 if need_lastrowid:
                     self._last_inserted_ids = None
                 else:
@@ -290,8 +290,8 @@ class DefaultExecutionContext(base.ExecutionContext):
                         pass
                     # its not in the bind parameters, and theres an "onupdate" defined for the column;
                     # execute it and add to bind params
-                    elif c.onupdate is not None and (not param.has_key(c.key) or param[c.key] is None):
+                    elif c.onupdate is not None and (not c.key in param or param.get_original(c.key) is None):
                         value = drunner.get_column_onupdate(c)
                         if value is not None:
-                            param[c.key] = value
+                            param.set_value(c.key, value)
                 self._last_updated_params = param
index b899c043d4912bdaa10eae935c9a475c5bb3a249..21c1fac51b7ac48d1de4a84a6b18b0facbeb7d98 100644 (file)
@@ -187,13 +187,13 @@ If you join tables that have an identical column name, wrap your join
 with `with_labels`, to disambiguate columns with their table name::
 
     >>> db.with_labels(join1).c.keys()
-    ['users_name', 'users_email', 'users_password', 'users_classname', 'users_admin', 'loans_book_id', 'loans_user_name', 'loans_loan_date']
+    [u'users_name', u'users_email', u'users_password', u'users_classname', u'users_admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
 
 You can also join directly to a labeled object::
 
     >>> labeled_loans = db.with_labels(db.loans)
     >>> db.join(db.users, labeled_loans, isouter=True).c.keys()
-    ['name', 'email', 'password', 'classname', 'admin', 'loans_book_id', 'loans_user_name', 'loans_loan_date']
+    [u'name', u'email', u'password', u'classname', u'admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
 
 
 Advanced Use
index 3d7ddb5d69a9277be67d71889ce748caf853a0f2..0279cca53bf031726b052cbaa741d64777032978 100644 (file)
@@ -1237,13 +1237,13 @@ class Mapper(object):
                     self.set_attr_by_column(obj, c, row[c])
         else:
             for c in table.c:
-                if c.primary_key or not params.has_key(c.name):
+                if c.primary_key or not c.key in params:
                     continue
                 v = self.get_attr_by_column(obj, c, False)
                 if v is NO_ATTRIBUTE:
                     continue
-                elif v != params.get_original(c.name):
-                    self.set_attr_by_column(obj, c, params.get_original(c.name))
+                elif v != params.get_original(c.key):
+                    self.set_attr_by_column(obj, c, params.get_original(c.key))
 
     def delete_obj(self, objects, uowtransaction):
         """Issue ``DELETE`` statements for a list of objects.
index 5ed95fabb5bf6f81fdcb33b2183d76bd75efbf67..ff835cec935ce88f18b0b9fe6d47dfe18699e253 100644 (file)
@@ -613,6 +613,7 @@ class Column(SchemaItem, sql._ColumnClause):
         [c._init_items(f) for f in fk]
         return c
 
+
     def _case_sens(self):
         """Redirect the `case_sensitive` accessor to use the ultimate
         parent column which created this one."""
index be43bb21b5f44015706ab08137dda2b0ba75bb90..78d07bec8584be775ecc1b5589b8c5b1edde2710 100644 (file)
@@ -442,7 +442,7 @@ class AbstractDialect(object):
     Used by ``Compiled`` objects."""
     pass
 
-class ClauseParameters(dict):
+class ClauseParameters(object):
     """Represent a dictionary/iterator of bind parameter key names/values.
 
     Tracks the original ``BindParam`` objects as well as the
@@ -453,39 +453,54 @@ class ClauseParameters(dict):
 
     def __init__(self, dialect, positional=None):
         super(ClauseParameters, self).__init__(self)
-        self.dialect=dialect
+        self.dialect = dialect
         self.binds = {}
+        self.binds_to_names = {}
+        self.binds_to_values = {}
         self.positional = positional or []
 
     def set_parameter(self, bindparam, value, name):
-        self[name] = value
+        self.binds[bindparam.key] = bindparam
         self.binds[name] = bindparam
-
+        self.binds_to_names[bindparam] = name
+        self.binds_to_values[bindparam] = value
+        
     def get_original(self, key):
         """Return the given parameter as it was originally placed in
         this ``ClauseParameters`` object, without any ``Type``
         conversion."""
+        return self.binds_to_values[self.binds[key]]
 
-        return super(ClauseParameters, self).__getitem__(key)
-
+    def get_processed(self, key):
+        bind = self.binds[key]
+        value = self.binds_to_values[bind]
+        return bind.typeprocess(value, self.dialect)
+    
     def __getitem__(self, key):
-        v = super(ClauseParameters, self).__getitem__(key)
-        if self.binds.has_key(key):
-            v = self.binds[key].typeprocess(v, self.dialect)
-        return v
-
+        return self.get_processed(key)
+        
+    def __contains__(self, key):
+        return key in self.binds
+    
+    def set_value(self, key, value):
+        bind = self.binds[key]
+        self.binds_to_values[bind] = value
+            
     def get_original_dict(self):
-        return self.copy()
+        return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()])
 
     def get_raw_list(self):
-        return [self[key] for key in self.positional]
+        return [self.get_processed(key) for key in self.positional]
 
     def get_raw_dict(self):
         d = {}
-        for k in self:
-            d[k] = self[k]
+        for k in self.binds_to_names.values():
+            d[k] = self.get_processed(k)
         return d
 
+    def __repr__(self):
+        return repr(self.get_original_dict())
+
 class ClauseVisitor(object):
     """A class that knows how to traverse and visit
     ``ClauseElements``.
@@ -1012,6 +1027,7 @@ class ColumnElement(Selectable, _CompareMixin):
         with Selectable objects.
         """)
 
+
     def _one_fkey(self):
         if len(self._foreign_keys):
             return list(self._foreign_keys)[0]
@@ -1037,7 +1053,7 @@ class ColumnElement(Selectable, _CompareMixin):
         for a column proxied from a Union (i.e. CompoundSelect), this 
         set will be just one element.
         """)
-
+    
     def shares_lineage(self, othercolumn):
         """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``."""
 
@@ -1929,6 +1945,8 @@ class _ColumnClause(ColumnElement):
             self.__label = "".join([x for x in self.__label if x in legal_characters])
         return self.__label
 
+    is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name)
+
     _label = property(_get_label)
 
     def label(self, name):
index 7c2ec45728c1152f61476cad2fa926d9fe4ce7b9..e2ca39c5116cd86a5728fd7cfc2c57c199ec6b1a 100644 (file)
@@ -516,7 +516,7 @@ class RelationTest5(testbase.ORMTest):
         
         container_select = select(
             [items.c.policyNum, items.c.policyEffDate, items.c.type],
-            distinct=True,
+            distinct=True, 
             ).alias('container_select')
 
         mapper(LineItem, items)
index 0302fee7845ac9f3949d82308c4bf158edbcb4c4..a2e899ed6e793523f8cd98cda3eb75cc2b460741 100644 (file)
@@ -2,27 +2,34 @@ import testbase
 
 from sqlalchemy import *
 
+# TODO: either create a mock dialect with named paramstyle and a short identifier length,
+# or find a way to just use sqlite dialect and make those changes
+
 class LongLabelsTest(testbase.PersistTest):
     def setUpAll(self):
         global metadata, table1
         metadata = MetaData(engine=testbase.db)
         table1 = Table("some_large_named_table", metadata,
-            Column("this_is_the_primary_key_column", Integer, primary_key=True),
+            Column("this_is_the_primarykey_column", Integer, primary_key=True),
             Column("this_is_the_data_column", String(30))
             )
         metadata.create_all()
-        table1.insert().execute(**{"this_is_the_primary_key_column":1, "this_is_the_data_column":"data1"})
-        table1.insert().execute(**{"this_is_the_primary_key_column":2, "this_is_the_data_column":"data2"})
-        table1.insert().execute(**{"this_is_the_primary_key_column":3, "this_is_the_data_column":"data3"})
-        table1.insert().execute(**{"this_is_the_primary_key_column":4, "this_is_the_data_column":"data4"})
+    def tearDown(self):
+        table1.delete().execute()
+        
     def tearDownAll(self):
         metadata.drop_all()
         
     def test_result(self):
+        table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
+        table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"})
+        table1.insert().execute(**{"this_is_the_primarykey_column":3, "this_is_the_data_column":"data3"})
+        table1.insert().execute(**{"this_is_the_primarykey_column":4, "this_is_the_data_column":"data4"})
+
         r = table1.select(use_labels=True).execute()
         result = []
         for row in r:
-            result.append((row[table1.c.this_is_the_primary_key_column], row[table1.c.this_is_the_data_column]))
+            result.append((row[table1.c.this_is_the_primarykey_column], row[table1.c.this_is_the_data_column]))
         assert result == [
             (1, "data1"),
             (2, "data2"),
@@ -31,14 +38,30 @@ class LongLabelsTest(testbase.PersistTest):
         ]
     
     def test_colbinds(self):
-        r = table1.select(table1.c.this_is_the_primary_key_column == 4).execute()
+        table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
+        table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"})
+        table1.insert().execute(**{"this_is_the_primarykey_column":3, "this_is_the_data_column":"data3"})
+        table1.insert().execute(**{"this_is_the_primarykey_column":4, "this_is_the_data_column":"data4"})
+
+        r = table1.select(table1.c.this_is_the_primarykey_column == 4).execute()
         assert r.fetchall() == [(4, "data4")]
 
         r = table1.select(or_(
-            table1.c.this_is_the_primary_key_column == 4,
-            table1.c.this_is_the_primary_key_column == 2
+            table1.c.this_is_the_primarykey_column == 4,
+            table1.c.this_is_the_primarykey_column == 2
         )).execute()
         assert r.fetchall() == [(2, "data2"), (4, "data4")]
+    
+    def test_insert_no_pk(self):
+        table1.insert().execute(**{"this_is_the_data_column":"data1"})
+        table1.insert().execute(**{"this_is_the_data_column":"data2"})
+        table1.insert().execute(**{"this_is_the_data_column":"data3"})
+        table1.insert().execute(**{"this_is_the_data_column":"data4"})
+        
+    def test_subquery(self):
+        q = table1.select(table1.c.this_is_the_primarykey_column == 4, use_labels=True)
+        x = select([q])
+        print str(x)
         
 if __name__ == '__main__':
     testbase.main()
\ No newline at end of file
index 5fcf88fd1ee92cc71e31630607c02ee6066b38fa..f71d5366b24109b85996b31836a6bc4aa5081865 100644 (file)
@@ -59,9 +59,9 @@ class SQLTest(PersistTest):
         self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'")
         if checkparams is not None:
             if isinstance(checkparams, list):
-                self.assert_(c.get_params().values() == checkparams, "params dont match ")
+                self.assert_(c.get_params().get_raw_list() == checkparams, "params dont match ")
             else:
-                self.assert_(c.get_params() == checkparams, "params dont match" + repr(c.get_params()))
+                self.assert_(c.get_params().get_original_dict() == checkparams, "params dont match" + repr(c.get_params()))
             
 class SelectTest(SQLTest):
     def testtableselect(self):
@@ -201,7 +201,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
                          order_by = ['dist', places.c.nm]
                          )
 
-        self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude AS latitude FROM zips WHERE zips.zipcode = :zips_zipco_1), (SELECT zips.longitude AS longitude FROM zips WHERE zips.zipcode = :zips_zipco_2)) AS dist FROM places, zips WHERE zips.zipcode = :zips_zipcode ORDER BY dist, places.nm")
+        self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude AS latitude FROM zips WHERE zips.zipcode = :zips_zipcode_1), (SELECT zips.longitude AS longitude FROM zips WHERE zips.zipcode = :zips_zipcode_2)) AS dist FROM places, zips WHERE zips.zipcode = :zips_zipcode ORDER BY dist, places.nm")
         
         zalias = zips.alias('main_zip')
         qlat = select([zips.c.latitude], zips.c.zipcode == zalias.c.zipcode, scalar=True)
@@ -224,8 +224,8 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
                 or_(table2.c.othername=='asdf', table2.c.othername == 'foo', table2.c.otherid == 9),
                 "sysdate() = today()", 
             )),
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND (myothertable.othername = :myothertable_othername OR myothertable.othername = :myothertable_otherna_1 OR myothertable.otherid = :myothertable_otherid) AND sysdate() = today()",
-            checkparams = {'myothertable_othername': 'asdf', 'myothertable_otherna_1':'foo', 'myothertable_otherid': 9, 'mytable_myid': 12}
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND (myothertable.othername = :myothertable_othername OR myothertable.othername = :myothertable_othername_1 OR myothertable.otherid = :myothertable_otherid) AND sysdate() = today()",
+            checkparams = {'myothertable_othername': 'asdf', 'myothertable_othername_1':'foo', 'myothertable_otherid': 9, 'mytable_myid': 12}
         )
 
     def testoperators(self):
@@ -235,7 +235,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         )
         
         self.runtest(
-            literal("a") + literal("b") * literal("c"), ":literal + (:liter_1 * :liter_2)"
+            literal("a") + literal("b") * literal("c"), ":literal + (:literal_1 * :literal_2)"
         )
 
         # test the op() function, also that its results are further usable in expressions
@@ -255,7 +255,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
     def testmultiparam(self):
         self.runtest(
             select(["*"], or_(table1.c.myid == 12, table1.c.myid=='asdf', table1.c.myid == 'foo')), 
-            "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid OR mytable.myid = :mytable_my_1 OR mytable.myid = :mytable_my_2"
+            "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid OR mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2"
         )
 
     def testorderby(self):
@@ -419,7 +419,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
 
     def testliteral(self):
         self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]), 
-            "SELECT :literal + :liter_1 FROM mytable")
+            "SELECT :literal + :literal_1 FROM mytable")
 
     def testcalculatedcolumns(self):
          value_tbl = table('values',
@@ -449,7 +449,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
         """tests the generation of functions using the func keyword"""
         # test an expression with a function
         self.runtest(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, 
-            "lala(:lala, :la_1, :literal, mytable.myid) * myothertable.otherid")
+            "lala(:lala, :lala_1, :literal, mytable.myid) * myothertable.otherid")
 
         # test it in a SELECT
         self.runtest(select([func.count(table1.c.myid)]), 
@@ -471,7 +471,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
         """test the EXTRACT function"""
         self.runtest(select([extract("month", table3.c.otherstuff)]), "SELECT extract(month FROM thirdtable.otherstuff) FROM thirdtable")
         
-        self.runtest(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date, :to_da_1))")
+        self.runtest(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date, :to_date_1))")
         
     def testjoin(self):
         self.runtest(
@@ -526,7 +526,7 @@ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertabl
             self.runtest(x, "SELECT mytable.myid, mytable.name, mytable.description \
 FROM mytable WHERE mytable.myid = :mytable_myid UNION \
 SELECT mytable.myid, mytable.name, mytable.description \
-FROM mytable WHERE mytable.myid = :mytable_my_1 ORDER BY mytable.myid")
+FROM mytable WHERE mytable.myid = :mytable_myid_1 ORDER BY mytable.myid")
   
             self.runtest(
                     union(
@@ -636,17 +636,17 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
              ),
              (
                  select([table1], or_(table1.c.myid==bindparam('myid', unique=True), table2.c.otherid==bindparam('myid', unique=True))),
-                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :my_1",
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid_1",
                  "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?",
-                 {'myid':None, 'my_1':None}, [None, None],
-                 {'myid':5, 'my_1': 6}, {'myid':5, 'my_1':6}, [5,6]
+                 {'myid':None, 'myid_1':None}, [None, None],
+                 {'myid':5, 'myid_1': 6}, {'myid':5, 'myid_1':6}, [5,6]
              ),
              (
                  select([table1], or_(table1.c.myid==bindparam('myid', value=7, unique=True), table2.c.otherid==bindparam('myid', value=8, unique=True))),
-                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :my_1",
+                 "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid_1",
                  "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?",
-                 {'myid':7, 'my_1':8}, [7,8],
-                 {'myid':5, 'my_1':6}, {'myid':5, 'my_1':6}, [5,6]
+                 {'myid':7, 'myid_1':8}, [7,8],
+                 {'myid':5, 'myid_1':6}, {'myid':5, 'myid_1':6}, [5,6]
              ),
              ][2:3]:
              
@@ -666,18 +666,18 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
         except exceptions.CompileError, err:
             assert str(err) == "Bind parameter 'mytable_myid' conflicts with unique bind parameter of the same name"
 
-        s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('mytable_my_1')))
+        s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('mytable_myid_1')))
         try:
             str(s)
             assert False
         except exceptions.CompileError, err:
-            assert str(err) == "Bind parameter 'mytable_my_1' conflicts with unique bind parameter of the same name"
+            assert str(err) == "Bind parameter 'mytable_myid_1' conflicts with unique bind parameter of the same name"
             
         # check that the bind params sent along with a compile() call
         # get preserved when the params are retreived later
         s = select([table1], table1.c.myid == bindparam('test'))
         c = s.compile(parameters = {'test' : 7})
-        self.assert_(c.get_params() == {'test' : 7})
+        self.assert_(c.get_params().get_original_dict() == {'test' : 7})
 
     def testbindascol(self):
         t = table('foo', column('id'))
@@ -688,7 +688,7 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
         
     def testin(self):
         self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_my_1, :mytable_my_2)")
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_myid_1, :mytable_myid_2)")
 
         self.runtest(select([table1], table1.c.myid.in_(select([table2.c.otherid]))),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid AS otherid FROM myothertable)")
@@ -752,9 +752,9 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
         import datetime
         table = Table('dt', metadata, 
             Column('date', Date))
-        self.runtest(table.select(table.c.date.between(datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :dt_date AND :dt_da_1", checkparams={'dt_date':datetime.date(2006,6,1), 'dt_da_1':datetime.date(2006,6,5)})
+        self.runtest(table.select(table.c.date.between(datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :dt_date AND :dt_date_1", checkparams={'dt_date':datetime.date(2006,6,1), 'dt_date_1':datetime.date(2006,6,5)})
 
-        self.runtest(table.select(sql.between(table.c.date, datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :literal AND :liter_1", checkparams={'literal':datetime.date(2006,6,1), 'liter_1':datetime.date(2006,6,5)})
+        self.runtest(table.select(sql.between(table.c.date, datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :literal AND :literal_1", checkparams={'literal':datetime.date(2006,6,1), 'literal_1':datetime.date(2006,6,5)})
 
 class CRUDTest(SQLTest):
     def testinsert(self):
@@ -803,7 +803,7 @@ class CRUDTest(SQLTest):
             values = {
             table1.c.name : table1.c.name + "lala",
             table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho'))
-            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :liter_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = ((:literal + mytable.name) + :liter_1)")
+            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = ((:literal + mytable.name) + :literal_1)")
         
     def testcorrelatedupdate(self):
         # test against a straight text subquery
index 88519ef41e5e0730bbb25fc0ad3773797295e905..8a1d9ee59a827b7ebbf04ffaf4b26265f2376c2f 100644 (file)
@@ -3,12 +3,10 @@ sys.path.insert(0, './lib/')
 import os
 import unittest
 import StringIO
-import sqlalchemy.engine as engine
 import sqlalchemy.ext.proxy as proxy
-import sqlalchemy.pool as pool
-#import sqlalchemy.schema as schema
 import re
 import sqlalchemy
+from sqlalchemy import sql, engine, pool
 import optparse
 from sqlalchemy.schema import BoundMetaData
 from sqlalchemy.orm import clear_mappers
@@ -294,6 +292,11 @@ class EngineAssert(proxy.BaseProxyEngine):
                     params = params(ctx)
                 if params is not None and isinstance(params, list) and len(params) == 1:
                     params = params[0]
+                
+                if isinstance(parameters, sql.ClauseParameters):
+                    parameters = parameters.get_original_dict()
+                elif isinstance(parameters, list):
+                    parameters = [p.get_original_dict() for p in parameters]
                         
                 query = self.convert_statement(query)
                 self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))