]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- Ensured that strings going to stdout go through an encode/decode phase,
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Aug 2013 17:25:31 +0000 (13:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Aug 2013 17:25:31 +0000 (13:25 -0400)
so that any non-ASCII characters get to the output stream correctly
in both Py2k and Py3k.   Also added source encoding detection using
Mako's parse_encoding() routine in Py2k so that the __doc__ of a
non-ascii revision file can be treated as unicode in Py2k.

alembic/__init__.py
alembic/compat.py
alembic/config.py
alembic/script.py
alembic/util.py
docs/build/changelog.rst
tests/__init__.py
tests/test_command.py
tests/test_config.py

index 7467b27332ce62e822dcaca0e6c74ae9f1f97310..b50a0d56d65f0e6824903c1ab5d0a9fb06216001 100644 (file)
@@ -1,6 +1,6 @@
 from os import path
 
-__version__ = '0.6.0'
+__version__ = '0.6.1'
 
 package_dir = path.abspath(path.dirname(__file__))
 
index ffc2d9900167efcedb440f1f9c1a08d3702d471d..11dde867c18aab55f88157cf0250e65f9f2c2a9b 100644 (file)
@@ -3,6 +3,7 @@ import sys
 if sys.version_info < (2, 6):
     raise NotImplementedError("Python 2.6 or greater is required.")
 
+py2k = sys.version_info < (3, 0)
 py3k = sys.version_info >= (3, 0)
 py33 = sys.version_info >= (3, 3)
 
@@ -13,6 +14,10 @@ if py3k:
     text_type = str
     def callable(fn):
         return hasattr(fn, '__call__')
+
+    def u(s):
+        return s
+
 else:
     import __builtin__ as compat_builtins
     string_types = basestring,
@@ -20,6 +25,9 @@ else:
     text_type = unicode
     callable = callable
 
+    def u(s):
+        return unicode(s, "utf-8")
+
 if py3k:
     from configparser import ConfigParser as SafeConfigParser
     import configparser
@@ -27,6 +35,9 @@ else:
     from ConfigParser import SafeConfigParser
     import ConfigParser as configparser
 
+if py2k:
+    from mako.util import parse_encoding
+
 if py33:
     from importlib import machinery
     def load_module(module_id, path):
@@ -36,11 +47,17 @@ else:
     def load_module(module_id, path):
         fp = open(path, 'rb')
         try:
-            return imp.load_source(module_id, path, fp)
+            mod = imp.load_source(module_id, path, fp)
+            if py2k:
+                source_encoding = parse_encoding(fp)
+                if source_encoding:
+                    mod._alembic_source_encoding = source_encoding
+            return mod
         finally:
             fp.close()
 
 
+
 try:
     exec_ = getattr(compat_builtins, 'exec')
 except AttributeError:
index da5eee2dc906b040c57b0b3b65bbd90e673fecb5..86ff1df77aeb5a769df4189499dd1574302b566d 100644 (file)
@@ -89,7 +89,11 @@ class Config(object):
     def print_stdout(self, text, *arg):
         """Render a message to standard out."""
 
-        self.stdout.write((compat.text_type(text) % arg) + "\n")
+        util.write_outstream(
+                self.stdout,
+                (compat.text_type(text) % arg),
+                "\n"
+        )
 
     @util.memoized_property
     def file_config(self):
index f0903266f0534edb767e8373e041a50027f83232..77366e30650f47043436c82397d1d0edec131811 100644 (file)
@@ -385,6 +385,8 @@ class Script(object):
 
         doc = self.module.__doc__
         if doc:
+            if hasattr(self.module, "_alembic_source_encoding"):
+                doc = doc.decode(self.module._alembic_source_encoding)
             return doc.strip()
         else:
             return ""
index 3c227b74e406eb5eb9b619f52e06e2591d554e42..bd2f03586d9f73f677eeef38f4dabfe51c4c94a1 100644 (file)
@@ -10,7 +10,7 @@ from mako.template import Template
 from sqlalchemy.engine import url
 from sqlalchemy import __version__
 
-from .compat import callable, exec_, load_module
+from .compat import callable, exec_, load_module, binary_type
 
 class CommandError(Exception):
     pass
@@ -123,6 +123,14 @@ def create_module_class_proxy(cls, globals_, locals_):
             else:
                 attr_names.add(methname)
 
+def write_outstream(stream, *text):
+    encoding = getattr(stream, 'encoding', 'ascii') or 'ascii'
+    for t in text:
+        if not isinstance(t, binary_type):
+            t = t.encode(encoding, errors='replace')
+        t = t.decode(encoding)
+        stream.write(t)
+
 def coerce_resource_to_filename(fname):
     """Interpret a filename as either a filesystem location or as a package resource.
 
@@ -139,10 +147,10 @@ def status(_statmsg, fn, *arg, **kw):
     msg(_statmsg + "...", False)
     try:
         ret = fn(*arg, **kw)
-        sys.stdout.write("done\n")
+        write_outstream(sys.stdout, "done\n")
         return ret
     except:
-        sys.stdout.write("FAILED\n")
+        write_outstream(sys.stdout, "FAILED\n")
         raise
 
 def err(message):
@@ -166,8 +174,8 @@ def msg(msg, newline=True):
     lines = textwrap.wrap(msg, width)
     if len(lines) > 1:
         for line in lines[0:-1]:
-            sys.stdout.write("  " + line + "\n")
-    sys.stdout.write("  " + lines[-1] + ("\n" if newline else ""))
+            write_outstream(sys.stdout, "  ", line, "\n")
+    write_outstream(sys.stdout, "  ", lines[-1], ("\n" if newline else ""))
 
 def load_python_file(dir_, filename):
     """Load a file from the given path as a Python module."""
index 5313567fde039a1928a3062ee2bc00d415f4f1d7..02b2b607848ec4e9058354a1417ca9041cb55f8a 100644 (file)
@@ -4,6 +4,19 @@ Changelog
 ==========
 
 .. changelog::
+    :version: 0.6.1
+    :released: no release date
+
+    .. change::
+      :tags: bug
+      :tickets: 137
+
+      Ensured that strings going to stdout go through an encode/decode phase,
+      so that any non-ASCII characters get to the output stream correctly
+      in both Py2k and Py3k.   Also added source encoding detection using
+      Mako's parse_encoding() routine in Py2k so that the __doc__ of a
+      non-ascii revision file can be treated as unicode in Py2k.
+
     :version: 0.6.0
     :released: Fri July 19 2013
 
index a07c0ae542869e5dbfd64014bf854e21cc24d5a0..2fb34da3cde5edcc202e3c9e526c11449b85d27b 100644 (file)
@@ -1,4 +1,4 @@
-
+# coding: utf-8
 import io
 import os
 import re
@@ -14,7 +14,7 @@ from sqlalchemy.util import decorator
 
 import alembic
 from alembic import util
-from alembic.compat import string_types, text_type
+from alembic.compat import string_types, text_type, u
 from alembic.migration import MigrationContext
 from alembic.environment import EnvironmentContext
 from alembic.operations import Operations
@@ -350,8 +350,8 @@ def downgrade():
 """ % a)
 
     script.generate_revision(b, "revision b", refresh=True)
-    write_script(script, b, """\
-"Rev B"
+    write_script(script, b, u("""# coding: utf-8
+"Rev B, méil"
 revision = '%s'
 down_revision = '%s'
 
@@ -363,7 +363,7 @@ def upgrade():
 def downgrade():
     op.execute("DROP STEP 2")
 
-""" % (b, a))
+""") % (b, a), encoding="utf-8")
 
     script.generate_revision(c, "revision c", refresh=True)
     write_script(script, c, """\
index 34f57856e19425b121173e85f57a073cb3635c6e..1179b6e087dd661dcc9eaad6f1bf25479508408f 100644 (file)
@@ -3,7 +3,7 @@ from . import clear_staging_env, staging_env, \
     _sqlite_testing_config, \
     three_rev_fixture, eq_
 from alembic import command
-from io import StringIO
+from io import TextIOWrapper, BytesIO
 from alembic.script import ScriptDirectory
 
 
@@ -23,50 +23,62 @@ class StdoutCommandTest(unittest.TestCase):
     def _eq_cmd_output(self, buf, expected):
         script = ScriptDirectory.from_config(self.cfg)
 
-        revs = {"reva": self.a, "revb": self.b, "revc": self.c}
+        # test default encode/decode behavior as well,
+        # rev B has a non-ascii char in it + a coding header.
         eq_(
-            buf.getvalue().strip(),
-            "\n".join([script.get_revision(rev).log_entry for rev in expected]).strip()
+            buf.getvalue().decode("ascii", errors='replace').strip(),
+            "\n".join([
+                script.get_revision(rev).log_entry
+                for rev in expected
+            ]).encode("ascii", errors="replace").decode("ascii").strip()
         )
 
+    def _buf_fixture(self):
+        # try to simulate how sys.stdout looks - we send it u''
+        # but then it's trying to encode to something.
+        buf = BytesIO()
+        wrapper = TextIOWrapper(buf, encoding='ascii', line_buffering=True)
+        wrapper.getvalue = buf.getvalue
+        return wrapper
+
     def test_history_full(self):
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg)
         self._eq_cmd_output(buf, [self.c, self.b, self.a])
 
     def test_history_num_range(self):
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, "%s:%s" % (self.a, self.b))
         self._eq_cmd_output(buf, [self.b])
 
     def test_history_base_to_num(self):
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, ":%s" % (self.b))
         self._eq_cmd_output(buf, [self.b, self.a])
 
     def test_history_num_to_head(self):
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, "%s:" % (self.a))
         self._eq_cmd_output(buf, [self.c, self.b])
 
     def test_history_num_plus_relative(self):
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, "%s:+2" % (self.a))
         self._eq_cmd_output(buf, [self.c, self.b])
 
     def test_history_relative_to_num(self):
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, "-2:%s" % (self.c))
         self._eq_cmd_output(buf, [self.c, self.b])
 
     def test_history_current_to_head_as_b(self):
         command.stamp(self.cfg, self.b)
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, "current:")
         self._eq_cmd_output(buf, [self.c])
 
     def test_history_current_to_head_as_base(self):
         command.stamp(self.cfg, "base")
-        self.cfg.stdout = buf = StringIO()
+        self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, "current:")
         self._eq_cmd_output(buf, [self.c, self.b, self.a])
index 96106000b6b67ea90b01d590f3f098886007a8a1..3f7862ce358f827f4f47725def611a4c0a9d7178 100644 (file)
@@ -1,7 +1,11 @@
-from alembic import config, util
+#!coding: utf-8
+
+from alembic import config, util, compat
 from alembic.migration import MigrationContext
 from alembic.operations import Operations
 from alembic.script import ScriptDirectory
+import unittest
+from mock import Mock, call
 
 from . import eq_, capture_db, assert_raises_message
 
@@ -38,3 +42,34 @@ def test_no_script_error():
         "No 'script_location' key found in configuration.",
         ScriptDirectory.from_config, cfg
     )
+
+
+class OutputEncodingTest(unittest.TestCase):
+
+    def test_plain(self):
+        stdout = Mock(encoding='latin-1')
+        cfg = config.Config(stdout=stdout)
+        cfg.print_stdout("test %s %s", "x", "y")
+        eq_(
+            stdout.mock_calls,
+            [call.write('test x y'), call.write('\n')]
+        )
+
+    def test_utf8_unicode(self):
+        stdout = Mock(encoding='latin-1')
+        cfg = config.Config(stdout=stdout)
+        cfg.print_stdout(compat.u("méil %s %s"), "x", "y")
+        eq_(
+            stdout.mock_calls,
+            [call.write(compat.u('méil x y')), call.write('\n')]
+        )
+
+    def test_ascii_unicode(self):
+        stdout = Mock(encoding=None)
+        cfg = config.Config(stdout=stdout)
+        cfg.print_stdout(compat.u("méil %s %s"), "x", "y")
+        eq_(
+            stdout.mock_calls,
+            [call.write('m?il x y'), call.write('\n')]
+        )
+