]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dev
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Aug 2014 21:44:58 +0000 (17:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 19:53:12 +0000 (15:53 -0400)
doc/build/faq.rst
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/unitofwork.py

index 3dc81026b1ad6c9b8a7c84f934f1d4dd6cc33ccb..b777f908fadf8d32469013b86f4753e19bb4669b 100644 (file)
@@ -907,10 +907,12 @@ methods of inserting rows, going from the most automated to the least.
 With cPython 2.7, runtimes observed::
 
     classics-MacBook-Pro:sqlalchemy classic$ python test.py
-    SQLAlchemy ORM: Total time for 100000 records 14.3528850079 secs
-    SQLAlchemy ORM pk given: Total time for 100000 records 10.0164160728 secs
-    SQLAlchemy Core: Total time for 100000 records 0.775382995605 secs
-    sqlite3: Total time for 100000 records 0.676795005798 sec
+    SQLAlchemy ORM: Total time for 100000 records 12.4703581333 secs
+    SQLAlchemy ORM pk given: Total time for 100000 records 7.32723999023 secs
+    SQLAlchemy ORM bulk_save_objects(): Total time for 100000 records 3.43464708328 secs
+    SQLAlchemy ORM bulk_save_mappings(): Total time for 100000 records 2.37040805817 secs
+    SQLAlchemy Core: Total time for 100000 records 0.495043992996 secs
+    sqlite3: Total time for 100000 records 0.508063077927 sec
 
 We can reduce the time by a factor of three using recent versions of `Pypy <http://pypy.org/>`_::
 
@@ -933,11 +935,13 @@ Script::
     DBSession = scoped_session(sessionmaker())
     engine = None
 
+
     class Customer(Base):
         __tablename__ = "customer"
         id = Column(Integer, primary_key=True)
         name = Column(String(255))
 
+
     def init_sqlalchemy(dbname='sqlite:///sqlalchemy.db'):
         global engine
         engine = create_engine(dbname, echo=False)
@@ -946,69 +950,114 @@ Script::
         Base.metadata.drop_all(engine)
         Base.metadata.create_all(engine)
 
+
     def test_sqlalchemy_orm(n=100000):
         init_sqlalchemy()
         t0 = time.time()
-        for i in range(n):
+        for i in xrange(n):
             customer = Customer()
             customer.name = 'NAME ' + str(i)
             DBSession.add(customer)
             if i % 1000 == 0:
                 DBSession.flush()
         DBSession.commit()
-        print("SQLAlchemy ORM: Total time for " + str(n) +
-                    " records " + str(time.time() - t0) + " secs")
+        print(
+            "SQLAlchemy ORM: Total time for " + str(n) +
+            " records " + str(time.time() - t0) + " secs")
+
 
     def test_sqlalchemy_orm_pk_given(n=100000):
         init_sqlalchemy()
         t0 = time.time()
-        for i in range(n):
+        for i in xrange(n):
             customer = Customer(id=i+1, name="NAME " + str(i))
             DBSession.add(customer)
             if i % 1000 == 0:
                 DBSession.flush()
         DBSession.commit()
-        print("SQLAlchemy ORM pk given: Total time for " + str(n) +
+        print(
+            "SQLAlchemy ORM pk given: Total time for " + str(n) +
+            " records " + str(time.time() - t0) + " secs")
+
+
+    def test_sqlalchemy_orm_bulk_save(n=100000):
+        init_sqlalchemy()
+        t0 = time.time()
+        n1 = n
+        while n1 > 0:
+            n1 = n1 - 10000
+            DBSession.bulk_save_objects(
+                [
+                    Customer(name="NAME " + str(i))
+                    for i in xrange(min(10000, n1))
+                ]
+            )
+        DBSession.commit()
+        print(
+            "SQLAlchemy ORM bulk_save_objects(): Total time for " + str(n) +
             " records " + str(time.time() - t0) + " secs")
 
+
+    def test_sqlalchemy_orm_bulk_save_mappings(n=100000):
+        init_sqlalchemy()
+        t0 = time.time()
+        DBSession.bulk_save_mappings(
+            Customer,
+            [
+                dict(name="NAME " + str(i))
+                for i in xrange(n)
+            ]
+        )
+        DBSession.commit()
+        print(
+            "SQLAlchemy ORM bulk_save_mappings(): Total time for " + str(n) +
+            " records " + str(time.time() - t0) + " secs")
+
+
     def test_sqlalchemy_core(n=100000):
         init_sqlalchemy()
         t0 = time.time()
         engine.execute(
             Customer.__table__.insert(),
-            [{"name": 'NAME ' + str(i)} for i in range(n)]
+            [{"name": 'NAME ' + str(i)} for i in xrange(n)]
         )
-        print("SQLAlchemy Core: Total time for " + str(n) +
+        print(
+            "SQLAlchemy Core: Total time for " + str(n) +
             " records " + str(time.time() - t0) + " secs")
 
+
     def init_sqlite3(dbname):
         conn = sqlite3.connect(dbname)
         c = conn.cursor()
         c.execute("DROP TABLE IF EXISTS customer")
-        c.execute("CREATE TABLE customer (id INTEGER NOT NULL, "
-                                    "name VARCHAR(255), PRIMARY KEY(id))")
+        c.execute(
+            "CREATE TABLE customer (id INTEGER NOT NULL, "
+            "name VARCHAR(255), PRIMARY KEY(id))")
         conn.commit()
         return conn
 
+
     def test_sqlite3(n=100000, dbname='sqlite3.db'):
         conn = init_sqlite3(dbname)
         c = conn.cursor()
         t0 = time.time()
-        for i in range(n):
+        for i in xrange(n):
             row = ('NAME ' + str(i),)
             c.execute("INSERT INTO customer (name) VALUES (?)", row)
         conn.commit()
-        print("sqlite3: Total time for " + str(n) +
+        print(
+            "sqlite3: Total time for " + str(n) +
             " records " + str(time.time() - t0) + " sec")
 
     if __name__ == '__main__':
         test_sqlalchemy_orm(100000)
         test_sqlalchemy_orm_pk_given(100000)
+        test_sqlalchemy_orm_bulk_save(100000)
+        test_sqlalchemy_orm_bulk_save_mappings(100000)
         test_sqlalchemy_core(100000)
         test_sqlite3(100000)
 
 
-
 Sessions / Queries
 ===================
 
index 511a324bee2e7540e25725877ba5ebbd4f12cc61..64c8440c4d6b60207e968efa997f22572fa47eef 100644 (file)
@@ -18,7 +18,7 @@ import operator
 from itertools import groupby
 from .. import sql, util, exc as sa_exc, schema
 from . import attributes, sync, exc as orm_exc, evaluator
-from .base import _state_mapper, state_str, _attr_as_key
+from .base import state_str, _attr_as_key
 from ..sql import expression
 from . import loading
 
@@ -65,7 +65,8 @@ def save_obj(
         if insert:
             _emit_insert_statements(base_mapper, uowtransaction,
                                     cached_connections,
-                                    mapper, table, insert)
+                                    mapper, table, insert,
+                                    bookkeeping)
 
     _finalize_insert_update_commands(base_mapper, uowtransaction,
                                      states_to_insert, states_to_update,
@@ -140,13 +141,16 @@ def _organize_states_for_save(
 
     states_to_insert = []
     states_to_update = []
+    instance_key = None
 
     for state, dict_, mapper, connection in _connections_for_states(
             base_mapper, uowtransaction,
             states):
 
         has_identity = bool(state.key)
-        instance_key = state.key or mapper._identity_key_from_state(state)
+
+        if bookkeeping:
+            instance_key = state.key or mapper._identity_key_from_state(state)
 
         row_switch = None
 
@@ -188,12 +192,12 @@ def _organize_states_for_save(
         if not has_identity and not row_switch:
             states_to_insert.append(
                 (state, dict_, mapper, connection,
-                 has_identity, instance_key, row_switch)
+                 has_identity, row_switch)
             )
         else:
             states_to_update.append(
                 (state, dict_, mapper, connection,
-                 has_identity, instance_key, row_switch)
+                 has_identity, row_switch)
             )
 
     return states_to_insert, states_to_update
@@ -242,7 +246,8 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
     """
     insert = []
     for state, state_dict, mapper, connection, has_identity, \
-            instance_key, row_switch in states_to_insert:
+            row_switch in states_to_insert:
+
         if table not in mapper._pks_by_table:
             continue
 
@@ -265,13 +270,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
                 prop = mapper._columntoproperty[col]
                 value = state_dict.get(prop.key, None)
 
-                if value is None:
-                    if bookkeeping and col in pks:
+                if bookkeeping and value is None:
+                    if col in pks:
                         has_all_pks = False
                     elif col.default is None and \
                             col.server_default is None:
                         params[col.key] = value
-                    elif bookkeeping and col.server_default is not None and \
+                    elif col.server_default is not None and \
                             mapper.base_mapper.eager_defaults:
                         has_all_defaults = False
 
@@ -301,7 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
 
     update = []
     for state, state_dict, mapper, connection, has_identity, \
-            instance_key, row_switch in states_to_update:
+            row_switch in states_to_update:
         if table not in mapper._pks_by_table:
             continue
 
@@ -567,7 +572,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
 
 
 def _emit_insert_statements(base_mapper, uowtransaction,
-                            cached_connections, mapper, table, insert):
+                            cached_connections, mapper, table, insert,
+                            bookkeeping):
     """Emit INSERT statements corresponding to value lists collected
     by _collect_insert_commands()."""
 
@@ -593,19 +599,20 @@ def _emit_insert_statements(base_mapper, uowtransaction,
             c = cached_connections[connection].\
                 execute(statement, multiparams)
 
-            for (state, state_dict, params, mapper_rec,
-                    conn, value_params, has_all_pks, has_all_defaults), \
-                    last_inserted_params in \
-                    zip(records, c.context.compiled_parameters):
-                _postfetch(
-                    mapper_rec,
-                    uowtransaction,
-                    table,
-                    state,
-                    state_dict,
-                    c,
-                    last_inserted_params,
-                    value_params)
+            if bookkeeping:
+                for (state, state_dict, params, mapper_rec,
+                        conn, value_params, has_all_pks, has_all_defaults), \
+                        last_inserted_params in \
+                        zip(records, c.context.compiled_parameters):
+                    _postfetch(
+                        mapper_rec,
+                        uowtransaction,
+                        table,
+                        state,
+                        state_dict,
+                        c,
+                        last_inserted_params,
+                        value_params)
 
         else:
             if not has_all_defaults and base_mapper.eager_defaults:
@@ -768,7 +775,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
 
     """
     for state, state_dict, mapper, connection, has_identity, \
-            instance_key, row_switch in states_to_insert + \
+            row_switch in states_to_insert + \
             states_to_update:
 
         if bookkeeping:
@@ -871,7 +878,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
         if connection_callable:
             connection = connection_callable(base_mapper, state.obj())
 
-        mapper = _state_mapper(state)
+        mapper = state.manager.mapper
 
         yield state, state.dict, mapper, connection
 
index 2455c803a7ec5785563be43c3550b1f826cc1dbb..546355611a581bca526ae853f68f5cb9d99e2d78 100644 (file)
@@ -482,7 +482,7 @@ class Session(_SessionClassMethods):
         '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested',
         'close', 'commit', 'connection', 'delete', 'execute', 'expire',
         'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind',
-        'is_modified',
+        'is_modified', 'bulk_save_objects', 'bulk_save_mappings',
         'merge', 'query', 'refresh', 'rollback',
         'scalar')
 
@@ -2033,31 +2033,42 @@ class Session(_SessionClassMethods):
             with util.safe_reraise():
                 transaction.rollback(_capture_exception=True)
 
-    def bulk_save(self, objects):
+    def bulk_save_objects(self, objects):
+        self._bulk_save((attributes.instance_state(obj) for obj in objects))
+
+    def bulk_save_mappings(self, mapper, mappings):
+        mapper = class_mapper(mapper)
+
+        self._bulk_save((
+            statelib.MappingState(mapper, mapping)
+            for mapping in mappings)
+        )
+
+    def _bulk_save(self, states):
         self._flushing = True
         flush_context = UOWTransaction(self)
 
         if self.dispatch.before_bulk_save:
             self.dispatch.before_bulk_save(
-                self, flush_context, objects)
+                self, flush_context, states)
 
         flush_context.transaction = transaction = self.begin(
             subtransactions=True)
         try:
             self._warn_on_events = True
             try:
-                flush_context.bulk_save(objects)
+                flush_context.bulk_save(states)
             finally:
                 self._warn_on_events = False
 
             self.dispatch.after_bulk_save(
-                self, flush_context, objects
+                self, flush_context, states
             )
 
             flush_context.finalize_flush_changes()
 
             self.dispatch.after_bulk_save_postexec(
-                self, flush_context, objects)
+                self, flush_context, states)
 
             transaction.commit()
 
index fe8ccd222b4eec4cc9b5a5d8698ccb8e78e9e4c9..e941bc1a4713ef074979e37dce43212cdd3be6a5 100644 (file)
@@ -580,6 +580,21 @@ class InstanceState(interfaces.InspectionAttr):
             state._strong_obj = None
 
 
+class MappingState(InstanceState):
+    committed_state = {}
+    callables = {}
+
+    def __init__(self, mapper, mapping):
+        self.class_ = mapper.class_
+        self.manager = mapper.class_manager
+        self.modified = True
+        self._dict = mapping
+
+    @property
+    def dict(self):
+        return self._dict
+
+
 class AttributeState(object):
     """Provide an inspection interface corresponding
     to a particular attribute on a particular mapped object.
index 8df24e95ac3540016e162b44c89d947bd70c287d..bc8a0f5565048bb195bcf601aa10f07d93ca6dce 100644 (file)
@@ -394,9 +394,9 @@ class UOWTransaction(object):
         if other:
             self.session._register_newly_persistent(other)
 
-    def bulk_save(self, objects):
-        for (base_mapper, in_session), states in itertools.groupby(
-                (attributes.instance_state(obj) for obj in objects),
+    def bulk_save(self, states):
+        for (base_mapper, in_session), states_ in itertools.groupby(
+                states,
                 lambda state:
                 (
                     state.mapper.base_mapper,
@@ -404,12 +404,12 @@ class UOWTransaction(object):
                 )):
 
             persistence.save_obj(
-                base_mapper, list(states), self, bookkeeping=in_session)
+                base_mapper, list(states_), self, bookkeeping=in_session)
 
             if in_session:
                 self.states.update(
                     (state, (False, False))
-                    for state in states
+                    for state in states_
                 )