]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dropped local variable micro-optimization
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 7 Nov 2020 01:58:06 +0000 (01:58 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 7 Nov 2020 01:58:06 +0000 (01:58 +0000)
A long time ago it was a thing. timeit shows it isn't anymore.

19 files changed:
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/sql.py
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/types/json.py
psycopg3/psycopg3/types/network.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/types/text.py
psycopg3/psycopg3/types/uuid.py
psycopg3/psycopg3/utils/queries.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/transform.pyx
psycopg3_c/psycopg3_c/types/text.pyx
tests/test_connection.py
tests/test_connection_async.py
tests/test_cursor.py

index 70d2f99b768f16e00a92cb0aa4293b53aeabed9b..b10aa7cf4feb209083f7f3dc6f68be4245cdb3c2 100644 (file)
@@ -4,7 +4,6 @@ Helper object to transform values between Python and PostgreSQL
 
 # Copyright (C) 2020 The Psycopg Team
 
-import codecs
 from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
 from typing import TYPE_CHECKING
 
@@ -56,7 +55,7 @@ class Transformer:
     def _setup_context(self, context: AdaptContext) -> None:
         if context is None:
             self._connection = None
-            self._codec = codecs.lookup("utf8")
+            self._encoding = "utf-8"
             self._dumpers = {}
             self._loaders = {}
             self._dumpers_maps = [self._dumpers]
@@ -66,7 +65,7 @@ class Transformer:
             # A transformer created from a transformers: usually it happens
             # for nested types: share the entire state of the parent
             self._connection = context.connection
-            self._codec = context.codec
+            self._encoding = context.encoding
             self._dumpers = context.dumpers
             self._loaders = context.loaders
             self._dumpers_maps.extend(context._dumpers_maps)
@@ -76,7 +75,7 @@ class Transformer:
 
         elif isinstance(context, BaseCursor):
             self._connection = context.connection
-            self._codec = context.connection.codec
+            self._encoding = context.connection.pyenc
             self._dumpers = {}
             self._dumpers_maps.extend(
                 (self._dumpers, context.dumpers, context.connection.dumpers)
@@ -88,7 +87,7 @@ class Transformer:
 
         elif isinstance(context, BaseConnection):
             self._connection = context
-            self._codec = context.codec
+            self._encoding = context.pyenc
             self._dumpers = {}
             self._dumpers_maps.extend((self._dumpers, context.dumpers))
             self._loaders = {}
@@ -104,8 +103,8 @@ class Transformer:
         return self._connection
 
     @property
-    def codec(self) -> codecs.CodecInfo:
-        return self._codec
+    def encoding(self) -> str:
+        return self._encoding
 
     @property
     def pgresult(self) -> Optional[pq.proto.PGresult]:
index 19231c0564f0f061373b83648e10958c44a3401a..a720528ababedb31914527fa893a74a8690e7300 100644 (file)
@@ -4,7 +4,6 @@ psycopg3 connection objects
 
 # Copyright (C) 2020 The Psycopg Team
 
-import codecs
 import logging
 import asyncio
 import threading
@@ -83,8 +82,10 @@ class BaseConnection:
         self.loaders: proto.LoadersMap = {}
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
-        # name of the postgres encoding (in bytes)
+        # postgres name of the client encoding (in bytes)
         self._pgenc = b""
+        # python name of the client encoding
+        self._pyenc = "utf-8"
 
         wself = ref(self)
 
@@ -126,25 +127,19 @@ class BaseConnection:
         return self.cursor_factory(self, format=format)
 
     @property
-    def codec(self) -> codecs.CodecInfo:
-        # TODO: utf8 fastpath?
+    def pyenc(self) -> str:
         pgenc = self.pgconn.parameter_status(b"client_encoding") or b""
         if self._pgenc != pgenc:
             if pgenc:
                 try:
-                    pyenc = pq.py_codecs[pgenc]
+                    self._pyenc = pq.py_codecs[pgenc]
                 except KeyError:
                     raise e.NotSupportedError(
                         f"encoding {pgenc.decode('ascii')} not available in Python"
                     )
-                self._codec = codecs.lookup(pyenc)
-            else:
-                # fallback for a connection closed whose codec was never asked
-                if not hasattr(self, "_codec"):
-                    self._codec = codecs.lookup("utf8")
 
             self._pgenc = pgenc
-        return self._codec
+        return self._pyenc
 
     @property
     def client_encoding(self) -> str:
@@ -179,7 +174,7 @@ class BaseConnection:
         if self is None or not self._notice_handler:
             return
 
-        diag = e.Diagnostic(res, self.codec.name)
+        diag = e.Diagnostic(res, self._pyenc)
         for cb in self._notice_handlers:
             try:
                 cb(diag)
@@ -202,8 +197,11 @@ class BaseConnection:
         if self is None or not self._notify_handlers:
             return
 
-        decode = self.codec.decode
-        n = Notify(decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid)
+        n = Notify(
+            pgn.relname.decode(self._pyenc),
+            pgn.extra.decode(self._pyenc),
+            pgn.be_pid,
+        )
         for cb in self._notify_handlers:
             cb(n)
 
@@ -261,7 +259,7 @@ class Connection(BaseConnection):
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 "error on begin:"
-                f" {pq.error_message(pgres, encoding=self.codec.name)}"
+                f" {pq.error_message(pgres, encoding=self._pyenc)}"
             )
 
     def commit(self) -> None:
@@ -283,7 +281,7 @@ class Connection(BaseConnection):
         if results[-1].status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 f"error on {command.decode('utf8')}:"
-                f" {pq.error_message(results[-1], encoding=self.codec.name)}"
+                f" {pq.error_message(results[-1], encoding=self._pyenc)}"
             )
 
     @classmethod
@@ -301,16 +299,17 @@ class Connection(BaseConnection):
             gen = execute(self.pgconn)
             (result,) = self.wait(gen)
             if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(result, encoding=self.codec.name)
+                raise e.error_from_result(result, encoding=self._pyenc)
 
     def notifies(self) -> Generator[Optional[Notify], bool, None]:
-        decode = self.codec.decode
         while 1:
             with self.lock:
                 ns = self.wait(notifies(self.pgconn))
             for pgn in ns:
                 n = Notify(
-                    decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid
+                    pgn.relname.decode(self._pyenc),
+                    pgn.extra.decode(self._pyenc),
+                    pgn.be_pid,
                 )
                 if (yield n):
                     yield None  # for the send who stopped us
@@ -375,7 +374,7 @@ class AsyncConnection(BaseConnection):
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 "error on begin:"
-                f" {pq.error_message(pgres, encoding=self.codec.name)}"
+                f" {pq.error_message(pgres, encoding=self._pyenc)}"
             )
 
     async def commit(self) -> None:
@@ -397,7 +396,7 @@ class AsyncConnection(BaseConnection):
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 f"error on {command.decode('utf8')}:"
-                f" {pq.error_message(pgres, encoding=self.codec.name)}"
+                f" {pq.error_message(pgres, encoding=self._pyenc)}"
             )
 
     @classmethod
@@ -419,16 +418,17 @@ class AsyncConnection(BaseConnection):
             gen = execute(self.pgconn)
             (result,) = await self.wait(gen)
             if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(result, encoding=self.codec.name)
+                raise e.error_from_result(result, encoding=self._pyenc)
 
     async def notifies(self) -> AsyncGenerator[Optional[Notify], bool]:
-        decode = self.codec.decode
         while 1:
             async with self.lock:
                 ns = await self.wait(notifies(self.pgconn))
             for pgn in ns:
                 n = Notify(
-                    decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid
+                    pgn.relname.decode(self._pyenc),
+                    pgn.extra.decode(self._pyenc),
+                    pgn.be_pid,
                 )
                 if (yield n):
                     yield None
index ff61cf988519f3c71793a843908da1a3cdfe1324..6cd89e533d250d31d50e81918483103decd1e948 100644 (file)
@@ -5,7 +5,6 @@ psycopg3 copy support
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-import codecs
 import struct
 from typing import TYPE_CHECKING, AsyncIterator, Iterator
 from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union
@@ -34,7 +33,7 @@ class BaseCopy:
         self.pgresult = result
         self._first_row = True
         self._finished = False
-        self._codec: Optional[codecs.CodecInfo] = None
+        self._encoding: str = ""
 
         if format == pq.Format.TEXT:
             self._format_row = self._format_row_text
@@ -70,8 +69,8 @@ class BaseCopy:
             return data
 
         elif isinstance(data, str):
-            if self._codec is not None:
-                return self._codec.encode(data)[0]
+            if self._encoding:
+                return data.encode(self._encoding)
 
             if (
                 self.pgresult is None
@@ -80,8 +79,8 @@ class BaseCopy:
                 raise TypeError(
                     "cannot copy str data in binary mode: use bytes instead"
                 )
-            self._codec = self.connection.codec
-            return self._codec.encode(data)[0]
+            self._encoding = self.connection.pyenc
+            return data.encode(self._encoding)
 
         else:
             raise TypeError(f"can't write {type(data).__name__}")
@@ -180,13 +179,9 @@ class Copy(BaseCopy):
         data = self.format_row(row)
         self.write(data)
 
-    def finish(self, error: Optional[str] = None) -> None:
+    def finish(self, error: str = "") -> None:
         conn = self.connection
-        berr = (
-            conn.codec.encode(error, "replace")[0]
-            if error is not None
-            else None
-        )
+        berr = error.encode(conn.pyenc, "replace") if error else None
         conn.wait(copy_end(conn.pgconn, berr))
         self._finished = True
 
@@ -205,7 +200,7 @@ class Copy(BaseCopy):
                 self.write(b"\xff\xff")
             self.finish()
         else:
-            self.finish(str(exc_val))
+            self.finish(str(exc_val) or type(exc_val).__qualname__)
 
     def __iter__(self) -> Iterator[bytes]:
         while 1:
@@ -241,13 +236,9 @@ class AsyncCopy(BaseCopy):
         data = self.format_row(row)
         await self.write(data)
 
-    async def finish(self, error: Optional[str] = None) -> None:
+    async def finish(self, error: str = "") -> None:
         conn = self.connection
-        berr = (
-            conn.codec.encode(error, "replace")[0]
-            if error is not None
-            else None
-        )
+        berr = error.encode(conn.pyenc, "replace") if error else None
         await conn.wait(copy_end(conn.pgconn, berr))
         self._finished = True
 
index 77b1ecd867f0b19d2295747f8af298d1b27fb1de..c062d2ae77aa01e97e07bc1bb2b1a4c7781462cd 100644 (file)
@@ -191,10 +191,8 @@ class BaseCursor:
         res = self.pgresult
         if res is None or res.status != self.ExecStatus.TUPLES_OK:
             return None
-        return [
-            Column(res, i, self.connection.codec.name)
-            for i in range(res.nfields)
-        ]
+        encoding = self.connection.pyenc
+        return [Column(res, i, encoding) for i in range(res.nfields)]
 
     @property
     def rowcount(self) -> int:
@@ -283,7 +281,7 @@ class BaseCursor:
 
         if results[-1].status == S.FATAL_ERROR:
             raise e.error_from_result(
-                results[-1], encoding=self.connection.codec.name
+                results[-1], encoding=self.connection.pyenc
             )
 
         elif badstats & {S.COPY_IN, S.COPY_OUT, S.COPY_BOTH}:
@@ -393,9 +391,7 @@ class BaseCursor:
         if status in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT):
             return
         elif status == pq.ExecStatus.FATAL_ERROR:
-            raise e.error_from_result(
-                result, encoding=self.connection.codec.name
-            )
+            raise e.error_from_result(result, encoding=self.connection.pyenc)
         else:
             raise e.ProgrammingError(
                 "copy() should be used only with COPY ... TO STDOUT or COPY ..."
@@ -450,7 +446,7 @@ class Cursor(BaseCursor):
                     (result,) = self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(
-                            result, encoding=self.connection.codec.name
+                            result, encoding=self.connection.pyenc
                         )
                 else:
                     pgq.dump(vars)
@@ -578,7 +574,7 @@ class AsyncCursor(BaseCursor):
                     (result,) = await self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(
-                            result, encoding=self.connection.codec.name
+                            result, encoding=self.connection.pyenc
                         )
                 else:
                     pgq.dump(vars)
index c8f04b24cdf9f9cdf777f62603d1748bef1cfc78..3a0130e4b266084389c73f22bf73941f90c64624 100644 (file)
@@ -4,7 +4,6 @@ Protocol objects representing different implementations of the same classes.
 
 # Copyright (C) 2020 The Psycopg Team
 
-import codecs
 from typing import Any, Callable, Dict, Generator, Mapping
 from typing import Optional, Sequence, Tuple, Type, TypeVar, Union
 from typing import TYPE_CHECKING
@@ -52,7 +51,7 @@ class Transformer(Protocol):
         ...
 
     @property
-    def codec(self) -> codecs.CodecInfo:
+    def encoding(self) -> str:
         ...
 
     @property
index d0f53067d0bef86fa4604e0ac3471d603594a6f7..60681877c90030d7a8475147724a3312acb77f78 100644 (file)
@@ -340,9 +340,9 @@ class Identifier(Composable):
             raise ValueError(f"no connection in the context: {context}")
 
         esc = Escaping(conn.pgconn)
-        codec = conn.codec
-        escs = [esc.escape_identifier(codec.encode(s)[0]) for s in self._obj]
-        return codec.decode(b".".join(escs))[0]
+        enc = conn.pyenc
+        escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+        return b".".join(escs).decode(enc)
 
 
 class Literal(Composable):
@@ -373,7 +373,7 @@ class Literal(Composable):
         tx = context if isinstance(context, Transformer) else Transformer(conn)
         dumper = tx.get_dumper(self._obj, Format.TEXT)
         quoted = dumper.quote(self._obj)
-        return conn.codec.decode(quoted)[0] if conn else quoted.decode("utf8")
+        return quoted.decode(conn.pyenc if conn else "utf-8")
 
 
 class Placeholder(Composable):
index 9cf039353417dcef7cf1b72ec112305079ba4fb8..d7aa36a7820b8fecd6d038f1d5f75ac79800a7af 100644 (file)
@@ -13,7 +13,6 @@ from ..oids import builtins
 from ..adapt import Dumper, Loader
 from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
-from ..utils.codecs import EncodeFunc, DecodeFunc, encode_ascii, decode_ascii
 
 
 @Dumper.text(date)
@@ -21,10 +20,10 @@ class DateDumper(Dumper):
 
     oid = builtins["date"].oid
 
-    def dump(self, obj: date, __encode: EncodeFunc = encode_ascii) -> bytes:
+    def dump(self, obj: date) -> bytes:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
         # the YYYY-MM-DD is always understood correctly.
-        return __encode(str(obj))[0]
+        return str(obj).encode("utf8")
 
 
 @Dumper.text(time)
@@ -32,8 +31,8 @@ class TimeDumper(Dumper):
 
     oid = builtins["timetz"].oid
 
-    def dump(self, obj: time, __encode: EncodeFunc = encode_ascii) -> bytes:
-        return __encode(str(obj))[0]
+    def dump(self, obj: time) -> bytes:
+        return str(obj).encode("utf8")
 
 
 @Dumper.text(datetime)
@@ -41,10 +40,10 @@ class DateTimeDumper(Dumper):
 
     oid = builtins["timestamptz"].oid
 
-    def dump(self, obj: date, __encode: EncodeFunc = encode_ascii) -> bytes:
+    def dump(self, obj: date) -> bytes:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
         # the YYYY-MM-DD is always understood correctly.
-        return __encode(str(obj))[0]
+        return str(obj).encode("utf8")
 
 
 @Dumper.text(timedelta)
@@ -61,10 +60,8 @@ class TimeDeltaDumper(Dumper):
             ):
                 setattr(self, "dump", self._dump_sql)
 
-    def dump(
-        self, obj: timedelta, __encode: EncodeFunc = encode_ascii
-    ) -> bytes:
-        return __encode(str(obj))[0]
+    def dump(self, obj: timedelta) -> bytes:
+        return str(obj).encode("utf8")
 
     def _dump_sql(self, obj: timedelta) -> bytes:
         # sql_standard format needs explicit signs
@@ -82,9 +79,9 @@ class DateLoader(Loader):
         super().__init__(oid, context)
         self._format = self._format_from_context()
 
-    def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> date:
+    def load(self, data: bytes) -> date:
         try:
-            return datetime.strptime(__decode(data)[0], self._format).date()
+            return datetime.strptime(data.decode("utf8"), self._format).date()
         except ValueError as e:
             return self._raise_error(data, e)
 
@@ -140,11 +137,11 @@ class TimeLoader(Loader):
     _format = "%H:%M:%S.%f"
     _format_no_micro = _format.replace(".%f", "")
 
-    def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> time:
+    def load(self, data: bytes) -> time:
         # check if the data contains microseconds
         fmt = self._format if b"." in data else self._format_no_micro
         try:
-            return datetime.strptime(__decode(data)[0], fmt).time()
+            return datetime.strptime(data.decode("utf8"), fmt).time()
         except ValueError as e:
             return self._raise_error(data, e)
 
@@ -170,22 +167,20 @@ class TimeTzLoader(TimeLoader):
 
         super().__init__(oid, context)
 
-    def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> time:
+    def load(self, data: bytes) -> time:
         # Hack to convert +HH in +HHMM
         if data[-3] in (43, 45):
             data += b"00"
 
         fmt = self._format if b"." in data else self._format_no_micro
         try:
-            dt = datetime.strptime(__decode(data)[0], fmt)
+            dt = datetime.strptime(data.decode("utf8"), fmt)
         except ValueError as e:
             return self._raise_error(data, e)
 
         return dt.time().replace(tzinfo=dt.tzinfo)
 
-    def _load_py36(
-        self, data: bytes, __decode: DecodeFunc = decode_ascii
-    ) -> time:
+    def _load_py36(self, data: bytes) -> time:
         # Drop seconds from timezone for Python 3.6
         # Also, Python 3.6 doesn't support HHMM, only HH:MM
         if data[-6] in (43, 45):  # +-HH:MM -> +-HHMM
@@ -202,15 +197,13 @@ class TimestampLoader(DateLoader):
         super().__init__(oid, context)
         self._format_no_micro = self._format.replace(".%f", "")
 
-    def load(
-        self, data: bytes, __decode: DecodeFunc = decode_ascii
-    ) -> datetime:
+    def load(self, data: bytes) -> datetime:
         # check if the data contains microseconds
         fmt = (
             self._format if data.find(b".", 19) >= 0 else self._format_no_micro
         )
         try:
-            return datetime.strptime(__decode(data)[0], fmt)
+            return datetime.strptime(data.decode("utf8"), fmt)
         except ValueError as e:
             return self._raise_error(data, e)
 
@@ -284,18 +277,14 @@ class TimestamptzLoader(TimestampLoader):
             setattr(self, "load", self._load_notimpl)
             return ""
 
-    def load(
-        self, data: bytes, __decode: DecodeFunc = decode_ascii
-    ) -> datetime:
+    def load(self, data: bytes) -> datetime:
         # Hack to convert +HH in +HHMM
         if data[-3] in (43, 45):
             data += b"00"
 
         return super().load(data)
 
-    def _load_py36(
-        self, data: bytes, __decode: DecodeFunc = decode_ascii
-    ) -> datetime:
+    def _load_py36(self, data: bytes) -> datetime:
         # Drop seconds from timezone for Python 3.6
         # Also, Python 3.6 doesn't support HHMM, only HH:MM
         tzsep = (43, 45)  # + and - bytes
index 2e7ee5ffd99382de0e249afcfc7d96d99be09624..74ec26f132c4af5cd1864b8403b4c01ddd706c5c 100644 (file)
@@ -10,7 +10,6 @@ from typing import Any, Callable, Optional
 from ..oids import builtins
 from ..adapt import Dumper, Loader
 from ..errors import DataError
-from ..utils.codecs import EncodeFunc, encode_utf8
 
 JSON_OID = builtins["json"].oid
 JSONB_OID = builtins["jsonb"].oid
@@ -36,10 +35,8 @@ class Jsonb(_JsonWrapper):
 
 
 class _JsonDumper(Dumper):
-    def dump(
-        self, obj: _JsonWrapper, __encode: EncodeFunc = encode_utf8
-    ) -> bytes:
-        return __encode(obj.dumps())[0]
+    def dump(self, obj: _JsonWrapper) -> bytes:
+        return obj.dumps().encode("utf-8")
 
 
 @Dumper.text(Json)
@@ -55,10 +52,8 @@ class JsonbDumper(_JsonDumper):
 
 @Dumper.binary(Jsonb)
 class JsonbBinaryDumper(JsonbDumper):
-    def dump(
-        self, obj: _JsonWrapper, __encode: EncodeFunc = encode_utf8
-    ) -> bytes:
-        return b"\x01" + __encode(obj.dumps())[0]
+    def dump(self, obj: _JsonWrapper) -> bytes:
+        return b"\x01" + obj.dumps().encode("utf-8")
 
 
 @Loader.text(builtins["json"].oid)
index 42a7ae3a578243ff97e7ba19842b0359e5998d1f..b192aeee58a46a671db32d855e3a636e24c5ed85 100644 (file)
@@ -5,20 +5,24 @@ Adapters for network types.
 # Copyright (C) 2020 The Psycopg Team
 
 # TODO: consiter lazy dumper registration.
-from ipaddress import ip_address, ip_interface, ip_network
+import ipaddress
 from ipaddress import IPv4Address, IPv4Interface, IPv4Network
 from ipaddress import IPv6Address, IPv6Interface, IPv6Network
 
-from typing import cast, Union
+from typing import cast, Callable, Union
 
 from ..oids import builtins
 from ..adapt import Dumper, Loader
-from ..utils.codecs import encode_ascii, decode_ascii
 
 Address = Union[IPv4Address, IPv6Address]
 Interface = Union[IPv4Interface, IPv6Interface]
 Network = Union[IPv4Network, IPv6Network]
 
+# in typeshed these types are commented out
+ip_address = cast(Callable[[str], Address], ipaddress.ip_address)
+ip_interface = cast(Callable[[str], Interface], ipaddress.ip_interface)
+ip_network = cast(Callable[[str], Network], ipaddress.ip_network)
+
 
 @Dumper.text(IPv4Address)
 @Dumper.text(IPv6Address)
@@ -29,7 +33,7 @@ class InterfaceDumper(Dumper):
     oid = builtins["inet"].oid
 
     def dump(self, obj: Interface) -> bytes:
-        return encode_ascii(str(obj))[0]
+        return str(obj).encode("utf8")
 
 
 @Dumper.text(IPv4Network)
@@ -39,19 +43,19 @@ class NetworkDumper(Dumper):
     oid = builtins["cidr"].oid
 
     def dump(self, obj: Network) -> bytes:
-        return encode_ascii(str(obj))[0]
+        return str(obj).encode("utf8")
 
 
 @Loader.text(builtins["inet"].oid)
 class InetLoader(Loader):
     def load(self, data: bytes) -> Union[Address, Interface]:
         if b"/" in data:
-            return cast(Interface, ip_interface(decode_ascii(data)[0]))
+            return ip_interface(data.decode("utf8"))
         else:
-            return cast(Address, ip_address(decode_ascii(data)[0]))
+            return ip_address(data.decode("utf8"))
 
 
 @Loader.text(builtins["cidr"].oid)
 class CidrLoader(Loader):
     def load(self, data: bytes) -> Network:
-        return cast(Network, ip_network(decode_ascii(data)[0]))
+        return ip_network(data.decode("utf8"))
index c239a47c697ac832e2acfc3ccafb90639b210d94..054c9c949d367bce587d7baeb4ecf07e83bfe225 100644 (file)
@@ -10,16 +10,22 @@ from decimal import Decimal
 
 from ..oids import builtins
 from ..adapt import Dumper, Loader
-from ..utils.codecs import DecodeFunc, decode_ascii
 
-PackInt = Callable[[int], bytes]
-UnpackInt = Callable[[bytes], Tuple[int]]
-UnpackFloat = Callable[[bytes], Tuple[float]]
+_PackInt = Callable[[int], bytes]
+_UnpackInt = Callable[[bytes], Tuple[int]]
+_UnpackFloat = Callable[[bytes], Tuple[float]]
+
+_pack_int2 = cast(_PackInt, struct.Struct("!h").pack)
+_pack_int4 = cast(_PackInt, struct.Struct("!i").pack)
+_pack_uint4 = cast(_PackInt, struct.Struct("!I").pack)
+_pack_int8 = cast(_PackInt, struct.Struct("!q").pack)
+_unpack_int2 = cast(_UnpackInt, struct.Struct("!h").unpack)
+_unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack)
+_unpack_uint4 = cast(_UnpackInt, struct.Struct("!I").unpack)
+_unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack)
+_unpack_float4 = cast(_UnpackFloat, struct.Struct("!f").unpack)
+_unpack_float8 = cast(_UnpackFloat, struct.Struct("!d").unpack)
 
-_pack_int2 = cast(PackInt, struct.Struct("!h").pack)
-_pack_int4 = cast(PackInt, struct.Struct("!i").pack)
-_pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
-_pack_int8 = cast(PackInt, struct.Struct("!q").pack)
 
 # Wrappers to force numbers to be cast as specific PostgreSQL types
 
@@ -142,48 +148,32 @@ class OidBinaryDumper(OidDumper):
 @Loader.text(builtins["int8"].oid)
 @Loader.text(builtins["oid"].oid)
 class IntLoader(Loader):
-    def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> int:
-        return int(__decode(data)[0])
+    def load(self, data: bytes) -> int:
+        return int(data.decode("utf8"))
 
 
 @Loader.binary(builtins["int2"].oid)
 class Int2BinaryLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!h").unpack),
-    ) -> int:
-        return __unpack(data)[0]
+    def load(self, data: bytes) -> int:
+        return _unpack_int2(data)[0]
 
 
 @Loader.binary(builtins["int4"].oid)
 class Int4BinaryLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!i").unpack),
-    ) -> int:
-        return __unpack(data)[0]
+    def load(self, data: bytes) -> int:
+        return _unpack_int4(data)[0]
 
 
 @Loader.binary(builtins["int8"].oid)
 class Int8BinaryLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!q").unpack),
-    ) -> int:
-        return __unpack(data)[0]
+    def load(self, data: bytes) -> int:
+        return _unpack_int8(data)[0]
 
 
 @Loader.binary(builtins["oid"].oid)
 class OidBinaryLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!I").unpack),
-    ) -> int:
-        return __unpack(data)[0]
+    def load(self, data: bytes) -> int:
+        return _unpack_uint4(data)[0]
 
 
 @Loader.text(builtins["float4"].oid)
@@ -196,27 +186,17 @@ class FloatLoader(Loader):
 
 @Loader.binary(builtins["float4"].oid)
 class Float4BinaryLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!f").unpack),
-    ) -> int:
-        return __unpack(data)[0]
+    def load(self, data: bytes) -> float:
+        return _unpack_float4(data)[0]
 
 
 @Loader.binary(builtins["float8"].oid)
 class Float8BinaryLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!d").unpack),
-    ) -> int:
-        return __unpack(data)[0]
+    def load(self, data: bytes) -> float:
+        return _unpack_float8(data)[0]
 
 
 @Loader.text(builtins["numeric"].oid)
 class NumericLoader(Loader):
-    def load(
-        self, data: bytes, __decode: DecodeFunc = decode_ascii
-    ) -> Decimal:
-        return Decimal(__decode(data)[0])
+    def load(self, data: bytes) -> Decimal:
+        return Decimal(data.decode("utf8"))
index 6ad6ca45ce54ea7795d02f96a0f94a755b0c27d9..f164780731a374d34137aba2cf60dc291bc451e7 100644 (file)
@@ -4,14 +4,13 @@ Adapters for textual types.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Optional, Union, TYPE_CHECKING
+from typing import Union, TYPE_CHECKING
 
 from ..pq import Escaping
 from ..oids import builtins, INVALID_OID
 from ..adapt import Dumper, Loader
 from ..proto import AdaptContext
 from ..errors import DataError
-from ..utils.codecs import EncodeFunc, DecodeFunc, encode_utf8, decode_utf8
 
 if TYPE_CHECKING:
     from ..pq.proto import Escaping as EscapingProto
@@ -21,20 +20,19 @@ class _StringDumper(Dumper):
     def __init__(self, src: type, context: AdaptContext):
         super().__init__(src, context)
 
-        self._encode: EncodeFunc
         if self.connection:
             if self.connection.client_encoding != "SQL_ASCII":
-                self._encode = self.connection.codec.encode
+                self.encoding = self.connection.pyenc
             else:
-                self._encode = encode_utf8
+                self.encoding = "utf-8"
         else:
-            self._encode = encode_utf8
+            self.encoding = "utf-8"
 
 
 @Dumper.binary(str)
 class StringBinaryDumper(_StringDumper):
     def dump(self, obj: str) -> bytes:
-        return self._encode(obj)[0]
+        return obj.encode(self.encoding)
 
 
 @Dumper.text(str)
@@ -45,7 +43,7 @@ class StringDumper(_StringDumper):
                 "PostgreSQL text fields cannot contain NUL (0x00) bytes"
             )
         else:
-            return self._encode(obj)[0]
+            return obj.encode(self.encoding)
 
 
 @Loader.text(builtins["text"].oid)
@@ -54,23 +52,20 @@ class StringDumper(_StringDumper):
 @Loader.binary(builtins["varchar"].oid)
 @Loader.text(INVALID_OID)
 class TextLoader(Loader):
-
-    decode: Optional[DecodeFunc]
-
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
 
         if self.connection is not None:
             if self.connection.client_encoding != "SQL_ASCII":
-                self.decode = self.connection.codec.decode
+                self.encoding = self.connection.pyenc
             else:
-                self.decode = None
+                self.encoding = ""
         else:
-            self.decode = decode_utf8
+            self.encoding = "utf-8"
 
     def load(self, data: bytes) -> Union[bytes, str]:
-        if self.decode is not None:
-            return self.decode(data)[0]
+        if self.encoding:
+            return data.decode(self.encoding)
         else:
             # return bytes for SQL_ASCII db
             return data
@@ -83,15 +78,10 @@ class TextLoader(Loader):
 class UnknownLoader(Loader):
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
-
-        self.decode: DecodeFunc
-        if self.connection is not None:
-            self.decode = self.connection.codec.decode
-        else:
-            self.decode = decode_utf8
+        self.encoding = self.connection.pyenc if self.connection else "utf-8"
 
     def load(self, data: bytes) -> str:
-        return self.decode(data)[0]
+        return data.decode(self.encoding)
 
 
 @Dumper.text(bytes)
index 6076b12f46396ed2953dcf8c78ef3cb5e108b9d1..73cb30ef2a89adeec3d2c8cd9efcb8a53aa02b1b 100644 (file)
@@ -10,7 +10,6 @@ from uuid import UUID
 
 from ..oids import builtins
 from ..adapt import Dumper, Loader
-from ..utils.codecs import EncodeFunc, DecodeFunc, encode_ascii, decode_ascii
 
 
 @Dumper.text(UUID)
@@ -18,23 +17,20 @@ class UUIDDumper(Dumper):
 
     oid = builtins["uuid"].oid
 
-    def dump(self, obj: UUID, __encode: EncodeFunc = encode_ascii) -> bytes:
-        return __encode(obj.hex)[0]
+    def dump(self, obj: UUID) -> bytes:
+        return obj.hex.encode("utf8")
 
 
 @Dumper.binary(UUID)
-class UUIDBinaryDumper(Dumper):
-
-    oid = builtins["uuid"].oid
-
+class UUIDBinaryDumper(UUIDDumper):
     def dump(self, obj: UUID) -> bytes:
         return obj.bytes
 
 
 @Loader.text(builtins["uuid"].oid)
 class UUIDLoader(Loader):
-    def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> UUID:
-        return UUID(__decode(data)[0])
+    def load(self, data: bytes) -> UUID:
+        return UUID(data.decode("utf8"))
 
 
 @Loader.binary(builtins["uuid"].oid)
index 01dc771f582416601d874a614a5c93df0bd4f840..59550b9b5629034a175284b4cf24131dc296b425 100644 (file)
@@ -53,14 +53,13 @@ class PostgresQuery:
         if isinstance(query, Composable):
             query = query.as_string(self._tx)
 
-        codec = self._tx.codec
         if vars is not None:
             self.query, self.formats, self._order, self._parts = _query2pg(
-                query, codec.name
+                query, self._tx.encoding
             )
         else:
             if isinstance(query, str):
-                query = codec.encode(query)[0]
+                query = query.encode(self._tx.encoding)
             self.query = query
             self.formats = self._order = None
 
index ff14ded2064f336260a8249ba29bd6d10bfdbe7a..2dddf25f921f4aa175d67978e02b9d02381076b1 100644 (file)
@@ -7,7 +7,6 @@ information. Will submit a bug.
 
 # Copyright (C) 2020 The Psycopg Team
 
-import codecs
 from typing import Any, Iterable, List, Optional, Sequence, Tuple
 
 from psycopg3.adapt import Dumper, Loader
@@ -21,7 +20,7 @@ class Transformer:
     @property
     def connection(self) -> Optional[BaseConnection]: ...
     @property
-    def codec(self) -> codecs.CodecInfo: ...
+    def encoding(self) -> str: ...
     @property
     def dumpers(self) -> DumpersMap: ...
     @property
index e79a32dbad4092e6c812de3a58af9bb22e133070..a8cde18326103d9c0a3b2b399493829173cffd32 100644 (file)
@@ -11,7 +11,6 @@ too many temporary Python objects and performing less memory copying.
 from cpython.ref cimport Py_INCREF
 from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
 
-import codecs
 from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
 
 from psycopg3_c cimport libpq
@@ -39,9 +38,10 @@ cdef class Transformer:
 
     cdef list _dumpers_maps, _loaders_maps
     cdef dict _dumpers, _loaders, _dumpers_cache, _loaders_cache, _load_funcs
-    cdef object _connection, _codec
+    cdef object _connection
     cdef PGresult _pgresult
     cdef int _nfields, _ntuples
+    cdef str _encoding
 
     cdef list _row_loaders
 
@@ -70,7 +70,7 @@ cdef class Transformer:
         cdef Transformer ctx
         if context is None:
             self._connection = None
-            self._codec = codecs.lookup("utf8")
+            self._encoding = "utf-8"
             self._dumpers = {}
             self._loaders = {}
             self._dumpers_maps = [self._dumpers]
@@ -81,7 +81,7 @@ cdef class Transformer:
             # for nested types: share the entire state of the parent
             ctx = context
             self._connection = ctx._connection
-            self._codec = ctx._codec
+            self._encoding = ctx.encoding
             self._dumpers = ctx._dumpers
             self._loaders = ctx._loaders
             self._dumpers_maps.extend(ctx._dumpers_maps)
@@ -91,7 +91,7 @@ cdef class Transformer:
 
         elif isinstance(context, BaseCursor):
             self._connection = context.connection
-            self._codec = context.connection.codec
+            self._encoding = context.connection.pyenc
             self._dumpers = {}
             self._dumpers_maps.extend(
                 (self._dumpers, context.dumpers, self.connection.dumpers)
@@ -103,7 +103,7 @@ cdef class Transformer:
 
         elif isinstance(context, BaseConnection):
             self._connection = context
-            self._codec = context.codec
+            self._encoding = context.pyenc
             self._dumpers = {}
             self._dumpers_maps.extend((self._dumpers, context.dumpers))
             self._loaders = {}
@@ -117,8 +117,8 @@ cdef class Transformer:
         return self._connection
 
     @property
-    def codec(self):
-        return self._codec
+    def encoding(self):
+        return self._encoding
 
     @property
     def dumpers(self):
index a7f6f6f76850b7b868d81e4bd51c5a92dffcce46..547fe8c6428e0af1c72426c822c662fca81751da 100644 (file)
@@ -4,38 +4,39 @@ Cython adapters for textual types.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from cpython.bytes cimport PyBytes_FromStringAndSize
-from cpython.unicode cimport PyUnicode_DecodeUTF8
+from cpython.unicode cimport PyUnicode_Decode, PyUnicode_DecodeUTF8
 from psycopg3_c cimport libpq
 
 
 cdef class TextLoader(CLoader):
     cdef int is_utf8
-    cdef object pydecoder
+    cdef char *encoding
+    cdef bytes _bytes_encoding  # needed to keep `encoding` alive
 
     def __init__(self, oid: int, context: "AdaptContext" = None):
         super().__init__(oid, context)
 
         self.is_utf8 = 0
-        self.pydecoder = None
+        self.encoding = NULL
+
         conn = self.connection
         if conn is not None:
             if conn.client_encoding == "UTF8":
                 self.is_utf8 = 1
             elif conn.client_encoding != "SQL_ASCII":
-                self.pydecoder = conn.codec.decode
+                self._bytes_encoding = conn.pyenc.encode("utf-8")
+                self.encoding = self._bytes_encoding
         else:
-            self.pydecoder = codecs.lookup("utf8").decode
+            self.encoding = "utf-8"
 
     cdef object cload(self, const char *data, size_t length):
         if self.is_utf8:
             return PyUnicode_DecodeUTF8(<char *>data, length, NULL)
 
-        b = PyBytes_FromStringAndSize(data, length)
-        if self.pydecoder is not None:
-            return self.pydecoder(b)[0]
+        if self.encoding:
+            return PyUnicode_Decode(<char *>data, length, self.encoding, NULL)
         else:
-            return b
+            return data[:length]
 
 
 cdef class ByteaLoader(CLoader):
index 2d741a020e344a7d76860fc7b9cb8634cc986f62..777a65ec6f3f945bfbbf087a123c66c07e870eab 100644 (file)
@@ -190,7 +190,7 @@ def test_set_encoding(conn):
 def test_normalize_encoding(conn, enc, out, codec):
     conn.client_encoding = enc
     assert conn.client_encoding == out
-    assert conn.codec.name == codec
+    assert conn.pyenc == codec
 
 
 @pytest.mark.parametrize(
@@ -207,7 +207,7 @@ def test_encoding_env_var(dsn, monkeypatch, enc, out, codec):
     monkeypatch.setenv("PGCLIENTENCODING", enc)
     conn = psycopg3.connect(dsn)
     assert conn.client_encoding == out
-    assert conn.codec.name == codec
+    assert conn.pyenc == codec
 
 
 def test_set_encoding_unsupported(conn):
index daa8aad034175de814f623b057659c2ebdbb4818..64109e5f9827d7152789227cd9d7d33634edc9a6 100644 (file)
@@ -205,7 +205,7 @@ async def test_set_encoding(aconn):
 async def test_normalize_encoding(aconn, enc, out, codec):
     await aconn.set_client_encoding(enc)
     assert aconn.client_encoding == out
-    assert aconn.codec.name == codec
+    assert aconn.pyenc == codec
 
 
 @pytest.mark.parametrize(
@@ -222,7 +222,7 @@ async def test_encoding_env_var(dsn, monkeypatch, enc, out, codec):
     monkeypatch.setenv("PGCLIENTENCODING", enc)
     aconn = await psycopg3.AsyncConnection.connect(dsn)
     assert aconn.client_encoding == out
-    assert aconn.codec.name == codec
+    assert aconn.pyenc == codec
 
 
 async def test_set_encoding_unsupported(aconn):
index 981515858d5fd683f162558edf7556cb59d09bb0..23506a8cac6ca6edca9f43cabfc4c646c3eb83ec 100644 (file)
@@ -219,7 +219,7 @@ def make_testfunc(conn):
         )
         .format(sql.Identifier(procname), sql.Identifier(paramname))
         .as_string(conn)
-        .encode(conn.codec.name)
+        .encode(conn.pyenc)
     )
 
     # execute regardless of sync/async conn