]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Added support for using the .key
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Feb 2012 21:58:32 +0000 (16:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Feb 2012 21:58:32 +0000 (16:58 -0500)
of a Column as a string identifier in a
result set row.   The .key is currently
listed as an "alternate" name for a column,
and is superseded by the name of a column
which has that key value as its regular name.
For the next major release
of SQLAlchemy we may reverse this precedence
so that .key takes precedence, but this
is not decided on yet.  [ticket:2392]

CHANGES
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/test_compiler.py
test/sql/test_query.py
test/sql/test_quote.py
test/sql/test_selectable.py

diff --git a/CHANGES b/CHANGES
index 336627a27814d9385ffaebc35546ad200de865c8..a4638f2b6dc77b9e44a23624b8f2985077535c6f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -20,6 +20,17 @@ CHANGES
     [ticket:2390]
 
 - sql
+  - [bug] Added support for using the .key
+    of a Column as a string identifier in a 
+    result set row.   The .key is currently
+    listed as an "alternate" name for a column,
+    and is superseded by the name of a column 
+    which has that key value as its regular name.
+    For the next major release
+    of SQLAlchemy we may reverse this precedence
+    so that .key takes precedence, but this
+    is not decided on yet.  [ticket:2392]
+
   - [bug] A significant change to how labeling
     is applied to columns in SELECT statements
     allows "truncated" labels, that is label names
index 598e879326a06b8e19cffd6775ee1c1494cff0d5..de18c48f91b4978b78075ed7610c69a69a6ca4d1 100644 (file)
@@ -154,9 +154,10 @@ class _CompileLabel(visitors.Visitable):
     __visit_name__ = 'label'
     __slots__ = 'element', 'name'
 
-    def __init__(self, col, name):
+    def __init__(self, col, name, alt_names=()):
         self.element = col
         self.name = name
+        self._alt_names = alt_names
 
     @property
     def proxy_set(self):
@@ -360,8 +361,10 @@ class SQLCompiler(engine.Compiled):
                 labelname = label.name
 
             if result_map is not None:
-                result_map[labelname.lower()] = \
-                        (label.name, (label, label.element, labelname),\
+                result_map[labelname.lower()] = (
+                        label.name, 
+                        (label, label.element, labelname, ) + 
+                            label._alt_names,
                         label.type)
 
             return label.element._compiler_dispatch(self, 
@@ -386,7 +389,9 @@ class SQLCompiler(engine.Compiled):
             name = self._truncated_identifier("colident", name)
 
         if result_map is not None:
-            result_map[name.lower()] = (orig_name, (column, name), column.type)
+            result_map[name.lower()] = (orig_name, 
+                                        (column, name, column.key), 
+                                        column.type)
 
         if is_literal:
             name = self.escape_literal_column(name)
@@ -775,8 +780,14 @@ class SQLCompiler(engine.Compiled):
         if isinstance(column, sql._Label):
             return column
 
-        elif select is not None and select.use_labels and column._label:
-            return _CompileLabel(column, column._label)
+        elif select is not None and \
+                select.use_labels and \
+                column._label:
+            return _CompileLabel(
+                    column, 
+                    column._label, 
+                    alt_names=(column._key_label, )
+                )
 
         elif \
             asfrom and \
@@ -784,7 +795,8 @@ class SQLCompiler(engine.Compiled):
             not column.is_literal and \
             column.table is not None and \
             not isinstance(column.table, sql.Select):
-            return _CompileLabel(column, sql._as_truncated(column.name))
+            return _CompileLabel(column, sql._as_truncated(column.name), 
+                                        alt_names=(column.key,))
         elif not isinstance(column, 
                     (sql._UnaryExpression, sql._TextClause)) \
                 and (not hasattr(column, 'name') or \
index 939456b9a611dc8858d635d40a9f8c8ec3983ca4..b11e5ad429845bf22df7d2d2d104054ecdfa9e0d 100644 (file)
@@ -2105,6 +2105,8 @@ class ColumnElement(ClauseElement, _CompareMixin):
     foreign_keys = []
     quote = None
     _label = None
+    _key_label = None
+    _alt_names = ()
 
     @property
     def _select_iterable(self):
@@ -3851,9 +3853,12 @@ class _Label(ColumnElement):
     def __init__(self, name, element, type_=None):
         while isinstance(element, _Label):
             element = element.element
-        self.name = self.key = self._label = name \
-            or _anonymous_label('%%(%d %s)s' % (id(self),
+        if name:
+            self.name = name
+        else:
+            self.name = _anonymous_label('%%(%d %s)s' % (id(self),
                                 getattr(element, 'name', 'anon')))
+        self.key = self._label = self._key_label = self.name
         self._element = element
         self._type = type_
         self.quote = element.quote
@@ -4000,8 +4005,18 @@ class ColumnClause(_Immutable, ColumnElement):
         return self.name.encode('ascii', 'backslashreplace')
         # end Py2K
 
+    @_memoized_property
+    def _key_label(self):
+        if self.key != self.name:
+            return self._gen_label(self.key)
+        else:
+            return self._label
+
     @_memoized_property
     def _label(self):
+        return self._gen_label(self.name)
+
+    def _gen_label(self, name):
         t = self.table
         if self.is_literal:
             return None
@@ -4009,9 +4024,9 @@ class ColumnClause(_Immutable, ColumnElement):
         elif t is not None and t.named_with_column:
             if getattr(t, 'schema', None):
                 label = t.schema.replace('.', '_') + "_" + \
-                            t.name + "_" + self.name
+                            t.name + "_" + name
             else:
-                label = t.name + "_" + self.name
+                label = t.name + "_" + name
 
             # ensure the label name doesn't conflict with that
             # of an existing column
@@ -4026,7 +4041,7 @@ class ColumnClause(_Immutable, ColumnElement):
             return _as_truncated(label)
 
         else:
-            return self.name
+            return name
 
     def label(self, name):
         # currently, anonymous labels don't occur for 
@@ -5041,7 +5056,9 @@ class Select(_SelectBase):
     def _populate_column_collection(self):
         for c in self.inner_columns:
             if hasattr(c, '_make_proxy'):
-                c._make_proxy(self, name=self.use_labels and c._label or None)
+                c._make_proxy(self, 
+                        name=self.use_labels 
+                            and c._label or None)
 
     def self_group(self, against=None):
         """return a 'grouping' construct as per the ClauseElement
index d9fad94934c4b9b2cc7a3c34711b700f54bbc5f8..6330ee34e9ae2c7ba86d1684b1b016278cb87618 100644 (file)
@@ -64,6 +64,12 @@ addresses = table('addresses',
     column('zip')
 )
 
+keyed = Table('keyed', metadata,
+    Column('x', Integer, key='colx'),
+    Column('y', Integer, key='coly'),
+    Column('z', Integer),
+)
+
 class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
@@ -242,6 +248,20 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT sum(lala(mytable.myid)) AS bar FROM mytable"
         )
 
+        # changes with #2397
+        self.assert_compile(
+            select([keyed]),
+            "SELECT keyed.x, keyed.y"
+            ", keyed.z FROM keyed"
+        )
+
+        # changes with #2397
+        self.assert_compile(
+            select([keyed]).apply_labels(),
+            "SELECT keyed.x AS keyed_x, keyed.y AS "
+            "keyed_y, keyed.z AS keyed_z FROM keyed"
+        )
+
     def test_paramstyles(self):
         stmt = text("select :foo, :bar, :bat from sometable")
 
@@ -272,7 +292,8 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_dupe_columns(self):
-        """test that deduping is performed against clause element identity, not rendered result."""
+        """test that deduping is performed against clause 
+        element identity, not rendered result."""
 
         self.assert_compile(
             select([column('a'), column('a'), column('a')]),
@@ -294,6 +315,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             , dialect=default.DefaultDialect()
         )
 
+        # using alternate keys.  
+        # this will change with #2397
+        a, b, c = Column('a', Integer, key='b'), \
+                    Column('b', Integer), \
+                    Column('c', Integer, key='a')
+        self.assert_compile(
+            select([a, b, c, a, b, c]),
+            "SELECT a, b, c"
+            , dialect=default.DefaultDialect()
+        )
+
         self.assert_compile(
             select([bindparam('a'), bindparam('b'), bindparam('c')]),
             "SELECT :a AS anon_1, :b AS anon_2, :c AS anon_3"
@@ -315,12 +347,10 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         s = s.compile(dialect=default.DefaultDialect(paramstyle='qmark'))
         eq_(s.positiontup, ['a', 'b', 'c'])
 
-    def test_nested_uselabels(self):
-        """test nested anonymous label generation.  this
-        essentially tests the ANONYMOUS_LABEL regex.
+    def test_nested_label_targeting(self):
+        """test nested anonymous label generation.  
 
         """
-
         s1 = table1.select()
         s2 = s1.alias()
         s3 = select([s2], use_labels=True)
@@ -339,6 +369,30 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                             'AS description FROM mytable) AS anon_2) '
                             'AS anon_1')
 
+    def test_nested_label_targeting_keyed(self):
+        # this behavior chagnes with #2397
+        s1 = keyed.select()
+        s2 = s1.alias()
+        s3 = select([s2], use_labels=True)
+        self.assert_compile(s3,
+                    "SELECT anon_1.x AS anon_1_x, "
+                    "anon_1.y AS anon_1_y, "
+                    "anon_1.z AS anon_1_z FROM "
+                    "(SELECT keyed.x AS x, keyed.y "
+                    "AS y, keyed.z AS z FROM keyed) AS anon_1")
+
+        s4 = s3.alias()
+        s5 = select([s4], use_labels=True)
+        self.assert_compile(s5,
+                    "SELECT anon_1.anon_2_x AS anon_1_anon_2_x, "
+                    "anon_1.anon_2_y AS anon_1_anon_2_y, "
+                    "anon_1.anon_2_z AS anon_1_anon_2_z "
+                    "FROM (SELECT anon_2.x AS anon_2_x, anon_2.y AS anon_2_y, "
+                    "anon_2.z AS anon_2_z FROM "
+                    "(SELECT keyed.x AS x, keyed.y AS y, keyed.z "
+                    "AS z FROM keyed) AS anon_2) AS anon_1"
+                    )
+
     def test_dont_overcorrelate(self):
         self.assert_compile(select([table1], from_obj=[table1,
                             table1.select()]),
index f9ec82a6aff81e7634b186e598b85a7bc4ea4e72..6b1e516ec9439a7aea379c77e5c5dd0b5813ca98 100644 (file)
@@ -29,6 +29,7 @@ class QueryTest(fixtures.TestBase):
             Column('user_name', VARCHAR(20)),
             test_needs_acid=True
         )
+
         metadata.create_all()
 
     @engines.close_first
@@ -264,7 +265,6 @@ class QueryTest(fixtures.TestBase):
         )
 
         concat = ("test: " + users.c.user_name).label('thedata')
-        print select([concat]).order_by("thedata")
         eq_(
             select([concat]).order_by("thedata").execute().fetchall(),
             [("test: ed",), ("test: fred",), ("test: jack",)]
@@ -1207,6 +1207,168 @@ class PercentSchemaNamesTest(fixtures.TestBase):
             ]
         )
 
+class KeyTargetingTest(fixtures.TablesTest):
+    run_inserts = 'once'
+    run_deletes = None
+
+    @classmethod
+    def define_tables(cls, metadata):
+        keyed1 = Table('keyed1', metadata,
+                Column("a", CHAR(2), key="b"),
+                Column("c", CHAR(2), key="q")
+        )
+        keyed2 = Table('keyed2', metadata,
+                Column("a", CHAR(2)),
+                Column("b", CHAR(2)),
+        )
+        keyed3 = Table('keyed3', metadata,
+                Column("a", CHAR(2)),
+                Column("d", CHAR(2)),
+        )
+        keyed4 = Table('keyed4', metadata,
+                Column("b", CHAR(2)),
+                Column("q", CHAR(2)),
+        )
+
+        content = Table('content', metadata,
+            Column('t', String(30), key="type"),
+        )
+        bar = Table('bar', metadata, 
+            Column('ctype', String(30), key="content_type")
+        )
+
+    @classmethod
+    def insert_data(cls):
+        cls.tables.keyed1.insert().execute(dict(b="a1", q="c1"))
+        cls.tables.keyed2.insert().execute(dict(a="a2", b="b2"))
+        cls.tables.keyed3.insert().execute(dict(a="a3", d="d3"))
+        cls.tables.keyed4.insert().execute(dict(b="b4", q="q4"))
+        cls.tables.content.insert().execute(type="t1")
+
+    def test_keyed_accessor_single(self):
+        keyed1 = self.tables.keyed1
+        row = testing.db.execute(keyed1.select()).first()
+
+        eq_(row.b, "a1")
+        eq_(row.q, "c1")
+        eq_(row.a, "a1")
+        eq_(row.c, "c1")
+
+    def test_keyed_accessor_single_labeled(self):
+        keyed1 = self.tables.keyed1
+        row = testing.db.execute(keyed1.select().apply_labels()).first()
+
+        eq_(row.keyed1_b, "a1")
+        eq_(row.keyed1_q, "c1")
+        eq_(row.keyed1_a, "a1")
+        eq_(row.keyed1_c, "c1")
+
+    def test_keyed_accessor_composite_conflict_2(self):
+        keyed1 = self.tables.keyed1
+        keyed2 = self.tables.keyed2
+
+        row = testing.db.execute(select([keyed1, keyed2])).first()
+        # without #2397, row.b is unambiguous
+        eq_(row.b, "b2")
+        # row.a is ambiguous
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambig",
+            getattr, row, "a"
+        )
+
+    @testing.fails_if(lambda: True, "Possible future behavior")
+    def test_keyed_accessor_composite_conflict_2397(self):
+        keyed1 = self.tables.keyed1
+        keyed2 = self.tables.keyed2
+
+        row = testing.db.execute(select([keyed1, keyed2])).first()
+        # with #2397, row.a is unambiguous
+        eq_(row.a, "a2")
+        # row.b is ambiguous
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name 'b'",
+            getattr, row, 'b'
+        )
+
+    def test_keyed_accessor_composite_names_precedent(self):
+        keyed1 = self.tables.keyed1
+        keyed4 = self.tables.keyed4
+
+        row = testing.db.execute(select([keyed1, keyed4])).first()
+        eq_(row.b, "b4")
+        eq_(row.q, "q4")
+        eq_(row.a, "a1")
+        eq_(row.c, "c1")
+
+    def test_keyed_accessor_composite_keys_precedent(self):
+        keyed1 = self.tables.keyed1
+        keyed3 = self.tables.keyed3
+
+        row = testing.db.execute(select([keyed1, keyed3])).first()
+        assert 'b' not in row
+        eq_(row.q, "c1")
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name 'a'",
+            getattr, row, "a"
+        )
+        eq_(row.d, "d3")
+
+    @testing.fails_if(lambda: True, "Possible future behavior")
+    def test_keyed_accessor_composite_2397(self):
+        keyed1 = self.tables.keyed1
+        keyed3 = self.tables.keyed3
+
+        row = testing.db.execute(select([keyed1, keyed3])).first()
+        eq_(row.b, "a1")
+        eq_(row.q, "c1")
+        eq_(row.a, "a3")
+        eq_(row.d, "d3")
+
+    def test_keyed_accessor_composite_labeled(self):
+        keyed1 = self.tables.keyed1
+        keyed2 = self.tables.keyed2
+
+        row = testing.db.execute(select([keyed1, keyed2]).apply_labels()).first()
+        eq_(row.keyed1_b, "a1")
+        eq_(row.keyed1_a, "a1")
+        eq_(row.keyed1_q, "c1")
+        eq_(row.keyed1_c, "c1")
+        eq_(row.keyed2_a, "a2")
+        eq_(row.keyed2_b, "b2")
+        assert_raises(KeyError, lambda: row['keyed2_c'])
+        assert_raises(KeyError, lambda: row['keyed2_q'])
+
+    def test_column_label_overlap_fallback(self):
+        content, bar = self.tables.content, self.tables.bar
+        row = testing.db.execute(select([content.c.type.label("content_type")])).first()
+        assert content.c.type in row
+        assert bar.c.content_type not in row
+        assert sql.column('content_type') in row
+
+        row = testing.db.execute(select([func.now().label("content_type")])).first()
+        assert content.c.type not in row
+        assert bar.c.content_type not in row
+        assert sql.column('content_type') in row
+
+    def test_column_label_overlap_fallback_2(self):
+        # this fails with #2397
+        content, bar = self.tables.content, self.tables.bar
+        row = testing.db.execute(content.select(use_labels=True)).first()
+        assert content.c.type in row
+        assert bar.c.content_type not in row
+        assert sql.column('content_type') not in row
+
+    @testing.fails_if(lambda: True, "Possible future behavior")
+    def test_column_label_overlap_fallback_3(self):
+        # this passes with #2397
+        content, bar = self.tables.content, self.tables.bar
+        row = testing.db.execute(content.select(use_labels=True)).first()
+        assert content.c.type in row
+        assert bar.c.content_type not in row
+        assert sql.column('content_type') in row
 
 
 class LimitTest(fixtures.TestBase):
index c421a521f9180e55f17b6d1d3b805b7c427d3f56..952b14763868e75742f92b49af20d67712913114 100644 (file)
@@ -34,7 +34,7 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
         table1.drop()
         table2.drop()
 
-    def testbasic(self):
+    def test_basic(self):
         table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'a123':4},
                 {'lowercase':2,'UPPERCASE':2,'MixedCase':3,'a123':4},
                 {'lowercase':4,'UPPERCASE':3,'MixedCase':2,'a123':1})
@@ -59,12 +59,12 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
             ')'
         )
 
-    def testreflect(self):
+    def test_reflect(self):
         meta2 = MetaData(testing.db)
         t2 = Table('WorstCase2', meta2, autoload=True, quote=True)
         assert 'MixedCase' in t2.c
 
-    def testlabels(self):
+    def test_labels(self):
         table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'a123':4},
                 {'lowercase':2,'UPPERCASE':2,'MixedCase':3,'a123':4},
                 {'lowercase':4,'UPPERCASE':3,'MixedCase':2,'a123':1})
@@ -136,7 +136,7 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
 
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
     @testing.requires.subqueries
-    def testlabels(self):
+    def test_labels(self):
         """test the quoting of labels.
 
         if labels arent quoted, a query in postgresql in particular will fail since it produces:
@@ -151,7 +151,7 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
 
         x = table1.select(distinct=True).alias("LaLa").select().scalar()
 
-    def testlabels2(self):
+    def test_labels2(self):
         metadata = MetaData()
         table = Table("ImATable", metadata,
             Column("col1", Integer))
@@ -174,14 +174,14 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
         metadata = MetaData()
         table = Table("ImATable", metadata,
             Column("col1", Integer),
-            Column("from", Integer, key="morf"),
+            Column("from", Integer),
             Column("louisville", Integer),
             Column("order", Integer))
-        x = select([table.c.col1, table.c.morf, table.c.louisville, table.c.order])
+        x = select([table.c.col1, table.c['from'], table.c.louisville, table.c.order])
 
         self.assert_compile(x, 
             '''SELECT "ImATable".col1, "ImATable"."from", "ImATable".louisville, "ImATable"."order" FROM "ImATable"''')
-        
+
 
 class PreparerTest(fixtures.TestBase):
     """Test the db-agnostic quoting services of IdentifierPreparer."""
index a4c3ddf40cdd343fc2d634f36955881acc47386d..8f599f1d6dd675a4d3e99f74d1dc5a31d6410b77 100644 (file)
@@ -26,6 +26,11 @@ table2 = Table('table2', metadata,
     Column('coly', Integer),
 )
 
+keyed = Table('keyed', metadata,
+    Column('x', Integer, key='colx'),
+    Column('y', Integer, key='coly'),
+    Column('z', Integer),
+)
 
 class SelectableTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     __dialect__ = 'default'
@@ -91,6 +96,24 @@ class SelectableTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
         assert sel3.corresponding_column(l1) is sel3.c.foo
         assert sel3.corresponding_column(l2) is sel3.c.bar
 
+    def test_keyed_gen(self):
+        s = select([keyed])
+        eq_(s.c.colx.key, 'colx')
+
+        # this would change to 'colx' 
+        # with #2397
+        eq_(s.c.colx.name, 'x')
+
+        assert s.corresponding_column(keyed.c.colx) is s.c.colx
+        assert s.corresponding_column(keyed.c.coly) is s.c.coly
+        assert s.corresponding_column(keyed.c.z) is s.c.z
+
+        sel2 = s.alias()
+        assert sel2.corresponding_column(keyed.c.colx) is sel2.c.colx
+        assert sel2.corresponding_column(keyed.c.coly) is sel2.c.coly
+        assert sel2.corresponding_column(keyed.c.z) is sel2.c.z
+
+
     def test_distance_on_aliases(self):
         a1 = table1.alias('a1')
         for s in (select([a1, table1], use_labels=True),