]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merge -r5673:5675 of trunk
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 21:06:02 +0000 (21:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 21:06:02 +0000 (21:06 +0000)
CHANGES
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/engine/reflection.py
test/sql/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index c7846b9c0d916fe52ec893eeeefa17f895275a7d..c45dd101b850f01436b6500d1c07a8713044196a 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -45,6 +45,12 @@ CHANGES
       combined with joined-table inheritance and an object
       which contained no defined values for the child table where
       an UPDATE with no SET clause would be rendered.
+
+- sql
+    - Improved the methodology to handling percent signs in column
+      names from [ticket:1256].  Added more tests.  MySQL and
+      Postgres dialects still do not issue correct CREATE TABLE
+      statements for identifiers with percent signs in them.
       
 - schema
     - Index now accepts column-oriented InstrumentedAttributes
@@ -55,6 +61,12 @@ CHANGES
       NoneType error when it's string output is requsted
       (such as in a stack trace).
       
+    - Fixed bug when overriding a Column with a ForeignKey
+      on a reflected table, where derived columns (i.e. the 
+      "virtual" columns of a select, etc.) would inadvertently
+      call upon schema-level cleanup logic intended only
+      for the original column. [ticket:1278]
+      
 - declarative
     - Can now specify Column objects on subclasses which have no
       table of their own (i.e. use single table inheritance).  
index 02055b7092fc7e041873c6df6bb999dbe2bfc428..c2fb5ee5a6b7bbc82e0850e4c9d314fcdabdc38a 100644 (file)
@@ -666,7 +666,9 @@ class Column(SchemaItem, expression.ColumnClause):
         if getattr(self, 'table', None) is not None:
             raise exc.ArgumentError("this Column already has a table!")
 
-        self._pre_existing_column = table._columns.get(self.key)
+        if self.key in table._columns:
+            # note the column being replaced, if any
+            self._pre_existing_column = table._columns.get(self.key)
         table._columns.replace(self)
 
         if self.primary_key:
@@ -734,11 +736,17 @@ class Column(SchemaItem, expression.ColumnClause):
         (such as an alias or select statement).
 
         """
-        fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
-        c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, quote=self.quote, *fk)
+        fk = [ForeignKey(f.column) for f in self.foreign_keys]
+        c = Column(
+            name or self.name, 
+            self.type, 
+            self.default, 
+            key = name or self.key, 
+            primary_key = self.primary_key, 
+            nullable = self.nullable, 
+            quote=self.quote, *fk)
         c.table = selectable
         c.proxies = [self]
-        c._pre_existing_column = self._pre_existing_column
         selectable.columns.add(c)
         if self.primary_key:
             selectable.primary_key.add(c)
@@ -924,10 +932,10 @@ class ForeignKey(SchemaItem):
             raise exc.InvalidRequestError("This ForeignKey already has a parent !")
         self.parent = column
 
-        if self.parent._pre_existing_column is not None:
+        if hasattr(self.parent, '_pre_existing_column'):
             # remove existing FK which matches us
             for fk in self.parent._pre_existing_column.foreign_keys:
-                if fk._colspec == self._colspec:
+                if fk.target_fullname == self.target_fullname:
                     self.parent.table.foreign_keys.remove(fk)
                     self.parent.table.constraints.remove(fk.constraint)
 
index 3e61b459b4b21104221d8e2dcd87694ccb7596b8..d00a05436a9e5eeed378808c9cef2c63818d660b 100644 (file)
@@ -274,7 +274,10 @@ class SQLCompiler(engine.Compiled):
                 schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.'
             else:
                 schema_prefix = ''
-            return schema_prefix + self.preparer.quote(column.table.name % self.anon_map, column.table.quote) + "." + name
+            tablename = column.table.name
+            if isinstance(tablename, sql._generated_label):
+                tablename = tablename % self.anon_map
+            return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name
 
     def escape_literal_column(self, text):
         """provide escaping for the literal_column() construct."""
@@ -407,12 +410,11 @@ class SQLCompiler(engine.Compiled):
 
         return bind_name
 
-    _trunc_re = re.compile(r'%\((-?\d+ \w+)\)s', re.U)
     def _truncated_identifier(self, ident_class, name):
         if (ident_class, name) in self.truncated_names:
             return self.truncated_names[(ident_class, name)]
 
-        anonname = self._trunc_re.sub(lambda m: self.anon_map[m.group(1)], name)
+        anonname = name % self.anon_map 
 
         if len(anonname) > self.label_length:
             counter = self.truncated_names.get(ident_class, 1)
@@ -424,7 +426,7 @@ class SQLCompiler(engine.Compiled):
         return truncname
     
     def _anonymize(self, name):
-        return self._trunc_re.sub(lambda m: self.anon_map[m.group(1)], name)
+        return name % self.anon_map
         
     def _process_anon(self, key):
         (ident, derived) = key.split(' ')
index 6be867dbf52b27ed60f841a07986997590cef99a..f8c7de8d0ff18208bf7dd902570cb23a40143416 100644 (file)
@@ -868,6 +868,12 @@ modifier = _FunctionGenerator(group=False)
 class _generated_label(unicode):
     """A unicode subclass used to identify dynamically generated names."""
 
+def _escape_for_generated(x):
+    if isinstance(x, _generated_label):
+        return x
+    else:
+        return x.replace('%', '%%')
+        
 def _clone(element):
     return element._clone()
 
@@ -1646,7 +1652,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
         expressions and function calls.
 
         """
-        return _generated_label("%%(%d %s)s" % (id(self), getattr(self, 'name', 'anon')))
+        return _generated_label("%%(%d %s)s" % (id(self), _escape_for_generated(getattr(self, 'name', 'anon'))))
 
 class ColumnCollection(util.OrderedProperties):
     """An ordered dictionary that stores a list of ColumnElement
@@ -1978,7 +1984,7 @@ class _BindParamClause(ColumnElement):
 
         """
         if unique:
-            self.key = _generated_label("%%(%d %s)s" % (id(self), key or 'param'))
+            self.key = _generated_label("%%(%d %s)s" % (id(self), key and _escape_for_generated(key) or 'param'))
         else:
             self.key = key or _generated_label("%%(%d param)s" % id(self))
         self._orig_key = key or 'param'
@@ -1997,13 +2003,13 @@ class _BindParamClause(ColumnElement):
     def _clone(self):
         c = ClauseElement._clone(self)
         if self.unique:
-            c.key = _generated_label("%%(%d %s)s" % (id(c), c._orig_key or 'param'))
+            c.key = _generated_label("%%(%d %s)s" % (id(c), c._orig_key and _escape_for_generated(c._orig_key) or 'param'))
         return c
 
     def _convert_to_unique(self):
         if not self.unique:
             self.unique = True
-            self.key = _generated_label("%%(%d %s)s" % (id(self), self._orig_key or 'param'))
+            self.key = _generated_label("%%(%d %s)s" % (id(self), self._orig_key and _escape_for_generated(self._orig_key) or 'param'))
 
     def bind_processor(self, dialect):
         return self.type.dialect_impl(dialect).bind_processor(dialect)
@@ -2610,7 +2616,7 @@ class Alias(FromClause):
         if alias is None:
             if self.original.named_with_column:
                 alias = getattr(self.original, 'name', None)
-            alias = _generated_label('%%(%d %s)s' % (id(self), alias or 'anon'))
+            alias = _generated_label('%%(%d %s)s' % (id(self), alias and _escape_for_generated(alias) or 'anon'))
         self.name = alias
 
     @property
@@ -2731,7 +2737,7 @@ class _Label(ColumnElement):
     def __init__(self, name, element, type_=None):
         while isinstance(element, _Label):
             element = element.element
-        self.name = self.key = self._label = name or _generated_label("%%(%d %s)s" % (id(self), getattr(element, 'name', 'anon')))
+        self.name = self.key = self._label = name or _generated_label("%%(%d %s)s" % (id(self), _escape_for_generated(getattr(element, 'name', 'anon'))))
         self._element = element
         self._type = type_
         self.quote = element.quote
@@ -2820,9 +2826,9 @@ class ColumnClause(_Immutable, ColumnElement):
 
         elif self.table and self.table.named_with_column:
             if getattr(self.table, 'schema', None):
-                label = self.table.schema + "_" + self.table.name + "_" + self.name
+                label = self.table.schema + "_" + _escape_for_generated(self.table.name) + "_" + _escape_for_generated(self.name)
             else:
-                label = self.table.name + "_" + self.name
+                label = _escape_for_generated(self.table.name) + "_" + _escape_for_generated(self.name)
 
             if label in self.table.c:
                 # TODO: coverage does not seem to be present for this
index 64ae468e67e23cabba1a2fd4da724f88f2587f62..a448540825bcf2721a9f4889042c7d1991c8e8e4 100644 (file)
@@ -300,6 +300,8 @@ class ReflectionTest(TestBase, ComparesTables):
                 autoload=True)
             u2 = Table('users', meta2, autoload=True)
 
+            s = sa.select([a2])
+            assert s.c.user_id
             assert len(a2.foreign_keys) == 1
             assert len(a2.c.user_id.foreign_keys) == 1
             assert len(a2.constraints) == 2
@@ -317,6 +319,8 @@ class ReflectionTest(TestBase, ComparesTables):
                 Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
                 autoload=True)
 
+            s = sa.select([a2])
+            assert s.c.user_id
             assert len(a2.foreign_keys) == 1
             assert len(a2.c.user_id.foreign_keys) == 1
             assert len(a2.constraints) == 2
index 275bbe78c66e33ed01cb09c79826f716410fac26..bf178ae8f541e02553279a60f3327dfc9d4e0faf 100644 (file)
@@ -664,7 +664,72 @@ class QueryTest(TestBase):
         r = s.execute().fetchall()
         assert len(r) == 1
 
-
+class PercentSchemaNamesTest(TestBase):
+    """tests using percent signs, spaces in table and column names.
+    
+    Doesn't pass for mysql, postgres, but this is really a 
+    SQLAlchemy bug - we should be escaping out %% signs for this
+    operation the same way we do for text() and column labels.
+    
+    """
+    @testing.crashes('mysql', 'mysqldb calls name % (params)')
+    @testing.crashes('postgres', 'postgres calls name % (params)')
+    def setUpAll(self):
+        global percent_table, metadata
+        metadata = MetaData(testing.db)
+        percent_table = Table('percent%table', metadata,
+            Column("percent%", Integer),
+            Column("%(oneofthese)s", Integer),
+            Column("spaces % more spaces", Integer),
+        )
+        metadata.create_all()
+        
+    def tearDownAll(self):
+        metadata.drop_all()
+    
+    @testing.crashes('mysql', 'mysqldb calls name % (params)')
+    @testing.crashes('postgres', 'postgres calls name % (params)')
+    def test_roundtrip(self):
+        percent_table.insert().execute(
+            {'percent%':5, '%(oneofthese)s':7, 'spaces % more spaces':12},
+        )
+        percent_table.insert().execute(
+            {'percent%':7, '%(oneofthese)s':8, 'spaces % more spaces':11},
+            {'percent%':9, '%(oneofthese)s':9, 'spaces % more spaces':10},
+            {'percent%':11, '%(oneofthese)s':10, 'spaces % more spaces':9},
+        )
+        eq_(
+            percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(),
+            [
+                (5, 7, 12),
+                (7, 8, 11),
+                (9, 9, 10),
+                (11, 10, 9)
+            ]
+        )
+        result = percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute()
+        row = result.fetchone()
+        eq_(row[percent_table.c['percent%']], 5)
+        eq_(row[percent_table.c['%(oneofthese)s']], 7)
+        eq_(row[percent_table.c['spaces % more spaces']], 12)
+        row = result.fetchone()
+        eq_(row['percent%'], 7)
+        eq_(row['%(oneofthese)s'], 8)
+        eq_(row['spaces % more spaces'], 11)
+        result.close()
+        percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute()
+        eq_(
+            percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(),
+            [
+                (5, 9, 15),
+                (7, 9, 15),
+                (9, 9, 15),
+                (11, 9, 15)
+            ]
+        )
+        
+        
+        
 class LimitTest(TestBase):
 
     def setUpAll(self):
index 77112e649a08a1cdc83f5545a8a5159b20cc383b..72c552fffd4daab77bf78a84db4b7ffbcc0cfe0a 100644 (file)
@@ -836,16 +836,16 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
                             "COLLATE somecol AS x")
 
     def test_percent_chars(self):
-        t = table("table",
+        t = table("table%name",
             column("percent%"),
             column("%(oneofthese)s"),
             column("spaces % more spaces"),
         )
         self.assert_compile(
             t.select(use_labels=True),
-            '''SELECT "table"."percent%" AS "table_percent%", '''\
-            '''"table"."%(oneofthese)s" AS "table_%(oneofthese)s", '''\
-            '''"table"."spaces % more spaces" AS "table_spaces % more spaces" FROM "table"'''
+            '''SELECT "table%name"."percent%" AS "table%name_percent%", '''\
+            '''"table%name"."%(oneofthese)s" AS "table%name_%(oneofthese)s", '''\
+            '''"table%name"."spaces % more spaces" AS "table%name_spaces % more spaces" FROM "table%name"'''
         )