from sqlalchemy import util, exceptions
import types
-from sqlalchemy.orm import mapper, Query
-
-def monkeypatch_query_method(ctx, class_, name):
- def do(self, *args, **kwargs):
- query = Query(class_, session=ctx.current)
- return getattr(query, name)(*args, **kwargs)
- try:
- do.__name__ = name
- except:
- pass
- setattr(class_, name, classmethod(do))
-
-def monkeypatch_objectstore_method(ctx, class_, name):
+from sqlalchemy.orm import mapper
+
+def _monkeypatch_session_method(name, ctx, class_):
def do(self, *args, **kwargs):
session = ctx.current
return getattr(session, name)(self, *args, **kwargs)
except:
pass
setattr(class_, name, do)
-
-
+
def assign_mapper(ctx, class_, *args, **kwargs):
- validate = kwargs.pop('validate', False)
- if not isinstance(getattr(class_, '__init__'), types.MethodType):
- def __init__(self, **kwargs):
- if validate:
- keys = [p.key for p in self.mapper.iterate_properties]
- for key, value in kwargs.items():
- if validate and key not in keys:
- raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
- setattr(self, key, value)
- class_.__init__ = __init__
extension = kwargs.pop('extension', None)
if extension is not None:
extension = util.to_list(extension)
extension.append(ctx.mapper_extension)
else:
extension = ctx.mapper_extension
+
+ validate = kwargs.pop('validate', False)
+
+ if not isinstance(getattr(class_, '__init__'), types.MethodType):
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ if validate:
+ if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+ raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ setattr(self, key, value)
+ class_.__init__ = __init__
+
+ class query(object):
+ def __getattr__(self, key):
+ return getattr(ctx.current.query(class_), key)
+ def __call__(self):
+ return ctx.current.query(class_)
+ class_.query = query()
+
+ for name in ['refresh', 'expire', 'delete', 'expunge', 'update']:
+ _monkeypatch_session_method(name, ctx, class_)
+
m = mapper(class_, extension=extension, *args, **kwargs)
class_.mapper = m
- class_.query = classmethod(lambda cls: Query(class_, session=ctx.current))
- for name in ['get', 'filter', 'filter_by', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', 'get_by', 'join', 'count', 'count_by', 'options', 'instances']:
- monkeypatch_query_method(ctx, class_, name)
- for name in ['delete', 'expire', 'refresh', 'expunge', 'save', 'update', 'save_or_update']:
- monkeypatch_objectstore_method(ctx, class_, name)
return m
from sqlalchemy.ext.sessioncontext import SessionContext
from testbase import Table, Column
-class OverrideAttributesTest(PersistTest):
+class AssignMapperTest(PersistTest):
def setUpAll(self):
global metadata, table, table2
metadata = MetaData(testbase.db)
Column('someid', None, ForeignKey('sometable.id'))
)
metadata.create_all()
- def tearDownAll(self):
- metadata.drop_all()
- def tearDown(self):
- clear_mappers()
+
def setUp(self):
- pass
- def test_override_attributes(self):
+ global SomeObject, SomeOtherObject, ctx
class SomeObject(object):pass
class SomeOtherObject(object):pass
'options':relation(SomeOtherObject)
})
assign_mapper(ctx, SomeOtherObject, table2)
- class_mapper(SomeObject)
+
s = SomeObject()
s.id = 1
s.data = 'hello'
s.options.append(sso)
ctx.current.flush()
ctx.current.clear()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+ def tearDown(self):
+ clear_mappers()
+
+ def test_override_attributes(self):
- assert SomeObject.get_by(id=s.id).options[0].id == sso.id
+ sso = SomeOtherObject.query().first()
+
+ assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
s2 = SomeObject(someid=12)
s3 = SomeOtherObject(someid=123, bogus=345)
assert False
except exceptions.ArgumentError:
pass
+
+
if __name__ == '__main__':
testbase.main()