]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't create a new transformer if the query is repeated on a cursor
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 14 Jan 2021 16:39:15 +0000 (17:39 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 14 Jan 2021 16:39:15 +0000 (17:39 +0100)
psycopg3/psycopg3/cursor.py
tests/types/test_array.py
tests/types/test_numeric.py

index 890c92c6ee91739086ad3fe72d721d804b099cd5..d14a04b0e344b594e85324db6baaf13082e7a907 100644 (file)
@@ -50,7 +50,7 @@ class BaseCursor(Generic[ConnectionType]):
     if sys.version_info >= (3, 7):
         __slots__ = """
             _conn format _adapters arraysize _closed _results _pgresult _pos
-            _iresult _rowcount _pgq _transformer
+            _iresult _rowcount _pgq _transformer _last_query
             __weakref__
             """.split()
 
@@ -68,6 +68,7 @@ class BaseCursor(Generic[ConnectionType]):
         self._adapters = adapt.AdaptersMap(connection.adapters)
         self.arraysize = 1
         self._closed = False
+        self._last_query: Optional[Query] = None
         self._reset()
 
     def _reset(self) -> None:
@@ -187,7 +188,7 @@ class BaseCursor(Generic[ConnectionType]):
         prepare: Optional[bool] = None,
     ) -> PQGen[None]:
         """Generator implementing `Cursor.execute()`."""
-        yield from self._start_query()
+        yield from self._start_query(query)
         pgq = self._convert_query(query, params)
 
         # Check if the query is prepared or needs preparing
@@ -220,12 +221,13 @@ class BaseCursor(Generic[ConnectionType]):
                 yield from self._conn._exec_command(cmd)
 
         self._execute_results(results)
+        self._last_query = query
 
     def _executemany_gen(
         self, query: Query, params_seq: Sequence[Params]
     ) -> PQGen[None]:
         """Generator implementing `Cursor.executemany()`."""
-        yield from self._start_query()
+        yield from self._start_query(query)
         first = True
         for params in params_seq:
             if first:
@@ -246,7 +248,9 @@ class BaseCursor(Generic[ConnectionType]):
             (result,) = yield from execute(self._conn.pgconn)
             self._execute_results((result,))
 
-    def _start_query(self) -> PQGen[None]:
+        self._last_query = query
+
+    def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
         """Generator to start the processing of a query.
 
         It is implemented as generator because it may send additional queries,
@@ -256,7 +260,9 @@ class BaseCursor(Generic[ConnectionType]):
             raise e.InterfaceError("the cursor is closed")
 
         self._reset()
-        self._transformer = adapt.Transformer(self)
+        if not self._last_query or (self._last_query is not query):
+            self._last_query = None
+            self._transformer = adapt.Transformer(self)
         yield from self._conn._start_query()
 
     def _start_copy_gen(self, statement: Query) -> PQGen[None]:
index 3866c72204ff2978e5e0ab3817ba433596638e43..b9e3b17e698854f965def8f52a405f1da502765e 100644 (file)
@@ -103,7 +103,7 @@ def test_load_list_int(conn, obj, want, fmt_out):
 def test_array_register(conn):
     cur = conn.cursor()
     cur.execute("create table mytype (data text)")
-    cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
+    cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[] -- 1""")
     res = cur.fetchone()
     assert res[0] == "(foo)"
     assert res[1] == "{(foo)}"
@@ -111,7 +111,7 @@ def test_array_register(conn):
     array.register(
         cur.description[1].type_code, cur.description[0].type_code, context=cur
     )
-    cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
+    cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[] -- 2""")
     res = cur.fetchone()
     assert res[0] == "(foo)"
     assert res[1] == ["(foo)"]
index 4a143ba39ff128d9a34889710f90935651660a5b..3befe2a3ed982688fed34df964ab6c75d4355f20 100644 (file)
@@ -356,7 +356,7 @@ def test_numeric_as_float(conn, val):
     FloatLoader.register(builtins["numeric"].oid, cur)
 
     val = Decimal(val)
-    cur.execute("select %s", (val,))
+    cur.execute("select %s as val", (val,))
     result = cur.fetchone()[0]
     assert isinstance(result, float)
     if val.is_nan():
@@ -365,7 +365,7 @@ def test_numeric_as_float(conn, val):
         assert result == pytest.approx(float(val))
 
     # the customization works with arrays too
-    cur.execute("select %s", ([val],))
+    cur.execute("select %s as arr", ([val],))
     result = cur.fetchone()[0]
     assert isinstance(result, list)
     assert isinstance(result[0], float)