]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
add a bulk insert feature. probably needs some work
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Apr 2011 16:11:20 +0000 (12:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Apr 2011 16:11:20 +0000 (12:11 -0400)
alembic/context.py
alembic/ddl/__init__.py
alembic/ddl/mssql.py [new file with mode: 0644]
alembic/op.py
tests/__init__.py
tests/test_bulk_insert.py [new file with mode: 0644]
tests/test_op.py

index f4bb9cb617d70e938c704ed27735cefb665793f0..9bac6063c16c71735833ea06d254cc3f5dfff3f1 100644 (file)
@@ -3,6 +3,8 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \
     text
 from sqlalchemy import schema, create_engine
 from sqlalchemy.util import importlater
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql.expression import _BindParamClause
 
 import logging
 base = importlater("alembic.ddl", "base")
@@ -82,15 +84,21 @@ class DefaultContext(object):
         if self.as_sql and self.transactional_ddl:
             print "COMMIT;\n"
 
-    def _exec(self, construct):
+    def _exec(self, construct, *args, **kw):
         if isinstance(construct, basestring):
             construct = text(construct)
         if self.as_sql:
+            if args or kw:
+                raise Exception("Execution arguments not allowed with as_sql")
             print unicode(
-                    construct.compile(dialect=self.connection.dialect)
+                    construct.compile(dialect=self.dialect)
                     ).replace("\t", "    ") + ";"
         else:
-            self.connection.execute(construct)
+            self.connection.execute(construct, *args, **kw)
+
+    @property
+    def dialect(self):
+        return self.connection.dialect
 
     def execute(self, sql):
         self._exec(sql)
@@ -156,6 +164,23 @@ class DefaultContext(object):
     def drop_table(self, table):
         self._exec(schema.DropTable(table))
 
+    def bulk_insert(self, table, rows):
+        if self.as_sql:
+            for row in rows:
+                self._exec(table.insert().values(**dict(
+                    (k, _literal_bindparam(k, v, type_=table.c[k].type))
+                    for k, v in row.items()
+                )))
+        else:
+            self._exec(table.insert(), *rows)
+
+class _literal_bindparam(_BindParamClause):
+    pass
+
+@compiles(_literal_bindparam)
+def _render_literal_bindparam(element, compiler, **kw):
+    return compiler.render_literal_bindparam(element, **kw)
+
 def opts(cfg, **kw):
     global _context_opts, config
     _context_opts = kw
index 07b063dd21d20bb94afae809ad141bcdc5832763..7efc90cb5259702f32439dc5dca53fb66fade080 100644 (file)
@@ -1 +1 @@
-import postgresql, mysql, sqlite
\ No newline at end of file
+import postgresql, mysql, sqlite, mssql
\ No newline at end of file
diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py
new file mode 100644 (file)
index 0000000..d79e619
--- /dev/null
@@ -0,0 +1,19 @@
+from alembic.context import DefaultContext
+
+class MSSQLContext(DefaultContext):
+    __dialect__ = 'mssql'
+    transactional_ddl = True
+
+    def bulk_insert(self, table, rows):
+        if self.as_sql:
+            self._exec(
+                "SET IDENTITY_INSERT %s ON" % 
+                    self.dialect.identifier_preparer.format_table(table)
+            )
+            super(MSSQLContext, self).bulk_insert(table, rows)
+            self._exec(
+                "SET IDENTITY_INSERT %s OFF" % 
+                    self.dialect.identifier_preparer.format_table(table)
+            )
+        else:
+            super(MSSQLContext, self).bulk_insert(table, rows)
\ No newline at end of file
index a12aa6449adfc0fd2e338734f0d79253f37f65c2..aaefb09aece848070269edd8b8a8d6af44900d21 100644 (file)
@@ -137,6 +137,9 @@ def drop_table(name, *columns, **kw):
         _table(name, *columns, **kw)
     )
 
+def bulk_insert(table, rows):
+    get_context().bulk_insert(table, rows)
+
 def execute(sql):
     get_context().execute(sql)
 
index 3cd79ddf47a990de022ed6dd2ca77aea013cddbb..17788f66122b5f2a3b26ccb13757c97c7d8e8228 100644 (file)
@@ -1,21 +1,26 @@
-from sqlalchemy.util import defaultdict
 from sqlalchemy.engine import url, default
 import shutil
 import os
 import itertools
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, text
 from alembic import context
 import re
+from alembic.context import _context_impls
+from alembic import ddl
 
 staging_directory = os.path.join(os.path.dirname(__file__), 'scratch')
 
-_dialects = defaultdict(lambda name:url.URL(drivername).get_dialect()())
+_dialects = {}
 def _get_dialect(name):
-    if name is None:
+    if name is None or name == 'default':
         return default.DefaultDialect()
     else:
-        return _dialects[name]
-
+        try:
+            return _dialects[name]
+        except KeyError:
+            dialect_mod = getattr(__import__('sqlalchemy.dialects.%s' % name).dialects, name)
+            _dialects[name] = d = dialect_mod.dialect()
+            return d
 
 def assert_compiled(element, assert_string, dialect=None):
     dialect = _get_dialect(dialect)
@@ -39,23 +44,34 @@ def _testing_config():
         os.mkdir(staging_directory)
     return Config(os.path.join(staging_directory, 'test_alembic.ini'))
 
-class _op_fixture(context.DefaultContext):
-    def __init__(self):
-        # TODO: accept dialect here.
-        context._context = self
-        self.assertion = []
-
-    def _exec(self, construct):
-        sql = unicode(construct.compile())
-        sql = re.sub(r'[\n\t]', '', sql)
-        self.assertion.append(
-            sql
-        )
-
-    def assert_(self, *sql):
-        # TODO: make this more flexible about 
-        # whitespace and such
-        eq_(self.assertion, list(sql))
+def _op_fixture(dialect='default', as_sql=False):
+    _base = _context_impls[dialect]
+    class ctx(_base):
+        def __init__(self, dialect='default', as_sql=False):
+            self._dialect = _get_dialect(dialect)
+
+            context._context = self
+            self.as_sql = as_sql
+            self.assertion = []
+
+        @property
+        def dialect(self):
+            return self._dialect
+
+        def _exec(self, construct, *args, **kw):
+            if isinstance(construct, basestring):
+                construct = text(construct)
+            sql = unicode(construct.compile(dialect=self.dialect))
+            sql = re.sub(r'[\n\t]', '', sql)
+            self.assertion.append(
+                sql
+            )
+
+        def assert_(self, *sql):
+            # TODO: make this more flexible about 
+            # whitespace and such
+            eq_(self.assertion, list(sql))
+    return ctx(dialect, as_sql)
 
 def _sqlite_testing_config():
     cfg = _testing_config()
diff --git a/tests/test_bulk_insert.py b/tests/test_bulk_insert.py
new file mode 100644 (file)
index 0000000..be13602
--- /dev/null
@@ -0,0 +1,88 @@
+from tests import _op_fixture
+from alembic import op
+from sqlalchemy import Integer, Column, ForeignKey, \
+            UniqueConstraint, Table, MetaData, String
+from sqlalchemy.sql import table
+
+def _test_bulk_insert(dialect, as_sql):
+    context = _op_fixture(dialect, as_sql)
+    t1 = table("ins_table",
+                Column('id', Integer, primary_key=True),
+                Column('v1', String()),
+                Column('v2', String()),
+    )
+    op.bulk_insert(t1, [
+        {'id':1, 'v1':'row v1', 'v2':'row v5'},
+        {'id':2, 'v1':'row v2', 'v2':'row v6'},
+        {'id':3, 'v1':'row v3', 'v2':'row v7'},
+        {'id':4, 'v1':'row v4', 'v2':'row v8'},
+    ])
+    return context
+
+def test_bulk_insert():
+    context = _test_bulk_insert('default', False)
+    context.assert_(
+        'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
+    )
+
+def test_bulk_insert_wrong_cols():
+    context = _op_fixture('postgresql')
+    t1 = Table("ins_table", MetaData(),
+                Column('id', Integer, primary_key=True),
+                Column('v1', String()),
+                Column('v2', String()),
+    )
+    op.bulk_insert(t1, [
+        {'v1':'row v1', },
+    ])
+    # TODO: this is wrong because the test fixture isn't actually 
+    # doing what the real context would do.   Sending this to 
+    # PG is going to produce a RETURNING clause.  fixture would
+    # need to be beefed up
+    context.assert_(
+        'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)'
+    )
+
+def test_bulk_insert_pg():
+    context = _test_bulk_insert('postgresql', False)
+    context.assert_(
+        'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)'
+    )
+
+def test_bulk_insert_mssql():
+    context = _test_bulk_insert('mssql', False)
+    context.assert_(
+        'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
+    )
+
+def test_bulk_insert_as_sql():
+    context = _test_bulk_insert('default', True)
+    context.assert_(
+        "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')", 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (2, 'row v2', 'row v6')", 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (3, 'row v3', 'row v7')",
+        "INSERT INTO ins_table (id, v1, v2) VALUES (4, 'row v4', 'row v8')"
+    )
+
+def test_bulk_insert_as_sql_pg():
+    context = _test_bulk_insert('postgresql', True)
+    context.assert_(
+        "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')", 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (2, 'row v2', 'row v6')", 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (3, 'row v3', 'row v7')",
+        "INSERT INTO ins_table (id, v1, v2) VALUES (4, 'row v4', 'row v8')"
+    )
+
+def test_bulk_insert_as_sql_mssql():
+    context = _test_bulk_insert('mssql', True)
+    # SQL server requires IDENTITY_INSERT
+    # TODO: figure out if this is safe to enable for a table that 
+    # doesn't have an IDENTITY column
+    context.assert_(
+        'SET IDENTITY_INSERT ins_table ON', 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')", 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (2, 'row v2', 'row v6')", 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (3, 'row v3', 'row v7')", 
+        "INSERT INTO ins_table (id, v1, v2) VALUES (4, 'row v4', 'row v8')", 
+        'SET IDENTITY_INSERT ins_table OFF'
+    )
index 2ffd8ee6d758ac6b27debb9d1367d8bac50f4036..987b25b331be7af531e4d4fa9bb7a1b8f20d596f 100644 (file)
@@ -3,7 +3,8 @@
 from tests import _op_fixture
 from alembic import op
 from sqlalchemy import Integer, Column, ForeignKey, \
-            UniqueConstraint, Table, MetaData
+            UniqueConstraint, Table, MetaData, String
+from sqlalchemy.sql import table
 
 def test_add_column():
     context = _op_fixture()
@@ -116,3 +117,4 @@ def test_create_table_two_fk():
             "FOREIGN KEY(foo_bar) REFERENCES foo (bar))"
     )
 
+