]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement tuple-slices from .c collections
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 27 Jul 2022 15:36:57 +0000 (11:36 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 1 Aug 2022 21:46:33 +0000 (21:46 +0000)
Added new syntax to the ``.c`` collection on all :class:`.FromClause`
objects allowing tuples of keys to be passed to ``__getitem__()``, along
with support for ``select()`` handling of ``.c`` collections directly,
allowing the syntax ``select(table.c['a', 'b', 'c'])`` to be possible. The
sub-collection returned is itself a :class:`.ColumnCollection` which is
also directly consumable by :func:`_sql.select` and similar now.

Fixes: #8285
Change-Id: I2236662c477ffc50af079310589e213323c960d1

doc/build/changelog/unreleased_20/8285.rst [new file with mode: 0644]
doc/build/core/metadata.rst
doc/build/tutorial/data_select.rst
lib/sqlalchemy/sql/base.py
test/base/test_utils.py
test/orm/test_core_compilation.py
test/sql/test_select.py

diff --git a/doc/build/changelog/unreleased_20/8285.rst b/doc/build/changelog/unreleased_20/8285.rst
new file mode 100644 (file)
index 0000000..e1a351b
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: feature, sql
+    :tickets: 8285
+
+    Added new syntax to the :attr:`.FromClause.c` collection on all
+    :class:`.FromClause` objects allowing tuples of keys to be passed to
+    ``__getitem__()``, along with support for the :func:`_sql.select` construct
+    to handle the resulting tuple-like collection directly, allowing the syntax
+    ``select(table.c['a', 'b', 'c'])`` to be possible. The sub-collection
+    returned is itself a :class:`.ColumnCollection` which is also directly
+    consumable by :func:`_sql.select` and similar now.
+
+    .. seealso::
+
+        :ref:`tutorial_selecting_columns`
index e000022a3d96fb7571ea2298a503594e6718662e..9cdc1e25629218350c14e01133c26e265d449404 100644 (file)
@@ -98,6 +98,10 @@ table include::
     # via string
     employees.c['employee_id']
 
+    # a tuple of columns may be returned using multiple strings
+    # (new in 2.0)
+    emp_id, name, type = employees.c['employee_id', "name", "type"]
+
     # iterate through all columns
     for c in employees.c:
         print(c)
@@ -113,9 +117,6 @@ table include::
     # access the table's MetaData:
     employees.metadata
 
-    # access the table's bound Engine or Connection, if its MetaData is bound:
-    employees.bind
-
     # access a column's name, type, nullable, primary key, foreign key
     employees.c.employee_id.name
     employees.c.employee_id.type
index 6d52b4627324508d27f22d78572453cfe67920af..b55113fd3f7dea647b2d537c65a1a0d82b977d68 100644 (file)
@@ -95,6 +95,7 @@ elements within each row:
 
 The following sections will discuss the SELECT construct in more detail.
 
+.. _tutorial_selecting_columns:
 
 Setting the COLUMNS and FROM clause
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -120,6 +121,18 @@ are represented by those columns::
     {opensql}SELECT user_account.name, user_account.fullname
     FROM user_account
 
+Alternatively, when using the :attr:`.FromClause.c` collection of any
+:class:`.FromClause` such as :class:`.Table`, multiple columns may be specified
+for a :func:`_sql.select` by using a tuple of string names::
+
+    >>> print(select(user_table.c['name', 'fullname']))
+    {opensql}SELECT user_account.name, user_account.fullname
+    FROM user_account
+
+.. versionadded:: 2.0 Added tuple-accessor capability to the
+   :attr`.FromClause.c` collection
+
+
 .. _tutorial_selecting_orm_entities:
 
 Selecting ORM Entities and Columns
index 70c01d8d3c94c8c9f3924114621ae3d1ce3e3cfd..fbbf9f7f73efb78f9b8342eda7a09a1a0afcb2f7 100644 (file)
@@ -32,6 +32,7 @@ from typing import Mapping
 from typing import MutableMapping
 from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Set
 from typing import Tuple
@@ -64,6 +65,7 @@ if TYPE_CHECKING:
     from . import elements
     from . import type_api
     from .elements import BindParameter
+    from .elements import ClauseList
     from .elements import ColumnClause  # noqa
     from .elements import ColumnElement
     from .elements import KeyedColumnElement
@@ -1396,7 +1398,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
     __slots__ = "_collection", "_index", "_colset"
 
     _collection: List[Tuple[_COLKEY, _COL_co]]
-    _index: Dict[Union[None, str, int], _COL_co]
+    _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]]
     _colset: Set[_COL_co]
 
     def __init__(
@@ -1408,6 +1410,16 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
         if columns:
             self._initial_populate(columns)
 
+    @util.preload_module("sqlalchemy.sql.elements")
+    def __clause_element__(self) -> ClauseList:
+        elements = util.preloaded.sql_elements
+
+        return elements.ClauseList(
+            _literal_as_text_role=roles.ColumnsClauseRole,
+            group=False,
+            *self._all_columns,
+        )
+
     def _initial_populate(
         self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
     ) -> None:
@@ -1415,18 +1427,18 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
 
     @property
     def _all_columns(self) -> List[_COL_co]:
-        return [col for (k, col) in self._collection]
+        return [col for (_, col) in self._collection]
 
     def keys(self) -> List[_COLKEY]:
         """Return a sequence of string key names for all columns in this
         collection."""
-        return [k for (k, col) in self._collection]
+        return [k for (k, _) in self._collection]
 
     def values(self) -> List[_COL_co]:
         """Return a sequence of :class:`_sql.ColumnClause` or
         :class:`_schema.Column` objects for all columns in this
         collection."""
-        return [col for (k, col) in self._collection]
+        return [col for (_, col) in self._collection]
 
     def items(self) -> List[Tuple[_COLKEY, _COL_co]]:
         """Return a sequence of (key, column) tuples for all columns in this
@@ -1445,20 +1457,37 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
 
     def __iter__(self) -> Iterator[_COL_co]:
         # turn to a list first to maintain over a course of changes
-        return iter([col for k, col in self._collection])
+        return iter([col for _, col in self._collection])
 
+    @overload
     def __getitem__(self, key: Union[str, int]) -> _COL_co:
+        ...
+
+    @overload
+    def __getitem__(
+        self, key: Tuple[Union[str, int], ...]
+    ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]:
+        ...
+
+    def __getitem__(
+        self, key: Union[str, int, Tuple[Union[str, int], ...]]
+    ) -> Union[ReadOnlyColumnCollection[_COLKEY, _COL_co], _COL_co]:
         try:
-            return self._index[key]
+            if isinstance(key, tuple):
+                return ColumnCollection(  # type: ignore
+                    [self._index[sub_key] for sub_key in key]
+                ).as_readonly()
+            else:
+                return self._index[key][1]
         except KeyError as err:
-            if isinstance(key, int):
-                raise IndexError(key) from err
+            if isinstance(err.args[0], int):
+                raise IndexError(err.args[0]) from err
             else:
                 raise
 
     def __getattr__(self, key: str) -> _COL_co:
         try:
-            return self._index[key]
+            return self._index[key][1]
         except KeyError as err:
             raise AttributeError(key) from err
 
@@ -1493,7 +1522,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
         :class:`_expression.ColumnCollection`."""
 
         if key in self._index:
-            return self._index[key]
+            return self._index[key][1]
         else:
             return default
 
@@ -1537,9 +1566,11 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
         self._collection[:] = cols
         self._colset.update(c for k, c in self._collection)
         self._index.update(
-            (idx, c) for idx, (k, c) in enumerate(self._collection)
+            (idx, (k, c)) for idx, (k, c) in enumerate(self._collection)
+        )
+        self._index.update(
+            {k: (k, col) for k, col in reversed(self._collection)}
         )
-        self._index.update({k: col for k, col in reversed(self._collection)})
 
     def add(
         self, column: ColumnElement[Any], key: Optional[_COLKEY] = None
@@ -1571,12 +1602,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
 
         self._collection.append((colkey, _column))
         self._colset.add(_column)
-        self._index[l] = _column
+        self._index[l] = (colkey, _column)
         if colkey not in self._index:
-            self._index[colkey] = _column
+            self._index[colkey] = (colkey, _column)
 
     def __getstate__(self) -> Dict[str, Any]:
-        return {"_collection": self._collection, "_index": self._index}
+        return {
+            "_collection": self._collection,
+            "_index": self._index,
+        }
 
     def __setstate__(self, state: Dict[str, Any]) -> None:
         object.__setattr__(self, "_index", state["_index"])
@@ -1652,7 +1686,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
 
         col, intersect = None, None
         target_set = column.proxy_set
-        cols = [c for (k, c) in self._collection]
+        cols = [c for (_, c) in self._collection]
         for c in cols:
             expanded_proxy_set = set(_expand_cloned(c.proxy_set))
             i = target_set.intersection(expanded_proxy_set)
@@ -1739,7 +1773,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
 
         if key in self._index:
 
-            existing = self._index[key]
+            existing = self._index[key][1]
 
             if existing is named_column:
                 return
@@ -1754,8 +1788,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
             l = len(self._collection)
             self._collection.append((key, named_column))
             self._colset.add(named_column)
-            self._index[l] = named_column
-            self._index[key] = named_column
+            self._index[l] = (key, named_column)
+            self._index[key] = (key, named_column)
 
     def _populate_separate_keys(
         self, iter_: Iterable[Tuple[str, _NAMEDCOL]]
@@ -1775,12 +1809,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
             elif col.key in self._index:
                 replace_col.append(col)
             else:
-                self._index[k] = col
+                self._index[k] = (k, col)
                 self._collection.append((k, col))
         self._colset.update(c for (k, c) in self._collection)
 
         self._index.update(
-            (idx, c) for idx, (k, c) in enumerate(self._collection)
+            (idx, (k, c)) for idx, (k, c) in enumerate(self._collection)
         )
         for col in replace_col:
             self.replace(col)
@@ -1801,7 +1835,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
         ]
 
         self._index.update(
-            {idx: col for idx, (k, col) in enumerate(self._collection)}
+            {idx: (k, col) for idx, (k, col) in enumerate(self._collection)}
         )
         # delete higher index
         del self._index[len(self._collection)]
@@ -1826,12 +1860,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
         remove_col = set()
         # remove up to two columns based on matches of name as well as key
         if column.name in self._index and column.key != column.name:
-            other = self._index[column.name]
+            other = self._index[column.name][1]
             if other.name == other.key:
                 remove_col.add(other)
 
         if column.key in self._index:
-            remove_col.add(self._index[column.key])
+            remove_col.add(self._index[column.key][1])
 
         new_cols: List[Tuple[str, _NAMEDCOL]] = []
         replaced = False
@@ -1855,9 +1889,9 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
         self._index.clear()
 
         self._index.update(
-            {idx: col for idx, (k, col) in enumerate(self._collection)}
+            {idx: (k, col) for idx, (k, col) in enumerate(self._collection)}
         )
-        self._index.update(self._collection)
+        self._index.update({k: (k, col) for (k, col) in self._collection})
 
 
 class ReadOnlyColumnCollection(
index c5a47ddf97a8130600cda249de692f24731017e2..98451cc4f1716b1abe3ba683e6639e2a97cb8efc 100644 (file)
@@ -550,8 +550,10 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL):
         eq_(coll._colset, set(c for k, c in coll._collection))
         d = {}
         for k, col in coll._collection:
-            d.setdefault(k, col)
-        d.update({idx: col for idx, (k, col) in enumerate(coll._collection)})
+            d.setdefault(k, (k, col))
+        d.update(
+            {idx: (k, col) for idx, (k, col) in enumerate(coll._collection)}
+        )
         eq_(coll._index, d)
 
     def test_keys(self):
@@ -593,6 +595,27 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL):
         ci = cc.as_readonly()
         eq_(ci.items(), [("c1", c1), ("foo", c2), ("c3", c3)])
 
+    def test_getitem_tuple_str(self):
+        c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3")
+        c2.key = "foo"
+        cc = self._column_collection(
+            columns=[("c1", c1), ("foo", c2), ("c3", c3)]
+        )
+        sub_cc = cc["c3", "foo"]
+        is_(sub_cc.c3, c3)
+        eq_(list(sub_cc), [c3, c2])
+
+    def test_getitem_tuple_int(self):
+        c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3")
+        c2.key = "foo"
+        cc = self._column_collection(
+            columns=[("c1", c1), ("foo", c2), ("c3", c3)]
+        )
+
+        sub_cc = cc[2, 1]
+        is_(sub_cc.c3, c3)
+        eq_(list(sub_cc), [c3, c2])
+
     def test_key_index_error(self):
         cc = self._column_collection(
             columns=[
index 19b1c34b14e89b3eb074569876b6daefb78d05d6..5807b619f35ac3c7ef69fa0ca8cc44e9fb26123e 100644 (file)
@@ -71,6 +71,26 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
         eq_(s1.subquery().c.keys(), ["id"])
         eq_(s1.subquery().c.keys(), ["id"])
 
+    def test_integration_w_8285_subc(self):
+        Address = self.classes.Address
+
+        s1 = select(
+            Address.id, Address.__table__.c["user_id", "email_address"]
+        )
+        self.assert_compile(
+            s1,
+            "SELECT addresses.id, addresses.user_id, "
+            "addresses.email_address FROM addresses",
+        )
+
+        subq = s1.subquery()
+        self.assert_compile(
+            select(subq.c.user_id, subq.c.id),
+            "SELECT anon_1.user_id, anon_1.id FROM (SELECT addresses.id AS "
+            "id, addresses.user_id AS user_id, addresses.email_address "
+            "AS email_address FROM addresses) AS anon_1",
+        )
+
     def test_scalar_subquery_from_subq_same_source(self):
         """test #6394, ensure all_selected_columns is generated each time"""
         User = self.classes.User
index d91e50e6373271c317dadafee023ff96ad51d844..ad4b4db95916f19962876409b3e083899369621e 100644 (file)
@@ -16,8 +16,10 @@ from sqlalchemy.sql import literal
 from sqlalchemy.sql import table
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 
 table1 = table(
     "mytable",
@@ -442,3 +444,72 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             " %(joiner)s SELECT :param_2 AS anon_2"
             " %(joiner)s SELECT :param_3 AS anon_3" % {"joiner": joiner},
         )
+
+
+class ColumnCollectionAsSelectTest(fixtures.TestBase, AssertsCompiledSQL):
+    """tests related to #8285."""
+
+    __dialect__ = "default"
+
+    def test_c_collection_as_from(self):
+        stmt = select(parent.c)
+
+        # this works because _all_selected_columns expands out
+        # ClauseList.  it does so in the same way that it works for
+        # Table already.  so this is free
+        eq_(stmt._all_selected_columns, [parent.c.id, parent.c.data])
+
+        self.assert_compile(stmt, "SELECT parent.id, parent.data FROM parent")
+
+    def test_c_sub_collection_str_stmt(self):
+        stmt = select(table1.c["myid", "description"])
+
+        self.assert_compile(
+            stmt, "SELECT mytable.myid, mytable.description FROM mytable"
+        )
+
+        subq = stmt.subquery()
+        self.assert_compile(
+            select(subq.c[0]).where(subq.c.description == "x"),
+            "SELECT anon_1.myid FROM (SELECT mytable.myid AS myid, "
+            "mytable.description AS description FROM mytable) AS anon_1 "
+            "WHERE anon_1.description = :description_1",
+        )
+
+    def test_c_sub_collection_int_stmt(self):
+        stmt = select(table1.c[2, 0])
+
+        self.assert_compile(
+            stmt, "SELECT mytable.description, mytable.myid FROM mytable"
+        )
+
+        subq = stmt.subquery()
+        self.assert_compile(
+            select(subq.c.myid).where(subq.c[1] == "x"),
+            "SELECT anon_1.myid FROM (SELECT mytable.description AS "
+            "description, mytable.myid AS myid FROM mytable) AS anon_1 "
+            "WHERE anon_1.myid = :myid_1",
+        )
+
+    def test_c_sub_collection_str(self):
+        coll = table1.c["myid", "description"]
+        is_(coll.myid, table1.c.myid)
+
+        eq_(list(coll), [table1.c.myid, table1.c.description])
+
+    def test_c_sub_collection_int(self):
+        coll = table1.c[2, 0]
+
+        is_(coll.myid, table1.c.myid)
+
+        eq_(list(coll), [table1.c.description, table1.c.myid])
+
+    def test_missing_key(self):
+
+        with expect_raises_message(KeyError, "unknown"):
+            table1.c["myid", "unknown"]
+
+    def test_missing_index(self):
+
+        with expect_raises_message(IndexError, "5"):
+            table1.c["myid", 5]