From: Daniele Varrazzo Date: Sun, 20 Mar 2022 18:22:47 +0000 (+0100) Subject: fix(sql): represent array literals correctly X-Git-Tag: 3.1~109^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=57b638957611dea7fbc56bc1490b5f0c717c5b5c;p=thirdparty%2Fpsycopg.git fix(sql): represent array literals correctly --- diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index f2b0ac1cc..9726af8b5 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -406,6 +406,8 @@ class Literal(Composable): except KeyError: type_name = ti.regtype.encode(tx.encoding) self._names_cache[ti.regtype, tx.encoding] = type_name + if dumper.oid == ti.array_oid: + type_name += b"[]" rv = b"%s::%s" % (rv, type_name) return rv diff --git a/tests/test_sql.py b/tests/test_sql.py index d8ae83703..21d0018c5 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -339,6 +339,12 @@ class TestLiteral: with pytest.raises(ProgrammingError): sql.Literal(Foo()).as_string(conn) + def test_array(self, conn): + assert ( + sql.Literal([dt.date(2000, 1, 1)]).as_string(conn) + == "'{2000-01-01}'::date[]" + ) + @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"]) def test_invalid_name(self, conn, name): conn.execute( @@ -368,6 +374,12 @@ class TestLiteral: cur = conn.execute(sql.SQL("select {}").format("hello")) assert cur.fetchone()[0] == "hello-inv" + assert ( + sql.Literal(["hello"]).as_string(conn) == f"'{{hello-inv}}'::\"{name}\"[]" + ) + cur = conn.execute(sql.SQL("select {}").format(["hello"])) + assert cur.fetchone()[0] == ["hello-inv"] + class TestSQL: def test_class(self):