from test.lib.testing import assert_raises, assert_raises_message
import sqlalchemy as sa
-from sqlalchemy import Integer, PickleType, String
+from sqlalchemy import Integer, PickleType, String, ForeignKey
import operator
from test.lib import testing
from sqlalchemy.util import OrderedSet
from sqlalchemy.orm import mapper, relationship, create_session, PropComparator, \
- synonym, comparable_property, sessionmaker, attributes
+ synonym, comparable_property, sessionmaker, attributes,\
+ Session, backref, configure_mappers
from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.orm.interfaces import MapperOption
from test.lib.testing import eq_, ne_
from test.lib import fixtures
from test.orm import _fixtures
-from sqlalchemy import event
+from sqlalchemy import event, and_
from test.lib.schema import Table, Column
class MergeTest(_fixtures.FixtureTest):
eq_(ustate.load_options, set([opt2]))
+class M2ONoUseGetLoadingTest(fixtures.MappedTest):
+ """Merge a one-to-many. The many-to-one on the other side is set up
+ so that use_get is False. See if skipping the "m2o" merge
+ vs. doing it saves on SQL calls.
+
+ """
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('user', metadata,
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('name', String(50)),
+ )
+ Table('address', metadata,
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('user_id', Integer, ForeignKey('user.id')),
+ Column('email', String(50)),
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class User(cls.Comparable):
+ pass
+ class Address(cls.Comparable):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ User, Address = cls.classes.User, cls.classes.Address
+ user, address = cls.tables.user, cls.tables.address
+ mapper(User, user, properties={
+ 'addresses':relationship(Address, backref=
+ backref('user',
+ # needlessly complex primaryjoin so that the
+ # use_get flag is False
+ primaryjoin=and_(
+ user.c.id==address.c.user_id,
+ user.c.id==user.c.id
+ )
+ )
+ )
+ })
+ mapper(Address, address)
+ configure_mappers()
+ assert Address.user.property._use_get is False
+
+ @classmethod
+ def insert_data(cls):
+ User, Address = cls.classes.User, cls.classes.Address
+ s = Session()
+ s.add_all([
+ User(id=1, name='u1', addresses=[Address(id=1, email='a1'),
+ Address(id=2, email='a2')])
+ ])
+ s.commit()
+
+ # "persistent" - we get at an Address that was already present.
+ # With the "skip bidirectional" check removed, the "set" emits SQL
+ # for the "previous" version in any case,
+ # address.user_id is 1, you get a load.
+ def test_persistent_access_none(self):
+ User, Address = self.classes.User, self.classes.Address
+ s = Session()
+ def go():
+ u1 = User(id=1,
+ addresses =[Address(id=1), Address(id=2)]
+ )
+ u2 = s.merge(u1)
+ self.assert_sql_count(testing.db, go, 2)
+
+ def test_persistent_access_one(self):
+ User, Address = self.classes.User, self.classes.Address
+ s = Session()
+ def go():
+ u1 = User(id=1,
+ addresses =[Address(id=1), Address(id=2)]
+ )
+ u2 = s.merge(u1)
+ a1 = u2.addresses[0]
+ assert a1.user is u2
+ self.assert_sql_count(testing.db, go, 3)
+
+ def test_persistent_access_two(self):
+ User, Address = self.classes.User, self.classes.Address
+ s = Session()
+ def go():
+ u1 = User(id=1,
+ addresses =[Address(id=1), Address(id=2)]
+ )
+ u2 = s.merge(u1)
+ a1 = u2.addresses[0]
+ assert a1.user is u2
+ a2 = u2.addresses[1]
+ assert a2.user is u2
+ self.assert_sql_count(testing.db, go, 4)
+
+ # "pending" - we get at an Address that is new- user_id should be
+ # None. But in this case the set attribute on the forward side
+ # already sets the backref. commenting out the "skip bidirectional"
+ # check emits SQL again for the other two Address objects already
+ # persistent.
+ def test_pending_access_one(self):
+ User, Address = self.classes.User, self.classes.Address
+ s = Session()
+ def go():
+ u1 = User(id=1,
+ addresses =[Address(id=1), Address(id=2),
+ Address(id=3, email='a3')]
+ )
+ u2 = s.merge(u1)
+ a3 = u2.addresses[2]
+ assert a3.user is u2
+ self.assert_sql_count(testing.db, go, 3)
+
+ def test_pending_access_two(self):
+ User, Address = self.classes.User, self.classes.Address
+ s = Session()
+ def go():
+ u1 = User(id=1,
+ addresses =[Address(id=1), Address(id=2),
+ Address(id=3, email='a3')]
+ )
+ u2 = s.merge(u1)
+ a3 = u2.addresses[2]
+ assert a3.user is u2
+ a2 = u2.addresses[1]
+ assert a2.user is u2
+ self.assert_sql_count(testing.db, go, 5)
+
class MutableMergeTest(fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):