]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Use io instead of codecs and StringIO
authorHong Minhee <minhee@dahlia.kr>
Fri, 12 Apr 2013 19:40:54 +0000 (04:40 +0900)
committerHong Minhee <minhee@dahlia.kr>
Fri, 12 Apr 2013 19:40:54 +0000 (04:40 +0900)
alembic/ddl/impl.py
alembic/migration.py
tests/__init__.py
tests/test_sql_script.py

index 1c44534f6a03bb35d64797099fc1b6a13192d60a..af93349c4863c3ad9a04cecb4ff5f94b856f04f1 100644 (file)
@@ -54,7 +54,10 @@ class DefaultImpl(ImplMeta('_ImplBase', (object,), {})):
         return _impls[dialect.name]
 
     def static_output(self, text):
-        self.output_buffer.write(text + "\n\n")
+        text_ = getattr(builtins, 'unicode', str)(text + '\n\n')
+        self.output_buffer.write(text_)
+        if callable(getattr(self.output_buffer, 'flush', None)):
+            self.output_buffer.flush()
 
     @property
     def bind(self):
index 82fee9cad4b2d2d940d6c864f22d115f330c15a3..bf8e9380826e25da2db08fdd615958e40eeaf1f6 100644 (file)
@@ -1,4 +1,4 @@
-import codecs
+import io
 import logging
 import sys
 
@@ -71,10 +71,12 @@ class MigrationContext(object):
         self._migrations_fn = opts.get('fn')
         self.as_sql = as_sql
         self.output_buffer = opts.get("output_buffer", sys.stdout)
-        if opts.get('output_encoding'):
-            self.output_buffer = codecs.getwriter(
+        if (opts.get('output_encoding') and
+            not isinstance(self.output_buffer, io.TextIOBase)):
+            self.output_buffer = io.TextIOWrapper(
+                                    self.output_buffer,
                                     opts['output_encoding']
-                                )(self.output_buffer)
+                                )
 
         self._user_compare_type = opts.get('compare_type', False)
         self._user_compare_server_default = opts.get(
index b0edf75ff949534a55786d674bc54e8f5e6f205a..37f24d33538065de6d8cc0a4f395766f86b08d5f 100644 (file)
@@ -12,7 +12,6 @@ import io
 import os
 import re
 import shutil
-import StringIO
 import textwrap
 
 from nose import SkipTest
@@ -105,7 +104,12 @@ def assert_compiled(element, assert_string, dialect=None):
     )
 
 def capture_context_buffer(**kw):
-    buf = StringIO.StringIO()
+    if kw.pop('bytes_io', False):
+        raw = io.BytesIO()
+        encoding = kw.get('output_encoding', 'utf-8')
+        buf = io.TextIOWrapper(raw, encoding)
+    else:
+        raw = buf = io.StringIO()
 
     class capture(object):
         def __enter__(self):
@@ -114,7 +118,7 @@ def capture_context_buffer(**kw):
                 'output_buffer':buf
             }
             EnvironmentContext._default_opts.update(kw)
-            return buf
+            return raw
 
         def __exit__(self, *arg, **kwarg):
             #print(buf.getvalue())
index 6cea44ce682c5d7c0b716305c4c3e0cc4aaf4c5d..d351151376b441a862e3611770adbe6c639e72f7 100644 (file)
@@ -111,6 +111,9 @@ def downgrade():
         clear_staging_env()
 
     def test_encode(self):
-        with capture_context_buffer(output_encoding='utf-8') as buf:
+        with capture_context_buffer(
+                    bytes_io=True,
+                    output_encoding='utf-8'
+                ) as buf:
             command.upgrade(cfg, a, sql=True)
         assert "« S’il vous plaît…".encode("utf-8") in buf.getvalue()