]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- preliminary support for unicode table and column names added.
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 27 Mar 2007 16:04:34 +0000 (16:04 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 27 Mar 2007 16:04:34 +0000 (16:04 +0000)
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/sql/alltests.py
test/sql/select.py
test/sql/unicode.py [new file with mode: 0644]
test/testbase.py

diff --git a/CHANGES b/CHANGES
index bc218cbcea2a262a5d60f4aee37da817ba493605..6830ee00eaabb1ec13f5ff8373a85014dbe563db 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,7 @@
       on postgres.  Also, the true labelname is always attached as the
       accessor on the parent Selectable so theres no need to be aware
       of the genrerated label names [ticket:512].
+    - preliminary support for unicode table and column names added.
 - orm:
     - improved/fixed custom collection classes when giving it "set"/
       "sets.Set" classes or subclasses (was still looking for append()
index 0d4fba4e8a1716a1415e3e7ec92552c6251d30c9..37b6366a9f24ad3fa400b930c89543e62043025a 100644 (file)
@@ -161,23 +161,23 @@ class ANSICompiler(sql.Compiled):
         # this re will search for params like :param
         # it has a negative lookbehind for an extra ':' so that it doesnt match
         # postgres '::text' tokens
-        match = r'(?<!:):([\w_]+)'
+        match = re.compile(r'(?<!:):([\w_]+)', re.UNICODE)
         if self.paramstyle=='pyformat':
-            self.strings[self.statement] = re.sub(match, lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
+            self.strings[self.statement] = match.sub(lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
         elif self.positional:
-            params = re.finditer(match, self.strings[self.statement])
+            params = match.finditer(self.strings[self.statement])
             for p in params:
                 self.positiontup.append(p.group(1))
             if self.paramstyle=='qmark':
-                self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement])
+                self.strings[self.statement] = match.sub('?', self.strings[self.statement])
             elif self.paramstyle=='format':
-                self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
+                self.strings[self.statement] = match.sub('%s', self.strings[self.statement])
             elif self.paramstyle=='numeric':
                 i = [0]
                 def getnum(x):
                     i[0] += 1
                     return str(i[0])
-                self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement])
+                self.strings[self.statement] = match.sub(getnum, self.strings[self.statement])
 
     def get_from_text(self, obj):
         return self.froms.get(obj, None)
@@ -188,7 +188,7 @@ class ANSICompiler(sql.Compiled):
     def get_whereclause(self, obj):
         return self.wheres.get(obj, None)
 
-    def get_params(self, **params):
+    def construct_params(self, params):
         """Return a structure of bind parameters for this compiled object.
 
         This includes bind parameters that might be compiled in via
@@ -214,7 +214,6 @@ class ANSICompiler(sql.Compiled):
         else:
             bindparams = {}
         bindparams.update(params)
-
         d = sql.ClauseParameters(self.dialect, self.positiontup)
         for b in self.binds.values():
             d.set_parameter(b, b.value)
@@ -693,7 +692,7 @@ class ANSICompiler(sql.Compiled):
 
         def to_col(key):
             if not isinstance(key, sql._ColumnClause):
-                return stmt.table.columns.get(str(key), key)
+                return stmt.table.columns.get(unicode(key), key)
             else:
                 return key
 
@@ -986,11 +985,10 @@ class ANSIIdentifierPreparer(object):
 
     def _requires_quotes(self, value, case_sensitive):
         """Return True if the given identifier requires quoting."""
-
         return \
             value in self._reserved_words() \
             or (value[0] in self._illegal_initial_characters()) \
-            or bool(len([x for x in str(value) if x not in self._legal_characters()])) \
+            or bool(len([x for x in unicode(value) if x not in self._legal_characters()])) \
             or (case_sensitive and value.lower() != value)
 
     def __generic_obj_format(self, obj, ident):
index 7f7bde81bb2025a55cda04157f4958004472c210..eb397d774388c2f4e32680493c018b41157e906c 100644 (file)
@@ -494,7 +494,7 @@ class Connection(Connectable):
         if not compiled.can_execute:
             raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
         cursor = self.__engine.dialect.create_cursor(self.connection)
-        parameters = [compiled.get_params(**m) for m in self._params_to_listofdicts(*multiparams, **params)]
+        parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
         if len(parameters) == 1:
             parameters = parameters[0]
         def proxy(statement=None, parameters=None):
@@ -506,7 +506,7 @@ class Connection(Connectable):
             return cursor
         context = self.__engine.dialect.create_execution_context()
         context.pre_exec(self.__engine, proxy, compiled, parameters)
-        proxy(str(compiled), parameters)
+        proxy(unicode(compiled), parameters)
         context.post_exec(self.__engine, proxy, compiled, parameters)
         rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
         return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs)
@@ -544,16 +544,13 @@ class Connection(Connectable):
     def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs):
         if cursor is None:
             cursor = self.__engine.dialect.create_cursor(self.connection)
-        try:
-            self.__engine.logger.info(statement)
-            self.__engine.logger.info(repr(parameters))
-            if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
-                self._executemany(cursor, statement, parameters, context=context)
-            else:
-                self._execute(cursor, statement, parameters, context=context)
-            self._autocommit(statement)
-        except:
-            raise
+        self.__engine.logger.info(statement)
+        self.__engine.logger.info(repr(parameters))
+        if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
+            self._executemany(cursor, statement, parameters, context=context)
+        else:
+            self._execute(cursor, statement, parameters, context=context)
+        self._autocommit(statement)
         return cursor
 
     def _execute(self, c, statement, parameters, context=None):
index ae8333843ea1b9484c7c0bb49e62c362eada4ab3..3d7ddb5d69a9277be67d71889ce748caf853a0f2 100644 (file)
@@ -862,7 +862,7 @@ class Mapper(object):
             mapper._adapt_inherited_property(key, prop)
 
     def __str__(self):
-        return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.name or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "")
+        return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.encodedname or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "")
 
     def _is_primary_mapper(self):
         """Return True if this mapper is the primary mapper for its class key (class + entity_name)."""
index aa993a270d266eed9a6f896a8b214b20dee73ba9..bd601ed80013ac16700768a899c109dd62be16d7 100644 (file)
@@ -138,7 +138,6 @@ class _TableSingleton(type):
         if metadata is None:
             metadata = default_metadata
 
-        name = str(name)    # in case of incoming unicode
         schema = kwargs.get('schema', None)
         autoload = kwargs.pop('autoload', False)
         autoload_with = kwargs.pop('autoload_with', False)
@@ -318,7 +317,7 @@ class Table(SchemaItem, sql.TableClause):
             , ',')
 
     def __str__(self):
-        return _get_table_key(self.name, self.schema)
+        return _get_table_key(self.encodedname, self.schema)
 
     def append_column(self, column):
         """Append a ``Column`` to this ``Table``."""
@@ -494,7 +493,6 @@ class Column(SchemaItem, sql._ColumnClause):
             identifier contains mixed case.
         """
 
-        name = str(name) # in case of incoming unicode
         super(Column, self).__init__(name, None, type)
         self.args = args
         self.key = kwargs.pop('key', name)
@@ -521,11 +519,11 @@ class Column(SchemaItem, sql._ColumnClause):
     def __str__(self):
         if self.table is not None:
             if self.table.named_with_column():
-                return self.table.name + "." + self.name
+                return (self.table.encodedname + "." + self.encodedname)
             else:
-                return self.name
+                return self.encodedname
         else:
-            return self.name
+            return self.encodedname
 
     def _derived_metadata(self):
         return self.table.metadata
@@ -572,11 +570,11 @@ class Column(SchemaItem, sql._ColumnClause):
         self.table = table
 
         if self.index:
-            if isinstance(self.index, str):
+            if isinstance(self.index, basestring):
                 raise exceptions.ArgumentError("The 'index' keyword argument on Column is boolean only.  To create indexes with a specific name, append an explicit Index object to the Table's list of elements.")
             Index('ix_%s' % self._label, self, unique=self.unique)
         elif self.unique:
-            if isinstance(self.unique, str):
+            if isinstance(self.unique, basestring):
                 raise exceptions.ArgumentError("The 'unique' keyword argument on Column is boolean only.  To create unique constraints or indexes with a specific name, append an explicit UniqueConstraint or Index object to the Table's list of elements.")
             table.append_constraint(UniqueConstraint(self.key))
 
@@ -654,8 +652,6 @@ class ForeignKey(SchemaItem):
           created and added to the parent table.
         """
 
-        if isinstance(column, unicode):
-            column = str(column)
         self._colspec = column
         self._column = None
         self.constraint = constraint
@@ -673,7 +669,7 @@ class ForeignKey(SchemaItem):
         return ForeignKey(self._get_colspec())
 
     def _get_colspec(self):
-        if isinstance(self._colspec, str):
+        if isinstance(self._colspec, basestring):
             return self._colspec
         elif self._colspec.table.schema is not None:
             return "%s.%s.%s" % (self._colspec.table.schema, self._colspec.table.name, self._colspec.key)
@@ -689,7 +685,7 @@ class ForeignKey(SchemaItem):
         # ForeignKey inits its remote column as late as possible, so tables can
         # be defined without dependencies
         if self._column is None:
-            if isinstance(self._colspec, str):
+            if isinstance(self._colspec, basestring):
                 # locate the parent table this foreign key is attached to.
                 # we use the "original" column which our parent column represents
                 # (its a list of columns/other ColumnElements if the parent table is a UNION)
@@ -699,7 +695,7 @@ class ForeignKey(SchemaItem):
                         break
                 else:
                     raise exceptions.ArgumentError("Parent column '%s' does not descend from a table-attached Column" % str(self.parent))
-                m = re.match(r"^([\w_-]+)(?:\.([\w_-]+))?(?:\.([\w_-]+))?$", self._colspec)
+                m = re.match(r"^([\w_-]+)(?:\.([\w_-]+))?(?:\.([\w_-]+))?$", self._colspec, re.UNICODE)
                 if m is None:
                     raise exceptions.ArgumentError("Invalid foreign key column specification: " + self._colspec)
                 if m.group(3) is None:
index bd018e89cdfc1607d4cb677ea122327a87615236..edcf0f04298a613c9397169c2c9ff0f068fe7722 100644 (file)
@@ -648,6 +648,12 @@ class Compiled(ClauseVisitor):
         raise NotImplementedError()
 
     def get_params(self, **params):
+        """Deprecated.  use construct_params().  (supports unicode names)
+        """
+
+        return self.construct_params(params)
+
+    def construct_params(self, params):
         """Return the bind params for this compiled object.
 
         Will start with the default parameters specified when this
@@ -657,9 +663,8 @@ class Compiled(ClauseVisitor):
         ``_BindParamClause`` objects compiled into this object; either
         the `key` or `shortname` property of the ``_BindParamClause``.
         """
-
         raise NotImplementedError()
-
+        
     def execute(self, *multiparams, **params):
         """Execute this compiled object."""
 
@@ -823,7 +828,7 @@ class ClauseElement(object):
         return compiler
 
     def __str__(self):
-        return str(self.compile())
+        return unicode(self.compile()).encode('ascii', 'backslashreplace')
 
     def __and__(self, other):
         return and_(self, other)
@@ -1858,6 +1863,7 @@ class _ColumnClause(ColumnElement):
 
     def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
         self.key = self.name = text
+        self.encodedname = self.name.encode('ascii', 'backslashreplace')
         self.table = selectable
         self.type = sqltypes.to_instance(type)
         self._is_oid = _is_oid
@@ -1941,6 +1947,7 @@ class TableClause(FromClause):
     def __init__(self, name, *columns):
         super(TableClause, self).__init__(name)
         self.name = self.fullname = name
+        self.encodedname = self.name.encode('ascii', 'backslashreplace')
         self._columns = ColumnCollection()
         self._foreign_keys = util.OrderedSet()
         self._primary_key = ColumnCollection()
index 9f1c0d36eb81637a5987f6a3ec5af8d9eb70a8e5..7be1a3ffb6fda3b802c1c3c5fbf77ac32784897c 100644 (file)
@@ -12,6 +12,7 @@ def suite():
         'sql.selectable',
         'sql.case_statement', 
         'sql.labels',
+        'sql.unicode',
         
         # assorted round-trip tests
         'sql.query',
index 1ca8224a8a179d8e2711eaa1242a6277bcdda852..5fcf88fd1ee92cc71e31630607c02ee6066b38fa 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy import *
 from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird
 import unittest, re
 
+
 # the select test now tests almost completely with TableClause/ColumnClause objects,
 # which are free-roaming table/column objects not attached to any database.  
 # so SQLAlchemy's SQL construction engine can be used with no database dependencies at all.
diff --git a/test/sql/unicode.py b/test/sql/unicode.py
new file mode 100644 (file)
index 0000000..1e1b414
--- /dev/null
@@ -0,0 +1,66 @@
+# coding: utf-8
+import testbase
+
+from sqlalchemy import *
+
+"""verrrrry basic unicode column name testing"""
+
+class UnicodeSchemaTest(testbase.PersistTest):
+    @testbase.unsupported('postgres')
+    def setUpAll(self):
+        global metadata, t1, t2
+        metadata = MetaData(engine=testbase.db)
+        t1 = Table('unitable1', metadata,
+            Column(u'méil', Integer, primary_key=True),
+            Column(u'éXXm', Integer),
+
+            )
+        t2 = Table(u'unitéble2', metadata,
+            Column(u'méil', Integer, primary_key=True, key="a"),
+            Column(u'éXXm', Integer, ForeignKey(u'unitable1.méil'), key="b"),
+
+            )
+
+        metadata.create_all()
+    @testbase.unsupported('postgres')
+    def tearDownAll(self):
+        metadata.drop_all()
+
+    @testbase.unsupported('postgres')
+    def test_insert(self):
+        t1.insert().execute({u'méil':1, u'éXXm':5})
+        t2.insert().execute({'a':1, 'b':5})
+        
+        assert t1.select().execute().fetchall() == [(1, 5)]
+        assert t2.select().execute().fetchall() == [(1, 5)]
+        
+    @testbase.unsupported('postgres')
+    def test_mapping(self):
+        # TODO: this test should be moved to the ORM tests, tests should be
+        # added to this module testing SQL syntax and joins, etc.
+        class A(object):pass
+        class B(object):pass
+        
+        mapper(A, t1, properties={
+            't2s':relation(B),
+            'a':t1.c[u'méil'],
+            'b':t1.c[u'éXXm']
+        })
+        mapper(B, t2)
+        sess = create_session()
+        a1 = A()
+        b1 = B()
+        a1.t2s.append(b1)
+        sess.save(a1)
+        sess.flush()
+        sess.clear()
+        new_a1 = sess.query(A).selectone(t1.c[u'méil'] == a1.a)
+        assert new_a1.a == a1.a
+        assert new_a1.t2s[0].a == b1.a
+        sess.clear()
+        new_a1 = sess.query(A).options(eagerload('t2s')).selectone(t1.c[u'méil'] == a1.a)
+        assert new_a1.a == a1.a
+        assert new_a1.t2s[0].a == b1.a
+        
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file
index 34c05b8e5002277cff49a71f2b328a025a23d930..88519ef41e5e0730bbb25fc0ad3773797295e905 100644 (file)
@@ -263,7 +263,7 @@ class EngineAssert(proxy.BaseProxyEngine):
         def post_exec(engine, proxy, compiled, parameters, **kwargs):
             ctx = e
             self.engine.logger = self.logger
-            statement = str(compiled)
+            statement = unicode(compiled)
             statement = re.sub(r'\n', '', statement)
 
             if self.assert_list is not None: