From aabc72bd33ba445c0a207432acf0aa1cf25263cb Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 17 Mar 2021 08:39:45 -0400 Subject: [PATCH] Provide special row proxies for count and index The Python ``namedtuple()`` has the behavior such that the names ``count`` and ``index`` will be served as tuple values if the named tuple includes those names; if they are absent, then their behavior as methods of ``collections.abc.Sequence`` is maintained. Therefore the :class:`_result.Row` and :class:`_result.LegacyRow` classes have been fixed so that they work in this same way, maintaining the expected behavior for database rows that have columns named "index" or "count". Fixes: #6074 Change-Id: I49a093da02f33f231d22ed5999c09fcaa3a68601 --- doc/build/changelog/unreleased_14/6074.rst | 11 ++++++ lib/sqlalchemy/engine/row.py | 21 +++++++++++ test/sql/test_resultset.py | 43 ++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 doc/build/changelog/unreleased_14/6074.rst diff --git a/doc/build/changelog/unreleased_14/6074.rst b/doc/build/changelog/unreleased_14/6074.rst new file mode 100644 index 0000000000..88cb71eb1c --- /dev/null +++ b/doc/build/changelog/unreleased_14/6074.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, engine + :tickets: 6074 + + The Python ``namedtuple()`` has the behavior such that the names ``count`` + and ``index`` will be served as tuple values if the named tuple includes + those names; if they are absent, then their behavior as methods of + ``collections.abc.Sequence`` is maintained. Therefore the + :class:`_result.Row` and :class:`_result.LegacyRow` classes have been fixed + so that they work in this same way, maintaining the expected behavior for + database rows that have columns named "index" or "count". diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index ac65d1b18e..b870e6534a 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -220,6 +220,27 @@ class Row(BaseRow, collections_abc.Sequence): self._data, ) + def _special_name_accessor(name): + """Handle ambiguous names such as "count" and "index" """ + + @property + def go(self): + if self._parent._has_key(name): + return self.__getattr__(name) + else: + + def meth(*arg, **kw): + return getattr(collections_abc.Sequence, name)( + self, *arg, **kw + ) + + return meth + + return go + + count = _special_name_accessor("count") + index = _special_name_accessor("index") + def __contains__(self, key): return key in self._data diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index e99ce881ca..5439d63b57 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -28,6 +28,8 @@ from sqlalchemy import VARCHAR from sqlalchemy.engine import cursor as _cursor from sqlalchemy.engine import default from sqlalchemy.engine import Row +from sqlalchemy.engine.result import SimpleResultMetaData +from sqlalchemy.engine.row import LegacyRow from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import ColumnElement from sqlalchemy.sql import expression @@ -1324,6 +1326,47 @@ class CursorResultTest(fixtures.TablesTest): ) is_true(isinstance(row, collections_abc.Sequence)) + @testing.combinations((Row,), (LegacyRow,)) + def test_row_special_names(self, row_cls): + metadata = SimpleResultMetaData(["key", "count", "index"]) + row = row_cls( + metadata, + [None, None, None], + metadata._keymap, + Row._default_key_style, + ["kv", "cv", "iv"], + ) + is_true(isinstance(row, collections_abc.Sequence)) + + eq_(row.key, "kv") + eq_(row.count, "cv") + eq_(row.index, "iv") + + if isinstance(row, LegacyRow): + eq_(row["count"], "cv") + eq_(row["index"], "iv") + + eq_(row._mapping["count"], "cv") + eq_(row._mapping["index"], "iv") + + metadata = SimpleResultMetaData(["key", "q", "p"]) + + row = row_cls( + metadata, + [None, None, None], + metadata._keymap, + Row._default_key_style, + ["kv", "cv", "iv"], + ) + is_true(isinstance(row, collections_abc.Sequence)) + + eq_(row.key, "kv") + eq_(row.q, "cv") + eq_(row.p, "iv") + eq_(row.index("cv"), 1) + eq_(row.count("cv"), 1) + eq_(row.count("x"), 0) + def test_row_is_hashable(self): row = Row( -- 2.47.2