# Copyright (C) 2021 The Psycopg Team
-from typing import Any, Callable, Dict, Sequence
+import functools
+import re
+from collections import namedtuple
+from typing import Any, Callable, Dict, Sequence, Tuple, Type
from .cursor import BaseCursor
from .proto import ConnectionType
return dict(zip(titles, values))
return make_row
+
+
+def namedtuple_row(
+ cursor: BaseCursor[ConnectionType],
+) -> Callable[[Sequence[Any]], Tuple[Any, ...]]:
+ """Row factory to represent rows as `~collections.namedtuple`."""
+
+ def make_row(values: Sequence[Any]) -> Tuple[Any, ...]:
+ assert cursor.description
+ key = tuple(c.name for c in cursor.description)
+ nt = _make_nt(key)
+ rv = nt._make(values) # type: ignore[attr-defined]
+ return rv # type: ignore[no-any-return]
+
+ return make_row
+
+
+# ascii except alnum and underscore
+_re_clean = re.compile(
+ "[" + re.escape(" !\"#$%&'()*+,-./:;<=>?@[\\]^`{|}~") + "]"
+)
+
+
+@functools.lru_cache(512)
+def _make_nt(key: Sequence[str]) -> Type[Tuple[Any, ...]]:
+ fields = []
+ for s in key:
+ s = _re_clean.sub("_", s)
+ # Python identifier cannot start with numbers, namedtuple fields
+ # cannot start with underscore. So...
+ if s[0] == "_" or "0" <= s[0] <= "9":
+ s = "f" + s
+ fields.append(s)
+ return namedtuple("Row", fields)
assert cur.nextset()
assert cur.fetchall() == [{"number": 1}]
assert not cur.nextset()
+
+
+def test_namedtuple_row(conn):
+ cur = conn.cursor(row_factory=rows.namedtuple_row)
+ cur.execute("select 'bob' as name, 3 as id")
+ (person1,) = cur.fetchall()
+ assert f"{person1.name} {person1.id}" == "bob 3"
+
+ ci1 = rows._make_nt.cache_info()
+ assert ci1.hits == 0 and ci1.misses == 1
+
+ cur.execute("select 'alice' as name, 1 as id")
+ (person2,) = cur.fetchall()
+ assert type(person2) is type(person1)
+
+ ci2 = rows._make_nt.cache_info()
+ assert ci2.hits == 1 and ci2.misses == 1
+
+ cur.execute("select 'foo', 1 as id")
+ (r0,) = cur.fetchall()
+ assert r0.f_column_ == "foo"
+ assert r0.id == 1
+
+ cur.execute("select 'a' as letter; select 1 as number")
+ (r1,) = cur.fetchall()
+ assert r1.letter == "a"
+ assert cur.nextset()
+ (r2,) = cur.fetchall()
+ assert r2.number == 1
+ assert not cur.nextset()
+ assert type(r1) is not type(r2)