]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dev
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Aug 2005 01:27:19 +0000 (01:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Aug 2005 01:27:19 +0000 (01:27 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapper.py
test/mapper.py

index 42014ca21df287074684ae6d985136467d8d21be..e0dcc58bd7d772a5fa81ba1c0a80052255a38b12 100644 (file)
@@ -58,6 +58,7 @@ class ANSICompiler(sql.Compiled):
         self.froms = {}
         self.wheres = {}
         self.strings = {}
+        self.isinsert = False
         
     def get_from_text(self, obj):
         return self.froms[obj]
@@ -200,6 +201,7 @@ class ANSICompiler(sql.Compiled):
             " ON " + self.get_str(join.onclause))
 
     def visit_insert(self, insert_stmt):
+        self.isinsert = True
         colparams = insert_stmt.get_colparams(self.bindparams)
         for c in colparams:
             b = c[1]
index fa5124ed11d4932b6e589ff8c2b71a4f41dcb5fa..f6d56d207ece94ef86f89dd39c081b7aed2a1578 100644 (file)
@@ -52,15 +52,22 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         statement.accept_visitor(compiler)
         return compiler
 
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+
     def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
+            last_inserted_ids = []
             for primary_key in compiled.statement.table.primary_keys:
                 # pseudocode
                 if echo is True or self._echo:
                     self.log(primary_key.sequence.text)
                 res = cursor.execute(primary_key.sequence.text)
-                parameters[primary_key.key] = res.fetchrow()[0]
+                newid = res.fetchrow()[0]
+                parameters[primary_key.key] = newid
+                last_inserted_ids.append(newid)
+            self.context.last_inserted_ids = last_inserted_ids
 
     def dbapi(self):
         return None
@@ -73,10 +80,8 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         raise NotImplementedError()
 
 class PGCompiler(ansisql.ANSICompiler):
-    def visit_insert(self, insert):
-        self.isinsert = True
-        super(self).visit_insert(insert)
-    
+    pass
+
 class PGColumnImpl(sql.ColumnSelectable):
     def get_specification(self):
         coltype = self.column.type
index 315374a6d4aa640d4a23f7c1a71248629230452d..475f687378c58bad5c4cf572eb8f5ed9dfe12049 100644 (file)
@@ -49,8 +49,13 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
         self.opts = opts or {}
         ansisql.ANSISQLEngine.__init__(self, **params)
 
+    def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+        if compiled is None: return
+        if getattr(compiled, "isinsert", False):
+            self.context.last_inserted_ids = [cursor.lastrowid]
+
     def last_inserted_ids(self):
-        pass
+        return self.context.last_inserted_ids
 
     def connect_args(self):
         return ([self.filename], self.opts)
@@ -81,5 +86,7 @@ class SQLiteColumnImpl(sql.ColumnSelectable):
         else:
             key = coltype.__class__
 
-        return self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)}
-
+        colspec = self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)}
+        if self.column.primary_key:
+            colspec += " PRIMARY KEY"
+        return colspec
index 22cc0343488545cd1722ded0aca9b4ac8c76f9aa..e6183c328fbcc6ebfef92a3829057cea15b4c893 100644 (file)
@@ -150,8 +150,9 @@ class SQLEngine(schema.SchemaEngine):
 
 
 class ResultProxy:
-    def __init__(self, cursor):
+    def __init__(self, cursor, echo = False):
         self.cursor = cursor
+        self.echo = echo
         metadata = cursor.description
         self.props = {}
         i = 0
@@ -164,7 +165,7 @@ class ResultProxy:
     def fetchone(self):
         row = self.cursor.fetchone()
         if row is not None:
-            #print repr(row)
+            if self.echo: print repr(row)
             return RowProxy(self, row)
         else:
             return None
index 4d506f8e2335a94b8defa900fb7e0df8040c60fd..d8b6caf465796dde4e71250f7977d9d397a7c1f4 100644 (file)
@@ -52,7 +52,7 @@ def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, se
 _mappers = {}
 def mapper(*args, **params):
     hashkey = mapper_hash_key(*args, **params)
-    print "HASHKEY: " + hashkey
+    #print "HASHKEY: " + hashkey
     try:
         return _mappers[hashkey]
     except KeyError:
@@ -121,7 +121,7 @@ class Mapper(object):
 
     def instances(self, cursor):
         result = []
-        cursor = engine.ResultProxy(cursor)
+        cursor = engine.ResultProxy(cursor, echo = self.echo)
 
         localmap = {}
         while True:
@@ -131,9 +131,30 @@ class Mapper(object):
             self._instance(row, localmap, result)
         return result
 
-    def get(self, id):
-        """returns an instance of the object based on the given ID."""
-        pass
+    def get(self, *ident):
+        """returns an instance of the object based on the given identifier, or None
+        if not found.  The *ident argument is a 
+        list of primary keys in the order of the table def's primary keys."""
+        key = self.identitymap.get_id_key(ident, self.class_, self.table, self.selectable)
+        try:
+            return self.identitymap[key]
+        except KeyError:
+            clause = sql.and_()
+            i = 0
+            for primary_key in self.selectable.primary_keys:
+                # appending to the and_'s clause list directly to skip
+                # typechecks etc.
+                clause.clauses.append(primary_key == ident[i])
+                i += 2
+            try:
+                return self.select(clause)[0]
+            except IndexError:
+                return None
+
+    def put(self, instance):
+        key = self.identitymap.get_instance_key(instance, self.class_, self.table, self.selectable)
+        self.identitymap[key] = instance
+        return key
 
     def compile(self, whereclause = None, **options):
         """works like select, except returns the SQL statement object without 
@@ -188,29 +209,45 @@ class Mapper(object):
         """removes the object.  traverse indicates attached objects should be removed as well."""
         pass
     
-    def insert(self, object):
+    def insert(self, obj):
         """inserts the object into its table, regardless of primary key being set.  this is a 
         lower-level operation than save."""
         params = {}
         for col in self.table.columns:
-            params[col.key] = getattr(object, col.key)
+            params[col.key] = getattr(obj, col.key, None)
         ins = self.table.insert()
         ins.execute(**params)
+
+        # TODO: unset dirty flag
+
+        # populate new primary keys
         primary_keys = self.table.engine.last_inserted_ids()
-        # TODO: put the primary keys into the object props
+        index = 0
+        for pk in self.table.primary_keys:
+            newid = primary_keys[index]
+            index += 1
+            # TODO: do this via the ColumnProperty objects
+            setattr(obj, pk.key, newid)
 
-    def update(self, object):
+        self.put(obj)
+
+    def update(self, obj):
         """inserts the object into its table, regardless of primary key being set.  this is a 
         lower-level operation than save."""
         params = {}
         for col in self.table.columns:
-            params[col.key] = getattr(object, col.key)
+            params[col.key] = getattr(obj, col.key)
         upd = self.table.update()
         upd.execute(**params)
-        
-    def delete(self, object):
+        # TODO: unset dirty flag
+
+    def delete(self, obj):
         """deletes the object's row from its table unconditionally. this is a lower-level
         operation than remove."""
+        # delete dependencies ?
+        # delete row
+        # remove primary keys
+        # unset dirty flag
         pass
 
     class TableFinder(sql.ClauseVisitor):
@@ -234,7 +271,7 @@ class Mapper(object):
     def _select_whereclause(self, whereclause = None, **params):
         statement = self._compile(whereclause)
         return self._select_statement(statement, **params)
-    
+
     def _select_statement(self, statement, **params):
         statement.use_labels = True
         statement.echo = self.echo
@@ -243,12 +280,13 @@ class Mapper(object):
     def _identity_key(self, row):
         return self.identitymap.get_key(row, self.class_, self.table, self.selectable)
 
+
     def _instance(self, row, localmap, result):
         """pulls an object instance from the given row and appends it to the given result list.
         if the instance already exists in the given identity map, its not added.  in either
         case, executes all the property loaders on the instance to also process extra information
         in the row."""
-            
+
         # create the instance if its not in the identity map,
         # else retrieve it
         identitykey = self._identity_key(row)
@@ -323,7 +361,7 @@ class ColumnProperty(MapperProperty):
 
     def hash_key(self):
         return "ColumnProperty(%s)" % hash_key(self.column)
-        
+
     def init(self, key, parent, root):
         self.key = key
         if root.use_smart_properties:
@@ -363,8 +401,6 @@ def mapper_hash_key(class_, selectable, table = None, properties = None, identit
         )
     )
 
-
-        
 class PropertyLoader(MapperProperty):
     def __init__(self, mapper, secondary, primaryjoin, secondaryjoin):
         self.mapper = mapper
@@ -373,10 +409,10 @@ class PropertyLoader(MapperProperty):
         self.primaryjoin = primaryjoin
         self.secondaryjoin = secondaryjoin
         self._hash_key = "%s(%s, %s, %s, %s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin))
-        
+
     def hash_key(self):
         return self._hash_key
-        
+
     def init(self, key, parent, root):
         self.key = key
         self.mapper.init(root)
@@ -397,12 +433,15 @@ class PropertyLoader(MapperProperty):
         # if a mapping table exists, determine the two foreign key columns 
         # in the mapping table, set the two values, and insert that row, for
         # each row in the list
-        pass
+        if self.secondary is None:
+            self.mapper.save(object)
+        else:
+            # TODO: crap, we dont have a simple version of what object props/cols match to which
+            pass
 
     def delete(self):
         self.mapper.delete()
 
-        
 class LazyLoader(PropertyLoader):
 
     def init(self, key, parent, root):
@@ -424,14 +463,22 @@ class LazyLoader(PropertyLoader):
 
     def execute(self, instance, row, identitykey, localmap, isduplicate):
         if not isduplicate:
-            def load():
-                m = {}
-                for key, value in self.binds.iteritems():
-                    m[key] = row[key]
-                return self.mapper.select(self.lazywhere, **m)
+            setattr(instance, self.key, LazyLoadInstance(self, row))
+
+class LazyLoadInstance(object):
+    """attached to a specific object instance to load related rows.  this is implemetned
+    as a callable object, rather than a closure, to allow serialization of the target object"""
+    def __init__(self, lazyloader, row):
+        self.params = {}
+        for key, value in lazyloader.binds.iteritems():
+            self.params[key] = row[key]
+        # TODO: dont attach to the mapper, its huge.
+        # figure out some way to shrink this.
+        self.mapper = lazyloader.mapper
+
+    def __call__(self):
+        return self.mapper.select(self.lazywhere, **self.params)
 
-            setattr(instance, self.key, load)
-        
 class EagerLoader(PropertyLoader):
     def init(self, key, parent, root):
         PropertyLoader.init(self, key, parent, root)
@@ -440,8 +487,7 @@ class EagerLoader(PropertyLoader):
         if self.secondaryjoin is not None:
             [self.to_alias.append(f) for f in self.secondaryjoin._get_from_objects()]
         del self.to_alias[parent.selectable]
-    
-            
+
     def setup(self, key, primarytable, statement, **options):
         """add a left outer join to the statement thats being constructed"""
 
@@ -565,10 +611,14 @@ def match_primaries(primary, secondary):
         return sql.and_([pk == secondary.c[pk.name] for pk in primary.primary_keys])
 
 class IdentityMap(dict):
+    def get_id_key(self, ident, class_, table, selectable):
+        return (class_, table, tuple(ident))
+    def get_instance_key(self, object, class_, table, selectable):
+        return (class_, table, tuple([getattr(object, column.key, None) for column in selectable.primary_keys]))
     def get_key(self, row, class_, table, selectable):
         return (class_, table, tuple([row[column.label] for column in selectable.primary_keys]))
     def hash_key(self):
         return "IdentityMap(%s)" % id(self)
-        
+
 _global_identitymap = IdentityMap()
 
index ee5a2d4591c39a0fc3ab3444e42326568fc5023a..1e2eabd92686b86c42fe9b09680668f879540c91 100644 (file)
@@ -7,6 +7,7 @@ class User(object):
     def __repr__(self):
         return (
 """
+objid: %d
 User ID: %s
 User Name: %s
 Addresses: %s
@@ -14,7 +15,7 @@ Orders: %s
 Open Orders %s
 Closed Orderss %s
 ------------------
-""" % tuple([self.user_id, repr(self.user_name)] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')])
+""" % tuple([id(self), self.user_id, repr(self.user_name)] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')])
 )
 
 class Address(object):
@@ -52,7 +53,14 @@ class MapperTest(AssertMixin):
     
     def setUp(self):
         globalidentity().clear()
-        
+
+    def testget(self):
+        m = mapper(User, users, echo = True)
+        self.assert_(m.get(19) is None)
+        u = m.get(7)
+        u2 = m.get(7)
+        self.assert_(u is u2)
+
     def testload(self):
         """tests loading rows with a mapper and producing object instances"""
         m = mapper(User, users)
@@ -67,7 +75,7 @@ class MapperTest(AssertMixin):
             addresses = relation(Address, addresses, lazy = True)
         ), echo = True)
         l = m.options(eagerload('addresses')).select()
-        self.assert_result(l, User, 
+        self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
             {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])},
             {'user_id' : 9, 'addresses' : (Address, [])}
@@ -79,7 +87,7 @@ class MapperTest(AssertMixin):
             addresses = relation(Address, addresses, lazy = False)
         ), echo = True)
         l = m.options(lazyload('addresses')).select()
-        self.assert_result(l, User, 
+        self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
             {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])},
             {'user_id' : 9, 'addresses' : (Address, [])}
@@ -216,13 +224,14 @@ class EagerTest(PersistTest):
         print repr(l)
 
 class SaveTest(PersistTest):
-    def _testinsert(self):
+        
+    def testinsert(self):
         u = User()
         u.user_name = 'inserttester'
-        m = mapper(User, users)
+        m = mapper(User, users, echo=True)
         m.insert(u)
-
-        nu = m.select(users.c.user_id == u.user_id)
+        nu = m.get(u.user_id)
+    #    nu = m.select(users.c.user_id == u.user_id)[0]
         self.assert_(u is nu)
 
 if __name__ == "__main__":