]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
query2pg returns formats too
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Mar 2020 07:22:42 +0000 (20:22 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Mar 2020 07:22:42 +0000 (20:22 +1300)
psycopg3/cursor.py
psycopg3/utils/queries.py
tests/test_query.py

index 667d9db7065d133d5f95771f0a29810ddf1f80b6..603ffdcf9b63c206a6ec19658694fc22e65feabc 100644 (file)
@@ -29,11 +29,11 @@ class BaseCursor:
 
         # process %% -> % only if there are paramters, even if empty list
         if vars is not None:
-            query, order = query2pg(query, vars, codec)
+            query, formats, order = query2pg(query, vars, codec)
         if vars:
             if order is not None:
                 vars = reorder_params(vars, order)
-            params, formats = self._adapt_sequence(vars)
+            params = self._adapt_sequence(vars, formats)
             self.conn.pgconn.send_query_params(
                 query, params, param_formats=formats
             )
@@ -83,14 +83,13 @@ class BaseCursor:
             self._result = self._results[self._iresult]
             return True
 
-    def _adapt_sequence(self, vars):
+    def _adapt_sequence(self, vars, formats):
         # TODO: stub. Need adaptation layer.
         codec = self.conn.codec
         out = [
             codec.encode(str(v))[0] if v is not None else None for v in vars
         ]
-        fmt = [0] * len(out)
-        return out, fmt
+        return out
 
 
 class Cursor(BaseCursor):
index f692294e02d86d4fc3511ccad3abb2edf1f4d926..823d37c889e4f7db169e2cdea4dd3e9f369e9244 100644 (file)
@@ -52,22 +52,33 @@ def query2pg(query, vars, codec):
         for part in parts[:-1]:
             name = codec.decode(part[1])[0]
             if name not in seen:
-                part[1] = seen[name] = len(seen)
+                n = len(seen)
+                part[1] = n
+                seen[name] = (n, part[2])
                 order.append(name)
             else:
-                part[1] = seen[name]
+                if seen[name][1] != part[2]:
+                    raise exc.ProgrammingError(
+                        f"placeholder '{name}' cannot have different formats"
+                    )
+                part[1] = seen[name][0]
 
     else:
-        raise TypeError("parameters should be a sequence or a mapping")
+        raise TypeError(
+            f"query parameters should be a sequence or a mapping,"
+            f" got {type(vars).__name__}"
+        )
 
     # Assemble query and parameters
     rv = []
+    formats = []
     for part in parts[:-1]:
         rv.append(part[0])
         rv.append(b"$%d" % (part[1] + 1))
+        formats.append(part[2])
     rv.append(parts[-1][0])
 
-    return b"".join(rv), order
+    return b"".join(rv), formats, order
 
 
 _re_placeholder = re.compile(
index 8ad6a342e4e1fb04cd999a7144ab4d2b2e863ead..8aa6ec9529ea0c91132c140a1f18362a8a6f431d 100644 (file)
@@ -62,41 +62,45 @@ def test_split_query_bad(input):
 
 
 @pytest.mark.parametrize(
-    "query, params, want",
+    "query, params, want, wformats",
     [
-        (b"", [], b""),
-        (b"%%", [], b"%"),
-        (b"select %s", (1,), b"select $1"),
-        (b"%s %% %s", (1, 2), b"$1 % $2"),
+        (b"", [], b"", []),
+        (b"%%", [], b"%", []),
+        (b"select %s", (1,), b"select $1", [False]),
+        (b"%s %% %s", (1, 2), b"$1 % $2", [False, False]),
+        (b"%b %% %s", (1, 2), b"$1 % $2", [True, False]),
     ],
 )
-def test_query2pg_seq(query, params, want):
-    out, order = query2pg(query, params, codecs.lookup("utf-8"))
+def test_query2pg_seq(query, params, want, wformats):
+    out, formats, order = query2pg(query, params, codecs.lookup("utf-8"))
     assert order is None
     assert out == want
+    assert formats == wformats
 
 
 @pytest.mark.parametrize(
-    "query, params, want, worder",
+    "query, params, want, wformats, worder",
     [
-        (b"", {}, b"", []),
-        (b"hello %%", {"a": 1}, b"hello %", []),
+        (b"", {}, b"", [], []),
+        (b"hello %%", {"a": 1}, b"hello %", [], []),
         (
             b"select %(hello)s",
             {"hello": 1, "world": 2},
             b"select $1",
+            [False],
             ["hello"],
         ),
         (
-            b"select %(hi)s %(there)s %(hi)s",
+            b"select %(hi)s %(there)b %(hi)s",
             {"hi": 1, "there": 2},
             b"select $1 $2 $1",
+            [False, True],
             ["hi", "there"],
         ),
     ],
 )
-def test_query2pg_map(query, params, want, worder):
-    out, order = query2pg(query, params, codecs.lookup("utf-8"))
+def test_query2pg_map(query, params, want, wformats, worder):
+    out, formats, order = query2pg(query, params, codecs.lookup("utf-8"))
     assert out == want
     assert order == worder
 
@@ -129,7 +133,8 @@ def test_query2pg_badtype(query, params):
         (b"select %(", {"a": 1}),
         (b"select %(a", {"a": 1}),
         (b"select %(a)", {"a": 1}),
-        (b"select %s %(hi)s", 1),
+        (b"select %s %(hi)s", [1]),
+        (b"select %(hi)s %(hi)b", {"hi": 1}),
     ],
 )
 def test_query2pg_badprog(query, params):