]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
make sure inputs are unicode when binding to unicode
authorRandall Smith <randall@tnr.cc>
Thu, 12 Feb 2009 04:40:03 +0000 (04:40 +0000)
committerRandall Smith <randall@tnr.cc>
Thu, 12 Feb 2009 04:40:03 +0000 (04:40 +0000)
lib/sqlalchemy/dialects/postgres/base.py
test/reflection.py

index a27af65ddd5b9da2aeda072f164311cc0728959f..9277fb124969078ed65d5b6467ce2c7850956a30 100644 (file)
@@ -540,6 +540,11 @@ class PGDialect(default.DefaultDialect):
             WHERE (%s)
             AND c.relname = :table_name AND c.relkind in ('r','v')
         """ % schema_where_clause
+        # Since we're binding to unicode, tablename and schemaname must be
+        # unicode.
+        tablename = unicode(tablename)
+        if schemaname is not None:
+            schemaname = unicode(schemaname)
         s = sql.text(query, bindparams=[
             sql.bindparam('table_name', type_=sqltypes.Unicode),
             sql.bindparam('schema', type_=sqltypes.Unicode)
index 863dc526ba764ae16a6ea7dc9d9fbab90cfc03b0..0e95a6dbf454f4d6a3886b49ca061b8a12f87b08 100644 (file)
@@ -13,9 +13,9 @@ if 'set' not in dir(__builtins__):
 
 def getSchema():
     if testing.against('oracle'):
-        return u'scott'
+        return 'scott'
     else:
-        return u'test_schema'
+        return 'test_schema'
 
 def createTables(meta, schema=None):
     if schema:
@@ -45,7 +45,7 @@ def createTables(meta, schema=None):
         schema=schema,
         test_needs_fk=True,
     )
-    addresses = Table(u'email_addresses', meta,
+    addresses = Table('email_addresses', meta,
         Column('address_id', sa.Integer, primary_key = True),
         Column('remote_user_id', sa.Integer,
                sa.ForeignKey(users.c.user_id)),
@@ -63,7 +63,7 @@ def createIndexes(con, schema=None):
     con.execute(sa.sql.text(query))
 
 def createViews(con, schema=None):
-    for tablename in (u'users', u'email_addresses'):
+    for tablename in ('users', 'email_addresses'):
         fullname = tablename
         if schema:
             fullname = "%s.%s" % (schema, tablename)
@@ -73,7 +73,7 @@ def createViews(con, schema=None):
         con.execute(sa.sql.text(query))
 
 def dropViews(con, schema=None):
-    for tablename in (u'email_addresses', 'users'):
+    for tablename in ('email_addresses', 'users'):
         fullname = tablename
         if schema:
             fullname = "%s.%s" % (schema, tablename)
@@ -100,15 +100,15 @@ class ReflectionTest(TestBase):
             if table_type == 'view':
                 table_names = insp.get_view_names(schemaname)
                 table_names.sort()
-                answer = [u'email_addresses_v', u'users_v']
+                answer = ['email_addresses_v', 'users_v']
             else:
                 table_names = insp.get_table_names(schemaname,
                                                    order_by=order_by)
                 table_names.sort()
                 if order_by == 'foreign_key':
-                    answer = [u'users', 'email_addresses']
+                    answer = ['users', 'email_addresses']
                 else:
-                    answer = [u'email_addresses', 'users']
+                    answer = ['email_addresses', 'users']
             self.assertEqual(table_names, answer)
         finally:
             dropViews(meta.bind, schemaname)
@@ -133,11 +133,11 @@ class ReflectionTest(TestBase):
     def _test_get_columns(self, schemaname=None, table_type='table'):
         meta = MetaData(testing.db)
         (users, addresses) = createTables(meta, schemaname)
-        table_names = [u'users', u'email_addresses']
+        table_names = ['users', 'email_addresses']
         meta.create_all()
         if table_type == 'view':
             createViews(meta.bind, schemaname)
-            table_names = [u'users_v', u'email_addresses_v']
+            table_names = ['users_v', 'email_addresses_v']
         try:
             insp = Inspector(meta.bind)
             for (tablename, table) in zip(table_names, (users, addresses)):
@@ -184,10 +184,10 @@ class ReflectionTest(TestBase):
         meta.create_all()
         insp = Inspector(meta.bind)
         try:
-            users_pkeys = insp.get_primary_keys(unicode(users.name),
+            users_pkeys = insp.get_primary_keys(users.name,
                                                 schemaname=schemaname)
             self.assertEqual(users_pkeys,  ['user_id'])
-            addr_pkeys = insp.get_primary_keys(unicode(addresses.name),
+            addr_pkeys = insp.get_primary_keys(addresses.name,
                                                schemaname=schemaname)
             self.assertEqual(addr_pkeys,  ['address_id'])
 
@@ -212,7 +212,7 @@ class ReflectionTest(TestBase):
                 expected_schema = meta.bind.dialect.get_default_schema_name(
                                     meta.bind)
             # users
-            users_fkeys = insp.get_foreign_keys(unicode(users.name),
+            users_fkeys = insp.get_foreign_keys(users.name,
                                                 schemaname=schemaname)
             fkey1 = users_fkeys[0]
             self.assert_(fkey1['name'] is not None)
@@ -246,7 +246,7 @@ class ReflectionTest(TestBase):
         createIndexes(meta.bind, schemaname)
         try:
             insp = Inspector(meta.bind)
-            indexes = insp.get_indexes(u'users', schemaname=schemaname)
+            indexes = insp.get_indexes('users', schemaname=schemaname)
             indexes.sort()
             expected_indexes = [
                 {'unique': False,
@@ -268,8 +268,8 @@ class ReflectionTest(TestBase):
         (users, addresses) = createTables(meta, schemaname)
         meta.create_all()
         createViews(meta.bind, schemaname)
-        view_name1 = u'users_v'
-        view_name2 = u'email_addresses_v'
+        view_name1 = 'users_v'
+        view_name2 = 'email_addresses_v'
         try:
             insp = Inspector(meta.bind)
             v1 = insp.get_view_definition(view_name1, schemaname=schemaname)