]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2005 07:17:57 +0000 (07:17 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2005 07:17:57 +0000 (07:17 +0000)
lib/sqlalchemy/mapper.py
test/mapper.py

index 6e2f62d3abd082db50c0dc2eea85b64c67c49f78..8be2c4e98fb226e284585cdecc027a221a517931 100644 (file)
@@ -1,10 +1,5 @@
 """
-# create a mapper from a class and table object
-usermapper = Mapper(User, users)
-
-
-# get primary key
-usermapper.get(10)
+usermapper = mapper(User, users)
 
 userlist = usermapper.select(usermapper.table.user_id == 10)
 
@@ -14,34 +9,46 @@ userlist = usermapper.select(
 
 userlist = usermapper.select("user_id =12 and foo=bar", from_obj=["foo"])
 
-usermapper = Mapper(
+addressmapper = mapper(Address, addresses)
+
+usermapper = mapper(
     User, 
     users, 
     properties = {
-        'addresses' : Relation(addressmapper, lazy = False),
-        'permissions' : Relation(permissions, 
-        
-                # one or the other
-                associationtable = userpermissions, 
-                criterion = and_(users.user_id == userpermissions.user_id, userpermissions.permission_id=permissions.permission_id), 
-                lazy = True),
-        '*' : [users, userinfo]
+        'addresses' : eagerloader(addressmapper, users.c.user_id == addresses.c.user_id),
+        'permissions' : lazymapper(Permissions, permissions, users.c.user_id == permissions.c.user_id)
     },
     )
 
-addressmapper = Mapper(Address, addresses, properties = {
-    'street': addresses.address_1,
-})
 """
 
 import sqlalchemy.sql as sql
 import sqlalchemy.schema as schema
 import sqlalchemy.engine as engine
+import weakref
+
+__ALL__ = ['eagermapper', 'eagerloader', 'mapper', 'lazyloader', 'lazymapper', 'identitymap', 'globalidentity']
+
+def eagermapper(class_, selectable, whereclause, table = None, properties = None):
+    return eagerloader(mapper(class_, selectable, table = table, properties = properties, isroot = False), whereclause)
+
+def eagerloader(mapper, whereclause):
+    return EagerLoader(mapper, whereclause)
+
+def mapper(class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True):
+    return Mapper(class_, selectable, table = table, properties = properties, identitymap = identitymap, use_smart_properties = use_smart_properties, isroot = isroot)
+
+def identitymap():
+    return IdentityMap()
+
+def globalidentity():
+    return _global_identitymap
 
 class Mapper(object):
-    def __init__(self, class_, selectable, table = None, properties = None, identitymap = None):
+    def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True):
         self.class_ = class_
         self.selectable = selectable
+        self.use_smart_properties = use_smart_properties
         if table is None:
             self.table = self._find_table(selectable)
         else:
@@ -51,14 +58,21 @@ class Mapper(object):
         for column in self.selectable.columns:
             self.props[column.key] = ColumnProperty(column)
 
-        if properties is not None:
-            for key, value in properties.iteritems():
-                self.props[key] = value
-                
         if identitymap is not None:
             self.identitymap = identitymap
         else:
             self.identitymap = _global_identitymap
+
+        if properties is not None:
+            for key, value in properties.iteritems():
+                self.props[key] = value
+
+        if isroot:
+            self.init(self)
+    
+    def init(self, root):
+        self.identitymap = root.identitymap
+        [prop.init(key, self, root) for key, prop in self.props.iteritems()]
             
     def instances(self, cursor):
         result = []
@@ -92,10 +106,29 @@ class Mapper(object):
         else:
             return self._select_whereclause(arg, **params)
         
-    def save(self, object):
+    def save(self, object, traverse = True, refetch = False):
+        """saves the object.  based on the existence of its primary key, either inserts or updates.
+        primary key is determined by the underlying database engine's sequence methodology.
+        traverse indicates attached objects should be saved as well."""
+        pass
+    
+    def remove(self, object, traverse = True):
+        """removes the object.  traverse indicates attached objects should be removed as well."""
+        pass
+    
+    def insert(self, object):
+        """inserts the object into its table, regardless of primary key being set.  this is a 
+        lower-level operation than save."""
+        pass
+
+    def update(self, object):
+        """inserts the object into its table, regardless of primary key being set.  this is a 
+        lower-level operation than save."""
         pass
         
-    def delete(self, whereclause = None, **params):
+    def delete(self, object):
+        """deletes the object's row from its table unconditionally. this is a lower-level
+        operation than remove."""
         pass
 
     class TableFinder(sql.ClauseVisitor):
@@ -159,25 +192,44 @@ class Mapper(object):
 
 
 
-
 class MapperProperty:
     def execute(self, instance, key, row, isduplicate):
+        """called when the mapper receives a row.  instance is the parent instance corresponding
+        to the row. """
         raise NotImplementedError()
     def setup(self, key, primarytable, statement):
+        """called when a statement is being constructed."""
+        pass
+    def init(self, key, parent, root):
+        """called when the MapperProperty is first attached to a new parent Mapper."""
         pass
 
 class ColumnProperty(MapperProperty):
     def __init__(self, column):
         self.column = column
-        
+
+    def init(self, key, parent, root):
+        if root.use_smart_properties:
+            self.use_smart = True
+            if not hasattr(parent.class_, key):
+                setattr(parent.class_, key, SmartProperty(key).property())
+        else:
+            self.use_smart = False
+            
     def execute(self, instance, key, row, identitykey, localmap, isduplicate):
         if not isduplicate:
+            if self.use_smart:
+                key = "_" + key
             setattr(instance, key, row[self.column.label])
 
+    
 class EagerLoader(MapperProperty):
     def __init__(self, mapper, whereclause):
         self.mapper = mapper
         self.whereclause = whereclause
+    
+    def init(self, key, parent, root):
+        self.mapper.init(root)
         
     def setup(self, key, primarytable, statement):
         """add a left outer join to the statement thats being constructed"""
@@ -192,7 +244,7 @@ class EagerLoader(MapperProperty):
             value.setup(key, self.mapper.selectable, statement) 
         
     def execute(self, instance, key, row, identitykey, localmap, isduplicate):
-        """a row.  tell our mapper to look for a new object instance in the row, and attach
+        """receive a row.  tell our mapper to look for a new object instance in the row, and attach
         it to a list on the parent instance."""
         try:
             list = getattr(instance, key)
@@ -201,7 +253,14 @@ class EagerLoader(MapperProperty):
             setattr(instance, key, list)
         self.mapper._instance(row, localmap, list)
 
-
+class SmartProperty(object):
+    def __init__(self, key):
+        self.key = key
+    def property(self):
+        def set_property(self, value):
+            setattr(self, "_" + self.key, value)
+            self.dirty = True
+        return property(lambda s: getattr(s, "_" + self.key), set_property)
         
 class IdentityMap(dict):
     def get_key(self, row, class_, table, selectable):
@@ -209,5 +268,3 @@ class IdentityMap(dict):
         
 _global_identitymap = IdentityMap()
 
-def clear_identity():
-    _global_identitymap.clear()
index 48b4c5109f6db09cfe168d115837e09d28527807..6ba0262034bd7e039057be716b73e085f3ce76c9 100644 (file)
@@ -9,8 +9,7 @@ db = sqllite.engine('querytest.db', echo = True)
 
 from sqlalchemy.sql import *
 from sqlalchemy.schema import *
-
-import sqlalchemy.mapper as mapper
+from sqlalchemy.mapper import *
 
 users = Table('users', db,
     Column('user_id', INT, primary_key = True),
@@ -124,47 +123,47 @@ class Keyword:
 class MapperTest(PersistTest):
     
     def setUp(self):
-        mapper.clear_identity()
+        globalidentity().clear()
     
         
     def testmapper(self):
-        m = mapper.Mapper(User, users)
+        m = mapper(User, users)
         l = m.select()
         print repr(l)
 
     def testeager(self):
-        m = mapper.Mapper(User, users, properties = dict(
-            addresses = mapper.EagerLoader(mapper.Mapper(Address, addresses), users.c.user_id==addresses.c.user_id)
+        m = mapper(User, users, properties = dict(
+            addresses = eagermapper(Address, addresses, users.c.user_id==addresses.c.user_id)
         ))
         l = m.select()
         print repr(l)
 
     def testmultieager(self):
-        m = mapper.Mapper(User, users, properties = dict(
-            addresses = mapper.EagerLoader(mapper.Mapper(Address, addresses), users.c.user_id==addresses.c.user_id),
-            orders = mapper.EagerLoader(mapper.Mapper(Order, orders), users.c.user_id==orders.c.user_id),
-        ), identitymap = mapper.IdentityMap())
+        m = mapper(User, users, properties = dict(
+            addresses = eagermapper(Address, addresses, users.c.user_id==addresses.c.user_id),
+            orders = eagermapper(Order, orders, users.c.user_id==orders.c.user_id),
+        ), identitymap = identitymap())
         l = m.select()
         print repr(l)
 
     def testdoubleeager(self):
         openorders = alias(orders, 'openorders')
         closedorders = alias(orders, 'closedorders')
-        m = mapper.Mapper(User, users, properties = dict(
-            orders_open = mapper.EagerLoader(mapper.Mapper(Order, openorders), and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id)),
-            orders_closed = mapper.EagerLoader(mapper.Mapper(Order, closedorders), and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id))
-        ), identitymap = mapper.IdentityMap())
+        m = mapper(User, users, properties = dict(
+            orders_open = eagermapper(Order, openorders, and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id)),
+            orders_closed = eagermapper(Order, closedorders, and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id))
+        ), identitymap = identitymap())
         l = m.select()
         print repr(l)
 
     def testnestedeager(self):
-        ordermapper = mapper.Mapper(Order, orders, properties = dict(
-                items = mapper.EagerLoader(mapper.Mapper(Item, orderitems), orders.c.order_id == orderitems.c.order_id)
+        ordermapper = mapper(Order, orders, properties = dict(
+                items = eagermapper(Item, orderitems, orders.c.order_id == orderitems.c.order_id)
             ))
 
-        m = mapper.Mapper(User, users, properties = dict(
-            addresses = mapper.EagerLoader(mapper.Mapper(Address, addresses), users.c.user_id==addresses.c.user_id),
-            orders = mapper.EagerLoader(ordermapper, users.c.user_id==orders.c.user_id),
+        m = mapper(User, users, properties = dict(
+            addresses = eagermapper(Address, addresses, users.c.user_id==addresses.c.user_id),
+            orders = eagerloader(ordermapper, users.c.user_id==orders.c.user_id),
         ))
         l = m.select()
         print repr(l)
@@ -172,8 +171,8 @@ class MapperTest(PersistTest):
     def testmanytomanyeager(self):
         items = orderitems
         
-        m = mapper.Mapper(Item, items, properties = dict(
-                keywords = mapper.EagerLoader(mapper.Mapper(Keyword, keywords),
+        m = mapper(Item, items, properties = dict(
+                keywords = eagermapper(Keyword, keywords,
                     and_(items.c.item_id == itemkeywords.c.item_id, keywords.c.keyword_id == itemkeywords.c.keyword_id))
             ))
         l = m.select()