]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
oids rows insert sort orders galore
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Oct 2005 22:57:32 +0000 (22:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Oct 2005 22:57:32 +0000 (22:57 +0000)
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/mapper.py
test/objectstore.py
test/tables.py

index cfaf63b57fa8ec46d8b98b788cd8e1f5eb9c896c..74c2b53666031ec330e87d17d33569ebc571ddd2 100644 (file)
@@ -92,8 +92,13 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         raise "not implemented"
         
     def last_inserted_ids(self):
-        return self.context.last_inserted_ids
-
+        table = self.context.last_inserted_table
+        if self.context.lastrowid is not None and table is not None and len(table.primary_keys):
+            row = sql.select(table.primary_keys, table.rowid_column == self.context.lastrowid).execute().fetchone()
+            return [v for v in row]
+        else:
+            return None
+            
     def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
         if True: return
         # if a sequence was explicitly defined we do it here
@@ -123,18 +128,10 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
     def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
-            # psycopg wants to return internal rowids, which I guess is what DBAPI2 really 
-            # specifies.
-            # well then post exec to get the row.  I guess this could be genericised to 
-            # be for all inserts somehow if the "rowid" col could be gotten off a table.
             table = compiled.statement.table
-            if len(table.primary_keys):
-                # TODO: cache this statement against the table to avoid multiple re-compiles
-                # TODO: instead of "oid" have the Table object have a "rowid_col" property
-                # that gives this col generically
-                row = sql.select(table.primary_keys, sql.ColumnClause("oid",table) == bindparam('oid', cursor.lastrowid) ).execute().fetchone()
-                self.context.last_inserted_ids = [v for v in row]
-
+            self.context.last_inserted_table = table
+            self.context.lastrowid = cursor.lastrowid
+            
     def dbapi(self):
         return self.module
 
index 8be019ea53bf738471d1fd093eccd65640ecdc57..4e01d1684d5672bb0986b976c22ef375ebed83ca 100644 (file)
@@ -78,6 +78,10 @@ class SQLEngine(schema.SchemaEngine):
     def compiler(self, statement, bindparams):
         raise NotImplementedError()
 
+    def rowid_column_name(self):
+        """returns the ROWID column name for this engine."""
+        return "oid"
+
     def create(self, table, **params):
         table.accept_visitor(self.schemagenerator(self.proxy(), **params))
 
index 42e0945bb65f9f2d9ecf00a95ebb6882f0fb7546..5e39e10f6ee50cb4433dfeea56c51155e33bdc31 100644 (file)
@@ -90,8 +90,16 @@ def mapper(class_, table = None, engine = None, autoload = False, *args, **param
         return _mappers[hashkey]
 
 def clear_mappers():
+    """removes all mappers that have been created thus far.  when new mappers are 
+    created, they will be assigned to their classes as their primary mapper."""
     _mappers.clear()
-        
+
+def clear_mapper(m):
+    """removes the given mapper from the storage of mappers.  when a new mapper is 
+    created for the previous mapper's class, it will be used as that classes' 
+    new primary mapper."""
+    del _mappers[m.hash_key]
+    
 def eagerload(name):
     """returns a MapperOption that will convert the property of the given name
     into an eager load.  Used with mapper.options()"""
@@ -261,16 +269,20 @@ class Mapper(object):
         return self.hashkey
 
     def _init_class(self):
+        """sets up our classes' overridden __init__ method, this mappers hash key as its '_mapper' property,
+        and our columns as its 'c' property.  if the class already had a mapper, the old __init__ method
+        is kept the same."""
+        if not hasattr(self.class_, '_mapper'):
+            oldinit = self.class_.__init__
+            def init(self, *args, **kwargs):
+                nohist = kwargs.pop('_mapper_nohistory', False)
+                if oldinit is not None:
+                    oldinit(self, *args, **kwargs)
+                if not nohist:
+                    objectstore.uow().register_new(self)
+            self.class_.__init__ = init
         self.class_._mapper = self.hashkey
         self.class_.c = self.c
-        oldinit = self.class_.__init__
-        def init(self, *args, **kwargs):
-            nohist = kwargs.pop('_mapper_nohistory', False)
-            if oldinit is not None:
-                oldinit(self, *args, **kwargs)
-            if not nohist:
-                objectstore.uow().register_new(self)
-        self.class_.__init__ = init
         
     def set_property(self, key, prop):
         self.props[key] = prop
@@ -938,7 +950,11 @@ class LazyLoader(PropertyLoader):
                 params = {}
                 for key in self.lazybinds.keys():
                     params[key] = row[key]
-                result = self.mapper.select(self.lazywhere, **params)
+                if self.secondary is not None:
+                    order_by = [self.secondary.rowid_column]
+                else:
+                    order_by = [self.target.rowid_column]
+                result = self.mapper.select(self.lazywhere, order_by=order_by,**params)
                 if self.uselist:
                     return result
                 else:
index 7cc9a69fb7f377105e5427dc2ef6914d0acd6e2a..3e216a768b219582f34c84b1152508be2fce041d 100644 (file)
@@ -54,20 +54,42 @@ def get_row_key(row, class_, table, primary_keys):
     """
     return (class_, table, tuple([row[column.label] for column in primary_keys]))
 
+def begin():
+    """begins a new UnitOfWork transaction.  the next commit will affect only
+    objects that are created, modified, or deleted following the begin statement."""
+    uow().begin()
+    
 def commit(*obj):
+    """commits the current UnitOfWork transaction.  if a transaction was begun 
+    via begin(), commits only those objects that were created, modified, or deleted
+    since that begin statement.  otherwise commits all objects that have been
+    changed."""
     uow().commit(*obj)
     
 def clear():
+    """removes all current UnitOfWorks and IdentityMaps for this thread and 
+    establishes a new one.  It is probably a good idea to discard all
+    current mapped object instances, as they are no longer in the Identity Map."""
     uow.set(UnitOfWork())
 
 def delete(*obj):
+    """registers the given objects as to be deleted upon the next commit"""
     uw = uow()
     for o in obj:
         uw.register_deleted(o)
     
 def has_key(key):
+    """returns True if the current thread-local IdentityMap contains the given instance key"""
     return uow().identity_map.has_key(key)
 
+def has_instance(instance):
+    """returns True if the current thread-local IdentityMap contains the given instance"""
+    return uow().identity_map.has_key(instance_key(instance))
+
+def instance_key(instance):
+    """returns the IdentityMap key for the given instance"""
+    return object_mapper(instance).instance_key(instance)
+    
 class UOWListElement(attributes.ListElement):
     def __init__(self, obj, key, data=None, deleteremoved=False):
         attributes.ListElement.__init__(self, obj, key, data=data)
index 9fad004b13d8461b3bf3e889b904994fbd294f9b..47db789f4b1a10c384ae670271f883b0390fb809 100644 (file)
@@ -130,7 +130,6 @@ class Column(SchemaItem):
         
     original = property(lambda s: s._orig or s)
     engine = property(lambda s: s.table.engine)
-    
         
     def _set_parent(self, table):
         table.columns[self.key] = self
index 3ab1e5ec29db848e9b3f4e2c002748c2e88498d3..b2272541d1e58b60190c835e028efb13d9f52cc1 100644 (file)
@@ -401,7 +401,6 @@ class CompoundClause(ClauseElement):
     """represents a list of clauses joined by an operator"""
     def __init__(self, operator, *clauses):
         self.operator = operator
-        self.fromobj = []
         self.clauses = []
         self.parens = False
         for c in clauses:
@@ -418,7 +417,6 @@ class CompoundClause(ClauseElement):
         elif isinstance(clause, CompoundClause):
             clause.parens = True
         self.clauses.append(clause)
-        self.fromobj += clause._get_from_objects()
 
     def accept_visitor(self, visitor):
         for c in self.clauses:
@@ -426,7 +424,10 @@ class CompoundClause(ClauseElement):
         visitor.visit_compound(self)
 
     def _get_from_objects(self):
-        return self.fromobj
+        f = []
+        for c in self.clauses:
+            f += c._get_from_objects()
+        return f
         
     def hash_key(self):
         return string.join([c.hash_key() for c in self.clauses], self.operator)
@@ -621,9 +622,13 @@ class TableImpl(Selectable):
     def __init__(self, table):
         self.table = table
         self.id = self.table.name
+        self.rowid_column = schema.Column(self.table.engine.rowid_column_name(), types.Integer)
+        self.rowid_column._set_parent(table)
+        del self.table.c[self.rowid_column.key]
         
     def get_from_text(self):
         return self.table.name
+    
         
     def group_parenthesized(self):
         return False
index 2498316c1d1eafe63757a91dfa8d4884a2b85ff4..1425d231bb854f83f2d1ddbfdf039c163a90146b 100644 (file)
@@ -24,19 +24,17 @@ class OrderedProperties(object):
     """
     def __init__(self):
         self.__dict__['_list'] = []
-            
     def keys(self):
         return self._list
-        
     def __iter__(self):
         return iter([self[x] for x in self._list])
-    
     def __setitem__(self, key, object):
         setattr(self, key, object)
-        
     def __getitem__(self, key):
         return getattr(self, key)
-        
+    def __delitem__(self, key):
+        delattr(self, key)
+        del self._list[self._list.index(key)]
     def __setattr__(self, key, object):
         if not hasattr(self, key):
             self._list.append(key)
index 001225f9f340d508909cf48f1596cc8717fd27b6..02b476c2d4c61f24c854f2145ee6ee7a2e09f081 100644 (file)
@@ -144,8 +144,8 @@ class LazyTest(MapperSuperTest):
         l = Item.mapper.select()
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
-            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
+            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 7}, {'keyword_id' : 5}])},
+            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 6}, {'keyword_id' : 3}, {'keyword_id' : 4}])},
             {'item_id' : 4, 'keywords' : (Keyword, [])},
             {'item_id' : 5, 'keywords' : (Keyword, [])}
         )
@@ -153,7 +153,7 @@ class LazyTest(MapperSuperTest):
         l = Item.mapper.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, Item.c.item_id==itemkeywords.c.item_id))
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
+            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 7}, {'keyword_id' : 5}])},
         )
 
 class EagerTest(MapperSuperTest):
index 852b1492d182a3d3c620ac41b99647e8f3734289..b3ae4b115feb7506fd155523a1b3b520c52e1ff1 100644 (file)
@@ -5,16 +5,10 @@ import StringIO
 import sqlalchemy.objectstore as objectstore
 import testbase
 
-echo = testbase.echo
-#testbase.echo = False
 from tables import *
+import tables
 
 
-db.connection().commit()
-
-db.echo = echo
-testbase.echo = echo
-
 class HistoryTest(AssertMixin):
     def testattr(self):
         m = mapper(User, users, properties = dict(addresses = relation(Address, addresses)))
@@ -32,18 +26,21 @@ class HistoryTest(AssertMixin):
         
 class SaveTest(AssertMixin):
 
+    def setUpAll(self):
+        db.echo = False
+        tables.create()
+        db.echo = testbase.echo
+    def tearDownAll(self):
+        db.echo = False
+        tables.drop()
+        db.echo = testbase.echo
+        
     def setUp(self):
-        e = db.echo
         db.echo = False
+        # remove all history/identity maps etc.
         objectstore.clear()
+        # remove all mapperes
         clear_mappers()
-        
-        itemkeywords.delete().execute()
-        orderitems.delete().execute()
-        orders.delete().execute()
-        addresses.delete().execute()
-        users.delete().execute()
-        keywords.delete().execute()
         keywords.insert().execute(
             dict(name='blue'),
             dict(name='red'),
@@ -53,8 +50,13 @@ class SaveTest(AssertMixin):
             dict(name='round'),
             dict(name='square')
         )
-        db.connection().commit()        
-        db.echo = e
+        db.commit()        
+        db.echo = testbase.echo
+
+    def tearDown(self):
+        db.echo = False
+        tables.delete()
+        db.echo = testbase.echo
         
     def testbasic(self):
         # save two users
index 1dee54236b46002784b2fdb8a2e6ffd89198ef8d..845077716cfa24191189707d3e39d0fb30a221bf 100644 (file)
@@ -10,16 +10,14 @@ __ALL__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'item
 
 ECHO = testbase.echo
 
-DBTYPE = 'sqlite_memory'
-#DBTYPE = 'postgres'
+#DBTYPE = 'sqlite_memory'
+DBTYPE = 'postgres'
 #DBTYPE = 'sqlite_file'
 
 if DBTYPE == 'sqlite_memory':
     db = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = testbase.echo)
 elif DBTYPE == 'sqlite_file':
     import sqlalchemy.databases.sqlite as sqllite
-#    if os.access('querytest.db', os.F_OK):
- #       os.remove('querytest.db')
     db = sqlalchemy.engine.create_engine('sqlite', 'querytest.db', {}, echo = testbase.echo)
 elif DBTYPE == 'postgres':
     db = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=testbase.echo)
@@ -67,6 +65,7 @@ def create():
     orderitems.create()
     keywords.create()
     itemkeywords.create()
+    db.commit()
     
 def drop():
     itemkeywords.drop()
@@ -75,14 +74,22 @@ def drop():
     orders.drop()
     addresses.drop()
     users.drop()
-
-def data():
+    db.commit()
+    
+def delete():
     itemkeywords.delete().execute()
     keywords.delete().execute()
     orderitems.delete().execute()
     orders.delete().execute()
     addresses.delete().execute()
     users.delete().execute()
+    
+def data():
+    delete()
+    
+    # with SQLITE, the OID column of a table defaults to the primary key, if it has one.
+    # so to database-neutrally get rows back in "insert order" based on OID, we
+    # have to also put the primary keys in order for the purpose of these tests
     users.insert().execute(
         dict(user_id = 7, user_name = 'jack'),
         dict(user_id = 8, user_name = 'ed'),
@@ -102,10 +109,10 @@ def data():
     )
     orderitems.insert().execute(
         dict(item_id=1, order_id=2, item_name='item 1'),
-        dict(item_id=3, order_id=3, item_name='item 3'),
         dict(item_id=2, order_id=2, item_name='item 2'),
+        dict(item_id=3, order_id=3, item_name='item 3'),
+        dict(item_id=4, order_id=3, item_name='item 4'),
         dict(item_id=5, order_id=3, item_name='item 5'),
-        dict(item_id=4, order_id=3, item_name='item 4')
     )
     keywords.insert().execute(
         dict(keyword_id=1, name='blue'),