]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added C string dumpers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 21 Nov 2020 23:09:16 +0000 (23:09 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 21 Nov 2020 23:15:47 +0000 (23:15 +0000)
psycopg3/psycopg3/types/text.py
psycopg3_c/psycopg3_c/types/numeric.pyx
psycopg3_c/psycopg3_c/types/singletons.pyx
psycopg3_c/psycopg3_c/types/text.pyx

index 0be59bef03554122239e22c389b91863573db084..df0993b61e5781f1a5b70706d8ca39d66520f82d 100644 (file)
@@ -30,6 +30,7 @@ class _StringDumper(Dumper):
 @Dumper.binary(str)
 class StringBinaryDumper(_StringDumper):
     def dump(self, obj: str) -> bytes:
+        # the server will raise DataError subclass if the string contains 0x00
         return obj.encode(self.encoding)
 
 
index a13b4c272cb9dcd0330c92af69018afc1050e6e3..fe97209e4b19d75ddbc4bf1012b4bdbf00400ca2 100644 (file)
@@ -23,13 +23,13 @@ cdef extern from "Python.h":
 cdef class IntDumper(CDumper):
     oid = oids.INT8_OID
 
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj) -> bytes:
         cdef char buf[22]
         cdef long long val = PyLong_AsLongLong(obj)
         cdef int written = PyOS_snprintf(buf, sizeof(buf), "%lld", val)
         return buf[:written]
 
-    def quote(self, obj: Any) -> bytes:
+    def quote(self, obj) -> bytes:
         cdef char buf[23]
         cdef long long val = PyLong_AsLongLong(obj)
         cdef int written
@@ -42,7 +42,7 @@ cdef class IntDumper(CDumper):
 
 
 cdef class IntBinaryDumper(IntDumper):
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj) -> bytes:
         cdef long long val = PyLong_AsLongLong(obj)
         cdef uint64_t *ptvar = <uint64_t *>(&val)
         cdef int64_t beval = htobe64(ptvar[0])
index 2377d4452f90afaf6f00dad75d77c8556594a46b..32aa1206b49564c66c08a6b40b940486cd20fa5d 100644 (file)
@@ -10,7 +10,7 @@ from psycopg3_c cimport oids
 cdef class BoolDumper(CDumper):
     oid = oids.BOOL_OID
 
-    def dump(self, obj: bool) -> bytes:
+    def dump(self, obj) -> bytes:
         # Fast paths, just a pointer comparison
         if obj is True:
             return b"t"
@@ -29,7 +29,7 @@ cdef class BoolDumper(CDumper):
 
 
 cdef class BoolBinaryDumper(BoolDumper):
-    def dump(self, obj: bool) -> bytes:
+    def dump(self, obj) -> bytes:
         if obj is True:
             return b"\x01"
         elif obj is False:
index ebc6bd13daafb249756d63c62e36f775898d96b1..b633ec08d1b5aef93197b7ddde0e9c579fdaac60 100644 (file)
@@ -4,10 +4,67 @@ Cython adapters for textual types.
 
 # Copyright (C) 2020 The Psycopg Team
 
+from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize
 from cpython.unicode cimport PyUnicode_Decode, PyUnicode_DecodeUTF8
+from cpython.unicode cimport PyUnicode_AsUTF8String, PyUnicode_AsEncodedString
+
 from psycopg3_c cimport libpq, oids
 
 
+cdef class _StringDumper(CDumper):
+    cdef int is_utf8
+    cdef char *encoding
+    cdef bytes _bytes_encoding  # needed to keep `encoding` alive
+
+    def __init__(self, src: type, context: AdaptContext):
+        super().__init__(src, context)
+
+        self.is_utf8 = 0
+        self.encoding = "utf-8"
+
+        conn = self._connection
+        if conn:
+            self._bytes_encoding = conn.client_encoding.encode("utf-8")
+            self.encoding = PyBytes_AsString(self._bytes_encoding)
+            if (
+                self._bytes_encoding == b"utf-8"
+                or self._bytes_encoding == b"ascii"
+            ):
+                self.is_utf8 = 1
+
+
+cdef class StringBinaryDumper(_StringDumper):
+    def dump(self, obj) -> bytes:
+        # the server will raise DataError subclass if the string contains 0x00
+        if self.is_utf8:
+            return PyUnicode_AsUTF8String(obj)
+        else:
+            return PyUnicode_AsEncodedString(obj, self.encoding, NULL)
+
+
+cdef class StringDumper(_StringDumper):
+    def dump(self, obj) -> bytes:
+        cdef bytes rv
+        cdef char *buf
+
+        if self.is_utf8:
+            rv = PyUnicode_AsUTF8String(obj)
+        else:
+            rv = PyUnicode_AsEncodedString(obj, self.encoding, NULL)
+
+        try:
+            # the function raises ValueError if the bytes contains 0x00
+            PyBytes_AsStringAndSize(rv, &buf, NULL)
+        except ValueError:
+            from psycopg3 import DataError
+
+            raise DataError(
+                "PostgreSQL text fields cannot contain NUL (0x00) bytes"
+            )
+
+        return rv
+
+
 cdef class TextLoader(CLoader):
     cdef int is_utf8
     cdef char *encoding
@@ -19,10 +76,10 @@ cdef class TextLoader(CLoader):
         self.is_utf8 = 0
         self.encoding = "utf-8"
 
-        conn = self.connection
+        conn = self._connection
         if conn:
             self._bytes_encoding = conn.client_encoding.encode("utf-8")
-            self.encoding = self._bytes_encoding
+            self.encoding = PyBytes_AsString(self._bytes_encoding)
             if self._bytes_encoding == b"utf-8":
                 self.is_utf8 = 1
             elif self._bytes_encoding == b"ascii":
@@ -60,6 +117,9 @@ cdef class ByteaBinaryLoader(CLoader):
 cdef void register_text_c_adapters():
     logger.debug("registering optimised text c adapters")
 
+    StringDumper.register(str)
+    StringBinaryDumper.register_binary(str)
+
     TextLoader.register(oids.INVALID_OID)
     TextLoader.register(oids.TEXT_OID)
     TextLoader.register_binary(oids.TEXT_OID)