]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added columns attributes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Oct 2020 01:57:24 +0000 (02:57 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Oct 2020 02:14:12 +0000 (03:14 +0100)
psycopg3/psycopg3/cursor.py
tests/test_cursor.py

index 68c5c87fc32e297252ba053d2ccfef31740ab4db..b14670a9ceedf6d7d9fb8cb330dfb54ea76e9c53 100644 (file)
@@ -39,19 +39,39 @@ class Column(Sequence[Any]):
         self._encoding = encoding
 
     _attrs = tuple(
-        map(
-            attrgetter,
-            """
+        attrgetter(attr)
+        for attr in """
             name type_code display_size internal_size precision scale null_ok
-            """.split(),
-        )
+            """.split()
     )
 
+    def __repr__(self) -> str:
+        return f"<Column {self.name}, type: {self._type_display()}>"
+
     def __len__(self) -> int:
         return 7
 
+    def _type_display(self) -> str:
+        parts = []
+        t = builtins.get(self.type_code)
+        parts.append(t.name if t else str(self.type_code))
+
+        mod1 = self.precision
+        if mod1 is None:
+            mod1 = self.display_size
+        if mod1:
+            parts.append(f"({mod1}")
+            if self.scale:
+                parts.append(f", {self.scale}")
+            parts.append(")")
+
+        return "".join(parts)
+
     def __getitem__(self, index: Any) -> Any:
-        return self._attrs[index](self)
+        if isinstance(index, slice):
+            return tuple(getter(self) for getter in self._attrs[index])
+        else:
+            return self._attrs[index](self)
 
     @property
     def name(self) -> str:
@@ -67,6 +87,56 @@ class Column(Sequence[Any]):
     def type_code(self) -> int:
         return self._pgresult.ftype(self._index)
 
+    @property
+    def display_size(self) -> Optional[int]:
+        t = builtins.get(self.type_code)
+        if not t:
+            return None
+
+        if t.name in ("varchar", "char"):
+            fmod = self._pgresult.fmod(self._index)
+            if fmod >= 0:
+                return fmod - 4
+
+        return None
+
+    @property
+    def internal_size(self) -> Optional[int]:
+        fsize = self._pgresult.fsize(self._index)
+        return fsize if fsize >= 0 else None
+
+    @property
+    def precision(self) -> Optional[int]:
+        t = builtins.get(self.type_code)
+        if not t:
+            return None
+
+        dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
+        if t.name == "numeric":
+            fmod = self._pgresult.fmod(self._index)
+            if fmod >= 0:
+                return fmod >> 16
+
+        elif t.name in dttypes:
+            fmod = self._pgresult.fmod(self._index)
+            if fmod >= 0:
+                return fmod & 0xFFFF
+
+        return None
+
+    @property
+    def scale(self) -> Optional[int]:
+        if self.type_code == builtins["numeric"].oid:
+            fmod = self._pgresult.fmod(self._index) - 4
+            if fmod >= 0:
+                return fmod & 0xFFFF
+
+        return None
+
+    @property
+    def null_ok(self) -> Optional[bool]:
+        return None
+
 
 class BaseCursor:
     ExecStatus = pq.ExecStatus
index 8d819842e4fc973fc6747a48485c7087b8305f09..4a943c0f1d3e7b6f39dbccd478177204a4a8a4b3 100644 (file)
@@ -200,3 +200,81 @@ def test_rowcount(conn):
 
     cur.close()
     assert cur.rowcount == -1
+
+
+class TestColumn:
+    def test_description_attribs(self, conn):
+        curs = conn.cursor()
+        curs.execute(
+            """select
+            3.14::decimal(10,2) as pi,
+            'hello'::text as hi,
+            '2010-02-18'::date as now
+            """
+        )
+        assert len(curs.description) == 3
+        for c in curs.description:
+            len(c) == 7  # DBAPI happy
+            for i, a in enumerate(
+                """
+                name type_code display_size internal_size precision scale null_ok
+                """.split()
+            ):
+                assert c[i] == getattr(c, a)
+
+            # Won't fill them up
+            assert c.null_ok is None
+
+        c = curs.description[0]
+        assert c.name == "pi"
+        assert c.type_code == builtins["numeric"].oid
+        assert c.display_size is None
+        assert c.internal_size is None
+        assert c.precision == 10
+        assert c.scale == 2
+
+        c = curs.description[1]
+        assert c.name == "hi"
+        assert c.type_code == builtins["text"].oid
+        assert c.display_size is None
+        assert c.internal_size is None
+        assert c.precision is None
+        assert c.scale is None
+
+        c = curs.description[2]
+        assert c.name == "now"
+        assert c.type_code == builtins["date"].oid
+        assert c.display_size is None
+        assert c.internal_size == 4
+        assert c.precision is None
+        assert c.scale is None
+
+    def test_description_slice(self, conn):
+        curs = conn.cursor()
+        curs.execute("select 1::int as a")
+        curs.description[0][0:2] == ("a", 23)
+
+    @pytest.mark.parametrize(
+        "type, precision, scale, dsize, isize",
+        [
+            ("text", None, None, None, None),
+            ("varchar", None, None, None, None),
+            ("varchar(42)", None, None, 42, None),
+            ("int4", None, None, None, 4),
+            ("numeric", None, None, None, None),
+            ("numeric(10)", 10, 0, None, None),
+            ("numeric(10, 3)", 10, 3, None, None),
+            ("time", None, None, None, 8),
+            ("time(4)", 4, None, None, 8),
+            ("time(10)", 6, None, None, 8),
+        ],
+    )
+    def test_details(self, conn, type, precision, scale, dsize, isize):
+        cur = conn.cursor()
+        cur.execute(f"select null::{type}")
+        col = cur.description[0]
+        repr(col)
+        assert col.precision == precision
+        assert col.scale == scale
+        assert col.display_size == dsize
+        assert col.internal_size == isize