From 1ecbf14cc24aa0b1d303926178941c1f7f9fe93b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 27 Jul 2022 11:36:57 -0400 Subject: [PATCH] implement tuple-slices from .c collections Added new syntax to the ``.c`` collection on all :class:`.FromClause` objects allowing tuples of keys to be passed to ``__getitem__()``, along with support for ``select()`` handling of ``.c`` collections directly, allowing the syntax ``select(table.c['a', 'b', 'c'])`` to be possible. The sub-collection returned is itself a :class:`.ColumnCollection` which is also directly consumable by :func:`_sql.select` and similar now. Fixes: #8285 Change-Id: I2236662c477ffc50af079310589e213323c960d1 --- doc/build/changelog/unreleased_20/8285.rst | 15 ++++ doc/build/core/metadata.rst | 7 +- doc/build/tutorial/data_select.rst | 13 ++++ lib/sqlalchemy/sql/base.py | 86 +++++++++++++++------- test/base/test_utils.py | 27 ++++++- test/orm/test_core_compilation.py | 20 +++++ test/sql/test_select.py | 71 ++++++++++++++++++ 7 files changed, 208 insertions(+), 31 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/8285.rst diff --git a/doc/build/changelog/unreleased_20/8285.rst b/doc/build/changelog/unreleased_20/8285.rst new file mode 100644 index 0000000000..e1a351b9d3 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8285.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: feature, sql + :tickets: 8285 + + Added new syntax to the :attr:`.FromClause.c` collection on all + :class:`.FromClause` objects allowing tuples of keys to be passed to + ``__getitem__()``, along with support for the :func:`_sql.select` construct + to handle the resulting tuple-like collection directly, allowing the syntax + ``select(table.c['a', 'b', 'c'])`` to be possible. The sub-collection + returned is itself a :class:`.ColumnCollection` which is also directly + consumable by :func:`_sql.select` and similar now. + + .. seealso:: + + :ref:`tutorial_selecting_columns` diff --git a/doc/build/core/metadata.rst b/doc/build/core/metadata.rst index e000022a3d..9cdc1e2562 100644 --- a/doc/build/core/metadata.rst +++ b/doc/build/core/metadata.rst @@ -98,6 +98,10 @@ table include:: # via string employees.c['employee_id'] + # a tuple of columns may be returned using multiple strings + # (new in 2.0) + emp_id, name, type = employees.c['employee_id', "name", "type"] + # iterate through all columns for c in employees.c: print(c) @@ -113,9 +117,6 @@ table include:: # access the table's MetaData: employees.metadata - # access the table's bound Engine or Connection, if its MetaData is bound: - employees.bind - # access a column's name, type, nullable, primary key, foreign key employees.c.employee_id.name employees.c.employee_id.type diff --git a/doc/build/tutorial/data_select.rst b/doc/build/tutorial/data_select.rst index 6d52b46273..b55113fd3f 100644 --- a/doc/build/tutorial/data_select.rst +++ b/doc/build/tutorial/data_select.rst @@ -95,6 +95,7 @@ elements within each row: The following sections will discuss the SELECT construct in more detail. +.. _tutorial_selecting_columns: Setting the COLUMNS and FROM clause ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -120,6 +121,18 @@ are represented by those columns:: {opensql}SELECT user_account.name, user_account.fullname FROM user_account +Alternatively, when using the :attr:`.FromClause.c` collection of any +:class:`.FromClause` such as :class:`.Table`, multiple columns may be specified +for a :func:`_sql.select` by using a tuple of string names:: + + >>> print(select(user_table.c['name', 'fullname'])) + {opensql}SELECT user_account.name, user_account.fullname + FROM user_account + +.. versionadded:: 2.0 Added tuple-accessor capability to the + :attr`.FromClause.c` collection + + .. _tutorial_selecting_orm_entities: Selecting ORM Entities and Columns diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 70c01d8d3c..fbbf9f7f73 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -32,6 +32,7 @@ from typing import Mapping from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Set from typing import Tuple @@ -64,6 +65,7 @@ if TYPE_CHECKING: from . import elements from . import type_api from .elements import BindParameter + from .elements import ClauseList from .elements import ColumnClause # noqa from .elements import ColumnElement from .elements import KeyedColumnElement @@ -1396,7 +1398,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): __slots__ = "_collection", "_index", "_colset" _collection: List[Tuple[_COLKEY, _COL_co]] - _index: Dict[Union[None, str, int], _COL_co] + _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]] _colset: Set[_COL_co] def __init__( @@ -1408,6 +1410,16 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): if columns: self._initial_populate(columns) + @util.preload_module("sqlalchemy.sql.elements") + def __clause_element__(self) -> ClauseList: + elements = util.preloaded.sql_elements + + return elements.ClauseList( + _literal_as_text_role=roles.ColumnsClauseRole, + group=False, + *self._all_columns, + ) + def _initial_populate( self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] ) -> None: @@ -1415,18 +1427,18 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): @property def _all_columns(self) -> List[_COL_co]: - return [col for (k, col) in self._collection] + return [col for (_, col) in self._collection] def keys(self) -> List[_COLKEY]: """Return a sequence of string key names for all columns in this collection.""" - return [k for (k, col) in self._collection] + return [k for (k, _) in self._collection] def values(self) -> List[_COL_co]: """Return a sequence of :class:`_sql.ColumnClause` or :class:`_schema.Column` objects for all columns in this collection.""" - return [col for (k, col) in self._collection] + return [col for (_, col) in self._collection] def items(self) -> List[Tuple[_COLKEY, _COL_co]]: """Return a sequence of (key, column) tuples for all columns in this @@ -1445,20 +1457,37 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): def __iter__(self) -> Iterator[_COL_co]: # turn to a list first to maintain over a course of changes - return iter([col for k, col in self._collection]) + return iter([col for _, col in self._collection]) + @overload def __getitem__(self, key: Union[str, int]) -> _COL_co: + ... + + @overload + def __getitem__( + self, key: Tuple[Union[str, int], ...] + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: + ... + + def __getitem__( + self, key: Union[str, int, Tuple[Union[str, int], ...]] + ) -> Union[ReadOnlyColumnCollection[_COLKEY, _COL_co], _COL_co]: try: - return self._index[key] + if isinstance(key, tuple): + return ColumnCollection( # type: ignore + [self._index[sub_key] for sub_key in key] + ).as_readonly() + else: + return self._index[key][1] except KeyError as err: - if isinstance(key, int): - raise IndexError(key) from err + if isinstance(err.args[0], int): + raise IndexError(err.args[0]) from err else: raise def __getattr__(self, key: str) -> _COL_co: try: - return self._index[key] + return self._index[key][1] except KeyError as err: raise AttributeError(key) from err @@ -1493,7 +1522,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): :class:`_expression.ColumnCollection`.""" if key in self._index: - return self._index[key] + return self._index[key][1] else: return default @@ -1537,9 +1566,11 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): self._collection[:] = cols self._colset.update(c for k, c in self._collection) self._index.update( - (idx, c) for idx, (k, c) in enumerate(self._collection) + (idx, (k, c)) for idx, (k, c) in enumerate(self._collection) + ) + self._index.update( + {k: (k, col) for k, col in reversed(self._collection)} ) - self._index.update({k: col for k, col in reversed(self._collection)}) def add( self, column: ColumnElement[Any], key: Optional[_COLKEY] = None @@ -1571,12 +1602,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): self._collection.append((colkey, _column)) self._colset.add(_column) - self._index[l] = _column + self._index[l] = (colkey, _column) if colkey not in self._index: - self._index[colkey] = _column + self._index[colkey] = (colkey, _column) def __getstate__(self) -> Dict[str, Any]: - return {"_collection": self._collection, "_index": self._index} + return { + "_collection": self._collection, + "_index": self._index, + } def __setstate__(self, state: Dict[str, Any]) -> None: object.__setattr__(self, "_index", state["_index"]) @@ -1652,7 +1686,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): col, intersect = None, None target_set = column.proxy_set - cols = [c for (k, c) in self._collection] + cols = [c for (_, c) in self._collection] for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) @@ -1739,7 +1773,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): if key in self._index: - existing = self._index[key] + existing = self._index[key][1] if existing is named_column: return @@ -1754,8 +1788,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): l = len(self._collection) self._collection.append((key, named_column)) self._colset.add(named_column) - self._index[l] = named_column - self._index[key] = named_column + self._index[l] = (key, named_column) + self._index[key] = (key, named_column) def _populate_separate_keys( self, iter_: Iterable[Tuple[str, _NAMEDCOL]] @@ -1775,12 +1809,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): elif col.key in self._index: replace_col.append(col) else: - self._index[k] = col + self._index[k] = (k, col) self._collection.append((k, col)) self._colset.update(c for (k, c) in self._collection) self._index.update( - (idx, c) for idx, (k, c) in enumerate(self._collection) + (idx, (k, c)) for idx, (k, c) in enumerate(self._collection) ) for col in replace_col: self.replace(col) @@ -1801,7 +1835,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): ] self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} + {idx: (k, col) for idx, (k, col) in enumerate(self._collection)} ) # delete higher index del self._index[len(self._collection)] @@ -1826,12 +1860,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): remove_col = set() # remove up to two columns based on matches of name as well as key if column.name in self._index and column.key != column.name: - other = self._index[column.name] + other = self._index[column.name][1] if other.name == other.key: remove_col.add(other) if column.key in self._index: - remove_col.add(self._index[column.key]) + remove_col.add(self._index[column.key][1]) new_cols: List[Tuple[str, _NAMEDCOL]] = [] replaced = False @@ -1855,9 +1889,9 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._index.clear() self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} + {idx: (k, col) for idx, (k, col) in enumerate(self._collection)} ) - self._index.update(self._collection) + self._index.update({k: (k, col) for (k, col) in self._collection}) class ReadOnlyColumnCollection( diff --git a/test/base/test_utils.py b/test/base/test_utils.py index c5a47ddf97..98451cc4f1 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -550,8 +550,10 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL): eq_(coll._colset, set(c for k, c in coll._collection)) d = {} for k, col in coll._collection: - d.setdefault(k, col) - d.update({idx: col for idx, (k, col) in enumerate(coll._collection)}) + d.setdefault(k, (k, col)) + d.update( + {idx: (k, col) for idx, (k, col) in enumerate(coll._collection)} + ) eq_(coll._index, d) def test_keys(self): @@ -593,6 +595,27 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL): ci = cc.as_readonly() eq_(ci.items(), [("c1", c1), ("foo", c2), ("c3", c3)]) + def test_getitem_tuple_str(self): + c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3") + c2.key = "foo" + cc = self._column_collection( + columns=[("c1", c1), ("foo", c2), ("c3", c3)] + ) + sub_cc = cc["c3", "foo"] + is_(sub_cc.c3, c3) + eq_(list(sub_cc), [c3, c2]) + + def test_getitem_tuple_int(self): + c1, c2, c3 = sql.column("c1"), sql.column("c2"), sql.column("c3") + c2.key = "foo" + cc = self._column_collection( + columns=[("c1", c1), ("foo", c2), ("c3", c3)] + ) + + sub_cc = cc[2, 1] + is_(sub_cc.c3, c3) + eq_(list(sub_cc), [c3, c2]) + def test_key_index_error(self): cc = self._column_collection( columns=[ diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 19b1c34b14..5807b619f3 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -71,6 +71,26 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): eq_(s1.subquery().c.keys(), ["id"]) eq_(s1.subquery().c.keys(), ["id"]) + def test_integration_w_8285_subc(self): + Address = self.classes.Address + + s1 = select( + Address.id, Address.__table__.c["user_id", "email_address"] + ) + self.assert_compile( + s1, + "SELECT addresses.id, addresses.user_id, " + "addresses.email_address FROM addresses", + ) + + subq = s1.subquery() + self.assert_compile( + select(subq.c.user_id, subq.c.id), + "SELECT anon_1.user_id, anon_1.id FROM (SELECT addresses.id AS " + "id, addresses.user_id AS user_id, addresses.email_address " + "AS email_address FROM addresses) AS anon_1", + ) + def test_scalar_subquery_from_subq_same_source(self): """test #6394, ensure all_selected_columns is generated each time""" User = self.classes.User diff --git a/test/sql/test_select.py b/test/sql/test_select.py index d91e50e637..ad4b4db959 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -16,8 +16,10 @@ from sqlalchemy.sql import literal from sqlalchemy.sql import table from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ table1 = table( "mytable", @@ -442,3 +444,72 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): " %(joiner)s SELECT :param_2 AS anon_2" " %(joiner)s SELECT :param_3 AS anon_3" % {"joiner": joiner}, ) + + +class ColumnCollectionAsSelectTest(fixtures.TestBase, AssertsCompiledSQL): + """tests related to #8285.""" + + __dialect__ = "default" + + def test_c_collection_as_from(self): + stmt = select(parent.c) + + # this works because _all_selected_columns expands out + # ClauseList. it does so in the same way that it works for + # Table already. so this is free + eq_(stmt._all_selected_columns, [parent.c.id, parent.c.data]) + + self.assert_compile(stmt, "SELECT parent.id, parent.data FROM parent") + + def test_c_sub_collection_str_stmt(self): + stmt = select(table1.c["myid", "description"]) + + self.assert_compile( + stmt, "SELECT mytable.myid, mytable.description FROM mytable" + ) + + subq = stmt.subquery() + self.assert_compile( + select(subq.c[0]).where(subq.c.description == "x"), + "SELECT anon_1.myid FROM (SELECT mytable.myid AS myid, " + "mytable.description AS description FROM mytable) AS anon_1 " + "WHERE anon_1.description = :description_1", + ) + + def test_c_sub_collection_int_stmt(self): + stmt = select(table1.c[2, 0]) + + self.assert_compile( + stmt, "SELECT mytable.description, mytable.myid FROM mytable" + ) + + subq = stmt.subquery() + self.assert_compile( + select(subq.c.myid).where(subq.c[1] == "x"), + "SELECT anon_1.myid FROM (SELECT mytable.description AS " + "description, mytable.myid AS myid FROM mytable) AS anon_1 " + "WHERE anon_1.myid = :myid_1", + ) + + def test_c_sub_collection_str(self): + coll = table1.c["myid", "description"] + is_(coll.myid, table1.c.myid) + + eq_(list(coll), [table1.c.myid, table1.c.description]) + + def test_c_sub_collection_int(self): + coll = table1.c[2, 0] + + is_(coll.myid, table1.c.myid) + + eq_(list(coll), [table1.c.description, table1.c.myid]) + + def test_missing_key(self): + + with expect_raises_message(KeyError, "unknown"): + table1.c["myid", "unknown"] + + def test_missing_index(self): + + with expect_raises_message(IndexError, "5"): + table1.c["myid", 5] -- 2.47.2