def in_transaction(self):
return self.__tcount > 0
- def begin(self):
+ def begin(self, **kwargs):
if self.__tcount == 0:
self.__transaction = self.get_connection()
- self.__trans = self.__transaction._begin()
+ self.__trans = self.__transaction._begin(**kwargs)
self.__tcount += 1
return self.__trans
def in_transaction(self):
return self.session.in_transaction()
- def begin(self):
- return self.session.begin()
+ def begin(self, **kwargs):
+ return self.session.begin(**kwargs)
def close(self):
if self.__opencount == 1:
return self.session.get_connection(**kwargs)
- def begin(self):
- return self.session.begin()
+ def begin(self, **kwargs):
+ return self.session.begin(**kwargs)
def commit(self):
self.session.commit()
def process_query(self, query):
query._extension = query._extension.copy()
- query._extension.append(self.ext)
+ query._extension.insert(self.ext)
class SynonymProperty(MapperProperty):
def __init__(self, name, proxy=False):
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.util import ExtensionCarrier
from sqlalchemy.orm import sync
-from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, MapperExtension, SynonymProperty
+from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, EXT_STOP, MapperExtension, SynonymProperty
deferred_load = None
__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
creates a linked list of those extensions.
"""
- extlist = util.Set()
- for ext_class in global_extensions:
- if isinstance(ext_class, MapperExtension):
- extlist.add(ext_class)
- else:
- extlist.add(ext_class())
- # local MapperExtensions have already instrumented the class
- extlist[-1].instrument_class(self, self.class_)
-
+ extlist = util.OrderedSet()
+
extension = self.extension
if extension is not None:
for ext_obj in util.to_list(extension):
+ # local MapperExtensions have already instrumented the class
extlist.add(ext_obj)
+ for ext in global_extensions:
+ if isinstance(ext_class, type):
+ ext = ext()
+ extlist.add(ext)
+ ext.instrument_class(self, self.class_)
+
+ extlist.add(_DefaultExtension())
+
self.extension = ExtensionCarrier()
for ext in extlist:
self.extension.append(ext)
else:
extension = self.extension
- ret = extension.translate_row(self, context, row)
- if ret is not EXT_CONTINUE:
- row = ret
+ row = extension.translate_row(self, context, row)
if not skip_polymorphic and self.polymorphic_on is not None:
discriminator = row[self.polymorphic_on]
identitykey = self.identity_key_from_row(row)
populate_existing = context.populate_existing or self.always_refresh
- if context.session.has_key(identitykey):
+ if identitykey in context.session.identity_map:
instance = context.session._get(identitykey)
if self.__should_log_debug:
self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
if identitykey not in context.identity_map:
context.identity_map[identitykey] = instance
isnew = True
- if extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
- self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew)
- if extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
- if result is not None:
- result.append(instance)
+ extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew)
+ extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew)
return instance
else:
if self.__should_log_debug:
# plugin point
instance = extension.create_instance(self, context, row, self.class_)
- if instance is EXT_CONTINUE:
- instance = self._create_instance(context.session)
- else:
- instance._entity_name = self.entity_name
+ instance._entity_name = self.entity_name
if self.__should_log_debug:
self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
context.identity_map[identitykey] = instance
# call further mapper properties on the row, to pull further
# instances from the row and possibly populate this item.
flags = {'instancekey':identitykey, 'isnew':isnew}
- if extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
- self.populate_instance(context, instance, row, **flags)
- if extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
- if result is not None:
- result.append(instance)
+ extension.populate_instance(self, context, row, instance, **flags)
+ extension.append_result(self, context, row, instance, result, **flags)
return instance
- def _create_instance(self, session):
- obj = self.class_.__new__(self.class_)
- obj._entity_name = self.entity_name
-
- return obj
-
def _deferred_inheritance_condition(self, needs_tables):
cond = self.inherit_condition
Mapper.logger = logging.class_logger(Mapper)
+class _DefaultExtension(MapperExtension):
+ def translate_row(self, mapper, context, row):
+ return row
+ def populate_instance(self, mapper, context, row, instance, instancekey, isnew):
+ mapper.populate_instance(context, instance, row, instancekey=instancekey, isnew=isnew)
+ return EXT_STOP
+
+ def append_result(self, mapper, context, row, instance, result, instancekey, isnew):
+ if result is not None:
+ result.append(instance)
+ return EXT_STOP
+
+ def create_instance(self, mapper, context, row, class_):
+ return class_.__new__(class_)
class ClassKey(object):
"""Key a class and an entity name to a mapper, via the mapper_registry."""
def __call__(s):
return self.context.registry().query(class_)
- if not hasattr(class_, 'query'):
+ if not 'query' in class_.__dict__:
class_.query = query()
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
"""
self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
+ self.identity_map = self.uow.identity_map
self.bind = bind
self.__binds = {}
self._unattach(instance)
echo = self.uow.echo
self.uow = unitofwork.UnitOfWork(weak_identity_map=self.weak_identity_map)
+ self.identity_map = self.uow.identity_map
self.uow.echo = echo
def bind_mapper(self, mapper, bind, entity_name=None):
try:
key = getattr(object, '_instance_key', None)
if key is None:
- merged = mapper._create_instance(self)
+ merged = mapper.class_.__new__(mapper.class_)
else:
if key in self.identity_map:
merged = self.identity_map[key]
def _get(self, key):
return self.identity_map[key]
- def has_key(self, key):
- """return True if the given identity key is present within this Session's identity map."""
-
- return key in self.identity_map
-
dirty = property(lambda s:s.uow.locate_dirty(),
doc="A ``Set`` of all objects marked as 'dirty' within this ``Session``")
new = property(lambda s:s.uow.new,
doc="A ``Set`` of all objects marked as 'new' within this ``Session``.")
- identity_map = property(lambda s:s.uow.identity_map,
- doc="A dictionary consisting of all objects "
- "within this ``Session`` keyed to their `_instance_key` value.")
-
def import_instance(self, *args, **kwargs):
"""A synynom for ``merge()``."""
false = bp(false)
owners = Table ( 'owners', dbmeta ,
- Column ( 'id', Integer, primary_key=True, nullable=False ),
- Column('data', String(30)) )
+ Column ( 'id', Integer, primary_key=True, nullable=False ),
+ Column('data', String(30)) )
categories=Table( 'categories', dbmeta,
- Column ( 'id', Integer,primary_key=True, nullable=False ),
- Column ( 'name', VARCHAR(20), index=True ) )
+ Column ( 'id', Integer,primary_key=True, nullable=False ),
+ Column ( 'name', VARCHAR(20), index=True ) )
tests = Table ( 'tests', dbmeta ,
- Column ( 'id', Integer, primary_key=True, nullable=False ),
- Column ( 'owner_id',Integer, ForeignKey('owners.id'), nullable=False,index=True ),
- Column ( 'category_id', Integer, ForeignKey('categories.id'),nullable=False,index=True ))
+ Column ( 'id', Integer, primary_key=True, nullable=False ),
+ Column ( 'owner_id',Integer, ForeignKey('owners.id'), nullable=False,index=True ),
+ Column ( 'category_id', Integer, ForeignKey('categories.id'),nullable=False,index=True ))
options = Table ( 'options', dbmeta ,
- Column ( 'test_id', Integer, ForeignKey ( 'tests.id' ), primary_key=True, nullable=False ),
- Column ( 'owner_id', Integer, ForeignKey ( 'owners.id' ), primary_key=True, nullable=False ),
- Column ( 'someoption', Boolean, PassiveDefault(str(false)), nullable=False ) )
+ Column ( 'test_id', Integer, ForeignKey ( 'tests.id' ), primary_key=True, nullable=False ),
+ Column ( 'owner_id', Integer, ForeignKey ( 'owners.id' ), primary_key=True, nullable=False ),
+ Column ( 'someoption', Boolean, PassiveDefault(str(false)), nullable=False ) )
dbmeta.create_all()
class Owner(object):
- pass
+ pass
class Category(object):
- pass
+ pass
class Test(object):
- pass
+ pass
class Option(object):
- pass
+ pass
mapper(Owner,owners)
mapper(Category,categories)
mapper(Option,options,properties={'owner':relation(Owner),'test':relation(Test)})
s.save(c)
for i in range(3):
- t=Test()
- t.owner=o
- t.category=c
- s.save(t)
- if i==1:
- op=Option()
- op.someoption=True
- t.owner_option=op
- if i==2:
- op=Option()
- t.owner_option=op
+ t=Test()
+ t.owner=o
+ t.category=c
+ s.save(t)
+ if i==1:
+ op=Option()
+ op.someoption=True
+ t.owner_option=op
+ if i==2:
+ op=Option()
+ t.owner_option=op
s.flush()
s.close()
# not orm style correct query
print "Obtaining correct results without orm"
result = select( [tests.c.id,categories.c.name],
- and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
- order_by=[tests.c.id],
- from_obj=[tests.join(categories).outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))] ).execute().fetchall()
+ and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
+ order_by=[tests.c.id],
+ from_obj=[tests.join(categories).outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))] ).execute().fetchall()
print result
assert result == [(1, u'Some Category'), (3, u'Some Category')]
def test_withouteagerload(self):
s = create_session()
l=s.query(Test).select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
- from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
+ from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
result = ["%d %s" % ( t.id,t.category.name ) for t in l]
print result
assert result == [u'1 Some Category', u'3 Some Category']
s = create_session()
q=s.query(Test).options(eagerload('category'))
l=q.select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
- from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
+ from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
result = ["%d %s" % ( t.id,t.category.name ) for t in l]
print result
assert result == [u'1 Some Category', u'3 Some Category']
s = create_session()
q=s.query(Test).options(eagerload('category'))
l=q.select( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false) & q.join_to('owner_option') )
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
+ result = ["%d %s" % ( t.id,t.category.name ) for t in l]
print result
assert result == [u'3 Some Category']
s = create_session()
q=s.query(Test).options(eagerload('category'))
l=q.select( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) & q.join_to('owner_option') )
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
+ result = ["%d %s" % ( t.id,t.category.name ) for t in l]
print result
assert result == [u'3 Some Category']
def define_tables(self, metadata):
global designType, design, part, inheritedPart
designType = Table('design_types', metadata,
- Column('design_type_id', Integer, primary_key=True),
- )
+ Column('design_type_id', Integer, primary_key=True),
+ )
design =Table('design', metadata,
- Column('design_id', Integer, primary_key=True),
- Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+ Column('design_id', Integer, primary_key=True),
+ Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
part = Table('parts', metadata,
- Column('part_id', Integer, primary_key=True),
- Column('design_id', Integer, ForeignKey('design.design_id')),
- Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+ Column('part_id', Integer, primary_key=True),
+ Column('design_id', Integer, ForeignKey('design.design_id')),
+ Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
inheritedPart = Table('inherited_part', metadata,
- Column('ip_id', Integer, primary_key=True),
- Column('part_id', Integer, ForeignKey('parts.part_id')),
- Column('design_id', Integer, ForeignKey('design.design_id')),
- )
+ Column('ip_id', Integer, primary_key=True),
+ Column('part_id', Integer, ForeignKey('parts.part_id')),
+ Column('design_id', Integer, ForeignKey('design.design_id')),
+ )
def testone(self):
class Part(object):pass
mapper(Part, part)
mapper(InheritedPart, inheritedPart, properties=dict(
- part=relation(Part, lazy=False)
+ part=relation(Part, lazy=False)
))
mapper(Design, design, properties=dict(
- parts=relation(Part, private=True, backref="design"),
- inheritedParts=relation(InheritedPart, private=True, backref="design"),
+ parts=relation(Part, private=True, backref="design"),
+ inheritedParts=relation(InheritedPart, private=True, backref="design"),
))
mapper(DesignType, designType, properties=dict(
- # designs=relation(Design, private=True, backref="type"),
+ # designs=relation(Design, private=True, backref="type"),
))
class_mapper(Design).add_property("type", relation(DesignType, lazy=False, backref="designs"))
from sqlalchemy import ansisql
from sqlalchemy.orm import *
from testlib import *
-from fixtures import *
+from testlib.fixtures import *
from query import QueryTest
from sqlalchemy import *
from sqlalchemy.orm import *
from testlib import *
-from fixtures import *
+from testlib.fixtures import *
from query import QueryTest
class EagerTest(QueryTest):
from sqlalchemy import *
from sqlalchemy.orm import *
from testlib import *
+from testlib.fixtures import Base
-# TODO: refactor "fixtures" to be part of testlib, so Base is globally available
-_recursion_stack = util.Set()
-class Base(object):
- def __init__(self, **kwargs):
- for k in kwargs:
- setattr(self, k, kwargs[k])
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __eq__(self, other):
- """'passively' compare this object to another.
-
- only look at attributes that are present on the source object.
-
- """
- if self in _recursion_stack:
- return True
- _recursion_stack.add(self)
- try:
- # use __dict__ to avoid instrumented properties
- for attr in self.__dict__.keys():
- if attr[0] == '_':
- continue
- value = getattr(self, attr)
- if hasattr(value, '__iter__') and not isinstance(value, basestring):
- try:
- # catch AttributeError so that lazy loaders trigger
- otherattr = getattr(other, attr)
- except AttributeError:
- return False
- if len(value) != len(getattr(other, attr)):
- return False
- for (us, them) in zip(value, getattr(other, attr)):
- if us != them:
- return False
- else:
- continue
- else:
- if value is not None:
- print "KEY", attr, "COMPARING", value, "TO", getattr(other, attr, None)
- if value != getattr(other, attr, None):
- return False
- else:
- return True
- finally:
- _recursion_stack.remove(self)
class InheritingSelectablesTest(ORMTest):
def define_tables(self, metadata):
from sqlalchemy import *
from sqlalchemy.orm import *
from testlib import *
-from fixtures import *
+from testlib.fixtures import *
from query import QueryTest
class LazyTest(QueryTest):
from sqlalchemy import ansisql
from sqlalchemy.orm import *
from testlib import *
-from fixtures import *
+from testlib.fixtures import *
class QueryTest(FixtureTest):
keep_mappers = True
from sqlalchemy import *
from sqlalchemy.orm import *
from testlib import *
-from fixtures import *
+from testlib.fixtures import *
from query import QueryTest
class SelectableNoFromsTest(ORMTest):
from testlib import *
from testlib.tables import *
import testlib.tables as tables
-import fixtures
+from testlib import fixtures
class SessionTest(AssertMixin):
def setUpAll(self):
tables.create()
+
def tearDownAll(self):
tables.drop()
+
def tearDown(self):
SessionCls.close_all()
tables.delete()
clear_mappers()
+
def setUp(self):
pass
Session.mapper(MyClass, table2)
assert MyClass().expunge() == "an expunge !"
+
+class ScopedMapperTest2(ORMTest):
+ def define_tables(self, metadata):
+ global table, table2
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)),
+ Column('type', String(30))
+
+ )
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id')),
+ Column('somedata', String(30)),
+ )
+ def test_inheritance(self):
+ def expunge_list(l):
+ for x in l:
+ Session.expunge(x)
+ return l
+
+ class BaseClass(fixtures.Base):
+ pass
+ class SubClass(BaseClass):
+ pass
+
+ Session = scoped_session(sessionmaker())
+ Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
+ Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
+
+ b = BaseClass(data='b1')
+ s = SubClass(data='s1', somedata='somedata')
+ Session.commit()
+ Session.clear()
+
+ assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
+ assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+
+
if __name__ == "__main__":
testbase.main()
from testlib import *
from testlib.tables import *
from testlib import tables
-import fixtures
+from testlib import fixtures
"""tests unitofwork operations"""
only look at attributes that are present on the source object.
"""
+
if self in _recursion_stack:
return True
_recursion_stack.add(self)
# catch AttributeError so that lazy loaders trigger
otherattr = getattr(other, attr)
except AttributeError:
+ print "Other class does not have attribute named '%s'" % attr
return False
if len(value) != len(getattr(other, attr)):
- return False
+ print "Length of collection '%s' does not match that of other" % attr
+ return False
for (us, them) in zip(value, getattr(other, attr)):
if us != them:
+ print "1. Attribute named '%s' does not match other" % attr
return False
else:
continue
else:
if value is not None:
- print "KEY", attr, "COMPARING", value, "TO", getattr(other, attr, None)
if value != getattr(other, attr, None):
+ print "2. Attribute named '%s' does not match that of other" % attr
return False
else:
return True