]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
changed assignmapper API per [ticket:636]
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Jul 2007 16:52:00 +0000 (16:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Jul 2007 16:52:00 +0000 (16:52 +0000)
lib/sqlalchemy/ext/assignmapper.py
test/ext/assignmapper.py

index f78e464933d4144dd59da220623887bba0f1e327..03c6cc1f83da460b997ecfe19ac97b4ca0bbab17 100644 (file)
@@ -1,18 +1,8 @@
 from sqlalchemy import util, exceptions
 import types
-from sqlalchemy.orm import mapper, Query
-
-def monkeypatch_query_method(ctx, class_, name):
-    def do(self, *args, **kwargs):
-        query = Query(class_, session=ctx.current)
-        return getattr(query, name)(*args, **kwargs)
-    try:
-        do.__name__ = name
-    except:
-        pass
-    setattr(class_, name, classmethod(do))
-
-def monkeypatch_objectstore_method(ctx, class_, name):
+from sqlalchemy.orm import mapper
+    
+def _monkeypatch_session_method(name, ctx, class_):
     def do(self, *args, **kwargs):
         session = ctx.current
         return getattr(session, name)(self, *args, **kwargs)
@@ -21,31 +11,37 @@ def monkeypatch_objectstore_method(ctx, class_, name):
     except:
         pass
     setattr(class_, name, do)
-
-    
+        
 def assign_mapper(ctx, class_, *args, **kwargs):
-    validate = kwargs.pop('validate', False)
-    if not isinstance(getattr(class_, '__init__'), types.MethodType):
-        def __init__(self, **kwargs):
-            if validate:
-                keys = [p.key for p in self.mapper.iterate_properties]
-            for key, value in kwargs.items():
-                if validate and key not in keys:
-                    raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
-                setattr(self, key, value)
-        class_.__init__ = __init__
     extension = kwargs.pop('extension', None)
     if extension is not None:
         extension = util.to_list(extension)
         extension.append(ctx.mapper_extension)
     else:
         extension = ctx.mapper_extension
+
+    validate = kwargs.pop('validate', False)
+    
+    if not isinstance(getattr(class_, '__init__'), types.MethodType):
+        def __init__(self, **kwargs):
+             for key, value in kwargs.items():
+                 if validate:
+                     if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+                         raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+                 setattr(self, key, value)
+        class_.__init__ = __init__
+    
+    class query(object):
+        def __getattr__(self, key):
+            return getattr(ctx.current.query(class_), key)
+        def __call__(self):
+            return ctx.current.query(class_)
+    class_.query = query()
+    
+    for name in ['refresh', 'expire', 'delete', 'expunge', 'update']:
+        _monkeypatch_session_method(name, ctx, class_)
+
     m = mapper(class_, extension=extension, *args, **kwargs)
     class_.mapper = m
-    class_.query = classmethod(lambda cls: Query(class_, session=ctx.current))
-    for name in ['get', 'filter', 'filter_by', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', 'get_by', 'join', 'count', 'count_by', 'options', 'instances']:
-        monkeypatch_query_method(ctx, class_, name)
-    for name in ['delete', 'expire', 'refresh', 'expunge', 'save', 'update', 'save_or_update']:
-        monkeypatch_objectstore_method(ctx, class_, name)
     return m
 
index 8562a016a2489e566887d607e25364afc72c85ba..dbad8de9d14a0689fd212edf8d93284b101f8325 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy.ext.assignmapper import assign_mapper
 from sqlalchemy.ext.sessioncontext import SessionContext
 from testbase import Table, Column
 
-class OverrideAttributesTest(PersistTest):
+class AssignMapperTest(PersistTest):
     def setUpAll(self):
         global metadata, table, table2
         metadata = MetaData(testbase.db)
@@ -20,13 +20,9 @@ class OverrideAttributesTest(PersistTest):
             Column('someid', None, ForeignKey('sometable.id'))
             )
         metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-    def tearDown(self):
-        clear_mappers()
+
     def setUp(self):
-        pass
-    def test_override_attributes(self):
+        global SomeObject, SomeOtherObject, ctx
         class SomeObject(object):pass
         class SomeOtherObject(object):pass
         
@@ -35,7 +31,7 @@ class OverrideAttributesTest(PersistTest):
             'options':relation(SomeOtherObject)
         })
         assign_mapper(ctx, SomeOtherObject, table2)
-        class_mapper(SomeObject)
+
         s = SomeObject()
         s.id = 1
         s.data = 'hello'
@@ -43,8 +39,17 @@ class OverrideAttributesTest(PersistTest):
         s.options.append(sso)
         ctx.current.flush()
         ctx.current.clear()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+    def tearDown(self):
+        clear_mappers()
+
+    def test_override_attributes(self):
         
-        assert SomeObject.get_by(id=s.id).options[0].id == sso.id
+        sso = SomeOtherObject.query().first()
+        
+        assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
 
         s2 = SomeObject(someid=12)
         s3 = SomeOtherObject(someid=123, bogus=345)
@@ -58,6 +63,8 @@ class OverrideAttributesTest(PersistTest):
             assert False
         except exceptions.ArgumentError:
             pass
+    
+
         
 if __name__ == '__main__':
     testbase.main()