]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- got unicode schemas to work with postgres
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Apr 2007 22:04:53 +0000 (22:04 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Apr 2007 22:04:53 +0000 (22:04 +0000)
- unicode schema with mysql slightly improved, still cant do has_table
- got reflection of unicode schemas working with sqlite, pg, mysql

CHANGES
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql.py
test/sql/unicode.py

diff --git a/CHANGES b/CHANGES
index 2700536a3c573fbeaeeb8f67c81e214863e11ed0..d8875df0859719b0b7af57352b065ce19cf703b6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       lost their underlying database - the error catching/invalidate
       step is totally moved to the connection pool. #516
 - sql:
+    - preliminary support for unicode table names, column names and 
+      SQL statements added, for databases which can support them.
+      Works with sqlite and postgres so far.  Mysql *mostly* works
+      except the has_table() function does not work.  Reflection
+      works too.
     - the Unicode type is now a direct subclass of String, which now
       contains all the "convert_unicode" logic.  This helps the variety
       of unicode situations that occur in db's such as MS-SQL to be
@@ -37,8 +42,6 @@
       that those named columns are selected from (part of [ticket:513])
     - MS-SQL better detects when a query is a subquery and knows not to
       generate ORDER BY phrases for those [ticket:513]
-    - preliminary support for unicode table names, column names and 
-      SQL statements added, for databases which can support them.
     - fix for fetchmany() "size" argument being positional in most
       dbapis [ticket:505]
     - sending None as an argument to func.<something> will produce
index 03297cd68675c6ba9484f2ae5b84ebfb3f887e56..d3a42ccdc47b545a9a43309ac852bcf52cafd50e 100644 (file)
@@ -303,7 +303,6 @@ class MySQLDialect(ansisql.ANSIDialect):
             except:
                 pass
         opts['client_flag'] = client_flag
-
         return [[], opts]
 
     def create_execution_context(self, *args, **kwargs):
@@ -331,7 +330,10 @@ class MySQLDialect(ansisql.ANSIDialect):
         rowcount = cursor.executemany(statement, parameters)
         if context is not None:
             context._rowcount = rowcount
-            
+    
+    def supports_unicode_statements(self):
+        return True
+                
     def do_execute(self, cursor, statement, parameters, **kwargs):
         cursor.execute(statement, parameters)
 
@@ -351,8 +353,11 @@ class MySQLDialect(ansisql.ANSIDialect):
         return self._default_schema_name
 
     def has_table(self, connection, table_name, schema=None):
+        # TODO: this does not work for table names that contain multibyte characters.
+        # i have tried dozens of approaches here with no luck.  statements like
+        # DESCRIBE and SHOW CREATE TABLE work better, but they raise an error when
+        # the table does not exist.
         cursor = connection.execute("show table status like %s", [table_name])
-        print "CURSOR", cursor, "ROWCOUNT", cursor.rowcount, "REAL RC", cursor.cursor.rowcount
         return bool( not not cursor.rowcount )
 
     def reflecttable(self, connection, table):
@@ -362,6 +367,8 @@ class MySQLDialect(ansisql.ANSIDialect):
             cs = cs.tostring()
         case_sensitive = int(cs) == 0
 
+        decode_from = connection.execute("show variables like 'character_Set_results'").fetchone()[1]
+
         if not case_sensitive:
             table.name = table.name.lower()
             table.metadata.tables[table.name]= table
@@ -379,7 +386,9 @@ class MySQLDialect(ansisql.ANSIDialect):
                 found_table = True
 
             # these can come back as unicode if use_unicode=1 in the mysql connection
-            (name, type, nullable, primary_key, default) = (str(row[0]), str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4])
+            (name, type, nullable, primary_key, default) = (row[0], str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4])
+            if not isinstance(name, unicode):
+                name = name.decode(decode_from)
 
             match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
             col_type = match.group(1)
@@ -425,10 +434,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         c = connection.execute("SHOW CREATE TABLE " + table.fullname, {})
         desc_fetched = c.fetchone()[1]
 
-        # this can come back as unicode if use_unicode=1 in the mysql connection
-        if type(desc_fetched) is unicode:
-            desc_fetched = str(desc_fetched)
-        elif type(desc_fetched) is not str:
+        if not isinstance(desc_fetched, basestring):
             # may get array.array object here, depending on version (such as mysql 4.1.14 vs. 4.1.11)
             desc_fetched = desc_fetched.tostring()
         desc = desc_fetched.strip()
index 6facde93671f55eb9036809ee36e5765140143b2..a93ba200cf289a0d96346c86beec19c942084518 100644 (file)
@@ -329,9 +329,9 @@ class PGDialect(ansisql.ANSIDialect):
     def has_table(self, connection, table_name, schema=None):
         # seems like case gets folded in pg_class...
         if schema is None:
-            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower()});
+            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
         else:
-            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower(), 'schema':schema});
+            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
         return bool( not not cursor.rowcount )
 
     def has_sequence(self, connection, sequence_name):
@@ -385,7 +385,7 @@ class PGDialect(ansisql.ANSIDialect):
                 ORDER BY a.attnum
             """ % schema_where_clause
 
-            s = sql.text(SQL_COLS)
+            s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
             c = connection.execute(s, table_name=table.name,
                                       schema=table.schema)
             rows = c.fetchall()
@@ -454,7 +454,7 @@ class PGDialect(ansisql.ANSIDialect):
                  AND i.indisprimary = 't')
               ORDER BY attnum
             """
-            t = sql.text(PK_SQL)
+            t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
             c = connection.execute(t, table=table_oid)
             for row in c.fetchall():
                 pk = row[0]
@@ -468,7 +468,7 @@ class PGDialect(ansisql.ANSIDialect):
               ORDER BY 1
             """
 
-            t = sql.text(FK_SQL)
+            t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
             c = connection.execute(t, table=table_oid)
             for conname, condef in c.fetchall():
                 m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
index 5140d865ec7f7a870f9f0fe1847462c5a4f56b03..2b7e28dfb5c73a837553fe02d599198996fe121b 100644 (file)
@@ -185,6 +185,9 @@ class SQLiteDialect(ansisql.ANSIDialect):
     def create_execution_context(self, **kwargs):
         return SQLiteExecutionContext(self, **kwargs)
 
+    def supports_unicode_statements(self):
+        return True
+
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
index 9431e13a0ecfbd67462fe5d1cd6d5af40ebf1506..969bde8d9dc44c60de7f8977fda94143becacb84 100644 (file)
@@ -51,7 +51,7 @@ class DefaultDialect(base.Dialect):
 
     def supports_unicode_statements(self):
         """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
-        return True
+        return False
 
     def max_identifier_length(self):
         # TODO: probably raise this and fill out
@@ -165,16 +165,30 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.statement = unicode(compiled)
         else:
             self.typemap = self.column_labels = None
-            self.parameters = parameters
+            self.parameters = self._encode_param_keys(parameters)
             self.statement = statement
 
         if not dialect.supports_unicode_statements():
-            self.statement = self.statement.encode('ascii')
-        
+            self.statement = self.statement.encode(self.dialect.encoding)
+            
         self.cursor = self.create_cursor()
         
     engine = property(lambda s:s.connection.engine)
     
+    def _encode_param_keys(self, params):
+        """apply string encoding to the keys of dictionary-based bind parameters"""
+        if self.dialect.positional or self.dialect.supports_unicode_statements():
+            return params
+        else:
+            def proc(d):
+                if d is None:
+                    return None
+                return dict((k.encode(self.dialect.encoding), d[k]) for k in d)
+            if isinstance(params, list):
+                return [proc(d) for d in params]
+            else:
+                return proc(params)
+                
     def is_select(self):
         return re.match(r'SELECT', self.statement.lstrip(), re.I)
 
@@ -183,7 +197,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         
     def pre_exec(self):
         self._process_defaults()
-        self.parameters = self.dialect.convert_compiled_params(self.compiled_parameters)
+        self.parameters = self._encode_param_keys(self.dialect.convert_compiled_params(self.compiled_parameters))
 
     def post_exec(self):
         pass
index d912ca17673ef44fbb68bf8d25589bf56d301a1a..94b618491c85255b75128776a5e13ac028c93c7f 100644 (file)
@@ -1521,7 +1521,7 @@ class ClauseList(ClauseElement):
 
     def append(self, clause):
         if _is_literal(clause):
-            clause = _TextClause(str(clause))
+            clause = _TextClause(unicode(clause))
         self.clauses.append(clause)
 
     def get_children(self, **kwargs):
@@ -2042,7 +2042,7 @@ class _ColumnClause(ColumnElement):
 
     def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
         self.key = self.name = text
-        self.encodedname = self.name.encode('ascii', 'backslashreplace')
+        self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name
         self.table = selectable
         self.type = sqltypes.to_instance(type)
         self._is_oid = _is_oid
@@ -2346,6 +2346,10 @@ class Select(_SelectBaseMixin, FromClause):
             for c in columns:
                 self.append_column(c)
 
+        if order_by:
+            order_by = util.to_list(order_by)
+        if group_by:
+            group_by = util.to_list(group_by)
         self.order_by(*(order_by or [None]))
         self.group_by(*(group_by or [None]))
         for c in self.order_by_clause:
index 1e1b414eaaf1bf8681d2f40b9cfbbfbb7b0a2ab4..9dfc75059e06c56bcc1ec85830139b27a775e293 100644 (file)
@@ -6,7 +6,6 @@ from sqlalchemy import *
 """verrrrry basic unicode column name testing"""
 
 class UnicodeSchemaTest(testbase.PersistTest):
-    @testbase.unsupported('postgres')
     def setUpAll(self):
         global metadata, t1, t2
         metadata = MetaData(engine=testbase.db)
@@ -20,21 +19,39 @@ class UnicodeSchemaTest(testbase.PersistTest):
             Column(u'éXXm', Integer, ForeignKey(u'unitable1.méil'), key="b"),
 
             )
-
         metadata.create_all()
-    @testbase.unsupported('postgres')
+
+    def tearDown(self):
+        t2.delete().execute()
+        t1.delete().execute()
+        
     def tearDownAll(self):
         metadata.drop_all()
-
-    @testbase.unsupported('postgres')
+        
+        # has_table() doesnt handle the unicode names on mysql
+        if testbase.db.name == 'mysql':
+            t2.drop()
+        
     def test_insert(self):
         t1.insert().execute({u'méil':1, u'éXXm':5})
-        t2.insert().execute({'a':1, 'b':5})
+        t2.insert().execute({'a':1, 'b':1})
         
         assert t1.select().execute().fetchall() == [(1, 5)]
-        assert t2.select().execute().fetchall() == [(1, 5)]
+        assert t2.select().execute().fetchall() == [(1, 1)]
+    
+    def test_reflect(self):
+        t1.insert().execute({u'méil':2, u'éXXm':7})
+        t2.insert().execute({'a':2, 'b':2})
+
+        meta = BoundMetaData(testbase.db)
+        tt1 = Table(t1.name, meta, autoload=True)
+        tt2 = Table(t2.name, meta, autoload=True)
+        tt1.insert().execute({u'méil':1, u'éXXm':5})
+        tt2.insert().execute({u'méil':1, u'éXXm':1})
+
+        assert tt1.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 7), (1, 5)]
+        assert tt2.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 2), (1, 1)]
         
-    @testbase.unsupported('postgres')
     def test_mapping(self):
         # TODO: this test should be moved to the ORM tests, tests should be
         # added to this module testing SQL syntax and joins, etc.