From: Daniele Varrazzo Date: Sat, 7 Nov 2020 01:58:06 +0000 (+0000) Subject: Dropped local variable micro-optimization X-Git-Tag: 3.0.dev0~394 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9e1334350183141da00ff8053cf7e2e366910d26;p=thirdparty%2Fpsycopg.git Dropped local variable micro-optimization A long time ago it was a thing. timeit shows it isn't anymore. --- diff --git a/psycopg3/psycopg3/_transform.py b/psycopg3/psycopg3/_transform.py index 70d2f99b7..b10aa7cf4 100644 --- a/psycopg3/psycopg3/_transform.py +++ b/psycopg3/psycopg3/_transform.py @@ -4,7 +4,6 @@ Helper object to transform values between Python and PostgreSQL # Copyright (C) 2020 The Psycopg Team -import codecs from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING @@ -56,7 +55,7 @@ class Transformer: def _setup_context(self, context: AdaptContext) -> None: if context is None: self._connection = None - self._codec = codecs.lookup("utf8") + self._encoding = "utf-8" self._dumpers = {} self._loaders = {} self._dumpers_maps = [self._dumpers] @@ -66,7 +65,7 @@ class Transformer: # A transformer created from a transformers: usually it happens # for nested types: share the entire state of the parent self._connection = context.connection - self._codec = context.codec + self._encoding = context.encoding self._dumpers = context.dumpers self._loaders = context.loaders self._dumpers_maps.extend(context._dumpers_maps) @@ -76,7 +75,7 @@ class Transformer: elif isinstance(context, BaseCursor): self._connection = context.connection - self._codec = context.connection.codec + self._encoding = context.connection.pyenc self._dumpers = {} self._dumpers_maps.extend( (self._dumpers, context.dumpers, context.connection.dumpers) @@ -88,7 +87,7 @@ class Transformer: elif isinstance(context, BaseConnection): self._connection = context - self._codec = context.codec + self._encoding = context.pyenc self._dumpers = {} self._dumpers_maps.extend((self._dumpers, context.dumpers)) self._loaders = {} @@ -104,8 +103,8 @@ class Transformer: return self._connection @property - def codec(self) -> codecs.CodecInfo: - return self._codec + def encoding(self) -> str: + return self._encoding @property def pgresult(self) -> Optional[pq.proto.PGresult]: diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 19231c056..a720528ab 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -4,7 +4,6 @@ psycopg3 connection objects # Copyright (C) 2020 The Psycopg Team -import codecs import logging import asyncio import threading @@ -83,8 +82,10 @@ class BaseConnection: self.loaders: proto.LoadersMap = {} self._notice_handlers: List[NoticeHandler] = [] self._notify_handlers: List[NotifyHandler] = [] - # name of the postgres encoding (in bytes) + # postgres name of the client encoding (in bytes) self._pgenc = b"" + # python name of the client encoding + self._pyenc = "utf-8" wself = ref(self) @@ -126,25 +127,19 @@ class BaseConnection: return self.cursor_factory(self, format=format) @property - def codec(self) -> codecs.CodecInfo: - # TODO: utf8 fastpath? + def pyenc(self) -> str: pgenc = self.pgconn.parameter_status(b"client_encoding") or b"" if self._pgenc != pgenc: if pgenc: try: - pyenc = pq.py_codecs[pgenc] + self._pyenc = pq.py_codecs[pgenc] except KeyError: raise e.NotSupportedError( f"encoding {pgenc.decode('ascii')} not available in Python" ) - self._codec = codecs.lookup(pyenc) - else: - # fallback for a connection closed whose codec was never asked - if not hasattr(self, "_codec"): - self._codec = codecs.lookup("utf8") self._pgenc = pgenc - return self._codec + return self._pyenc @property def client_encoding(self) -> str: @@ -179,7 +174,7 @@ class BaseConnection: if self is None or not self._notice_handler: return - diag = e.Diagnostic(res, self.codec.name) + diag = e.Diagnostic(res, self._pyenc) for cb in self._notice_handlers: try: cb(diag) @@ -202,8 +197,11 @@ class BaseConnection: if self is None or not self._notify_handlers: return - decode = self.codec.decode - n = Notify(decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid) + n = Notify( + pgn.relname.decode(self._pyenc), + pgn.extra.decode(self._pyenc), + pgn.be_pid, + ) for cb in self._notify_handlers: cb(n) @@ -261,7 +259,7 @@ class Connection(BaseConnection): if pgres.status != ExecStatus.COMMAND_OK: raise e.OperationalError( "error on begin:" - f" {pq.error_message(pgres, encoding=self.codec.name)}" + f" {pq.error_message(pgres, encoding=self._pyenc)}" ) def commit(self) -> None: @@ -283,7 +281,7 @@ class Connection(BaseConnection): if results[-1].status != ExecStatus.COMMAND_OK: raise e.OperationalError( f"error on {command.decode('utf8')}:" - f" {pq.error_message(results[-1], encoding=self.codec.name)}" + f" {pq.error_message(results[-1], encoding=self._pyenc)}" ) @classmethod @@ -301,16 +299,17 @@ class Connection(BaseConnection): gen = execute(self.pgconn) (result,) = self.wait(gen) if result.status != ExecStatus.TUPLES_OK: - raise e.error_from_result(result, encoding=self.codec.name) + raise e.error_from_result(result, encoding=self._pyenc) def notifies(self) -> Generator[Optional[Notify], bool, None]: - decode = self.codec.decode while 1: with self.lock: ns = self.wait(notifies(self.pgconn)) for pgn in ns: n = Notify( - decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid + pgn.relname.decode(self._pyenc), + pgn.extra.decode(self._pyenc), + pgn.be_pid, ) if (yield n): yield None # for the send who stopped us @@ -375,7 +374,7 @@ class AsyncConnection(BaseConnection): if pgres.status != ExecStatus.COMMAND_OK: raise e.OperationalError( "error on begin:" - f" {pq.error_message(pgres, encoding=self.codec.name)}" + f" {pq.error_message(pgres, encoding=self._pyenc)}" ) async def commit(self) -> None: @@ -397,7 +396,7 @@ class AsyncConnection(BaseConnection): if pgres.status != ExecStatus.COMMAND_OK: raise e.OperationalError( f"error on {command.decode('utf8')}:" - f" {pq.error_message(pgres, encoding=self.codec.name)}" + f" {pq.error_message(pgres, encoding=self._pyenc)}" ) @classmethod @@ -419,16 +418,17 @@ class AsyncConnection(BaseConnection): gen = execute(self.pgconn) (result,) = await self.wait(gen) if result.status != ExecStatus.TUPLES_OK: - raise e.error_from_result(result, encoding=self.codec.name) + raise e.error_from_result(result, encoding=self._pyenc) async def notifies(self) -> AsyncGenerator[Optional[Notify], bool]: - decode = self.codec.decode while 1: async with self.lock: ns = await self.wait(notifies(self.pgconn)) for pgn in ns: n = Notify( - decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid + pgn.relname.decode(self._pyenc), + pgn.extra.decode(self._pyenc), + pgn.be_pid, ) if (yield n): yield None diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index ff61cf988..6cd89e533 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -5,7 +5,6 @@ psycopg3 copy support # Copyright (C) 2020 The Psycopg Team import re -import codecs import struct from typing import TYPE_CHECKING, AsyncIterator, Iterator from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union @@ -34,7 +33,7 @@ class BaseCopy: self.pgresult = result self._first_row = True self._finished = False - self._codec: Optional[codecs.CodecInfo] = None + self._encoding: str = "" if format == pq.Format.TEXT: self._format_row = self._format_row_text @@ -70,8 +69,8 @@ class BaseCopy: return data elif isinstance(data, str): - if self._codec is not None: - return self._codec.encode(data)[0] + if self._encoding: + return data.encode(self._encoding) if ( self.pgresult is None @@ -80,8 +79,8 @@ class BaseCopy: raise TypeError( "cannot copy str data in binary mode: use bytes instead" ) - self._codec = self.connection.codec - return self._codec.encode(data)[0] + self._encoding = self.connection.pyenc + return data.encode(self._encoding) else: raise TypeError(f"can't write {type(data).__name__}") @@ -180,13 +179,9 @@ class Copy(BaseCopy): data = self.format_row(row) self.write(data) - def finish(self, error: Optional[str] = None) -> None: + def finish(self, error: str = "") -> None: conn = self.connection - berr = ( - conn.codec.encode(error, "replace")[0] - if error is not None - else None - ) + berr = error.encode(conn.pyenc, "replace") if error else None conn.wait(copy_end(conn.pgconn, berr)) self._finished = True @@ -205,7 +200,7 @@ class Copy(BaseCopy): self.write(b"\xff\xff") self.finish() else: - self.finish(str(exc_val)) + self.finish(str(exc_val) or type(exc_val).__qualname__) def __iter__(self) -> Iterator[bytes]: while 1: @@ -241,13 +236,9 @@ class AsyncCopy(BaseCopy): data = self.format_row(row) await self.write(data) - async def finish(self, error: Optional[str] = None) -> None: + async def finish(self, error: str = "") -> None: conn = self.connection - berr = ( - conn.codec.encode(error, "replace")[0] - if error is not None - else None - ) + berr = error.encode(conn.pyenc, "replace") if error else None await conn.wait(copy_end(conn.pgconn, berr)) self._finished = True diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 77b1ecd86..c062d2ae7 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -191,10 +191,8 @@ class BaseCursor: res = self.pgresult if res is None or res.status != self.ExecStatus.TUPLES_OK: return None - return [ - Column(res, i, self.connection.codec.name) - for i in range(res.nfields) - ] + encoding = self.connection.pyenc + return [Column(res, i, encoding) for i in range(res.nfields)] @property def rowcount(self) -> int: @@ -283,7 +281,7 @@ class BaseCursor: if results[-1].status == S.FATAL_ERROR: raise e.error_from_result( - results[-1], encoding=self.connection.codec.name + results[-1], encoding=self.connection.pyenc ) elif badstats & {S.COPY_IN, S.COPY_OUT, S.COPY_BOTH}: @@ -393,9 +391,7 @@ class BaseCursor: if status in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT): return elif status == pq.ExecStatus.FATAL_ERROR: - raise e.error_from_result( - result, encoding=self.connection.codec.name - ) + raise e.error_from_result(result, encoding=self.connection.pyenc) else: raise e.ProgrammingError( "copy() should be used only with COPY ... TO STDOUT or COPY ..." @@ -450,7 +446,7 @@ class Cursor(BaseCursor): (result,) = self.connection.wait(gen) if result.status == self.ExecStatus.FATAL_ERROR: raise e.error_from_result( - result, encoding=self.connection.codec.name + result, encoding=self.connection.pyenc ) else: pgq.dump(vars) @@ -578,7 +574,7 @@ class AsyncCursor(BaseCursor): (result,) = await self.connection.wait(gen) if result.status == self.ExecStatus.FATAL_ERROR: raise e.error_from_result( - result, encoding=self.connection.codec.name + result, encoding=self.connection.pyenc ) else: pgq.dump(vars) diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index c8f04b24c..3a0130e4b 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -4,7 +4,6 @@ Protocol objects representing different implementations of the same classes. # Copyright (C) 2020 The Psycopg Team -import codecs from typing import Any, Callable, Dict, Generator, Mapping from typing import Optional, Sequence, Tuple, Type, TypeVar, Union from typing import TYPE_CHECKING @@ -52,7 +51,7 @@ class Transformer(Protocol): ... @property - def codec(self) -> codecs.CodecInfo: + def encoding(self) -> str: ... @property diff --git a/psycopg3/psycopg3/sql.py b/psycopg3/psycopg3/sql.py index d0f53067d..60681877c 100644 --- a/psycopg3/psycopg3/sql.py +++ b/psycopg3/psycopg3/sql.py @@ -340,9 +340,9 @@ class Identifier(Composable): raise ValueError(f"no connection in the context: {context}") esc = Escaping(conn.pgconn) - codec = conn.codec - escs = [esc.escape_identifier(codec.encode(s)[0]) for s in self._obj] - return codec.decode(b".".join(escs))[0] + enc = conn.pyenc + escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + return b".".join(escs).decode(enc) class Literal(Composable): @@ -373,7 +373,7 @@ class Literal(Composable): tx = context if isinstance(context, Transformer) else Transformer(conn) dumper = tx.get_dumper(self._obj, Format.TEXT) quoted = dumper.quote(self._obj) - return conn.codec.decode(quoted)[0] if conn else quoted.decode("utf8") + return quoted.decode(conn.pyenc if conn else "utf-8") class Placeholder(Composable): diff --git a/psycopg3/psycopg3/types/date.py b/psycopg3/psycopg3/types/date.py index 9cf039353..d7aa36a78 100644 --- a/psycopg3/psycopg3/types/date.py +++ b/psycopg3/psycopg3/types/date.py @@ -13,7 +13,6 @@ from ..oids import builtins from ..adapt import Dumper, Loader from ..proto import AdaptContext from ..errors import InterfaceError, DataError -from ..utils.codecs import EncodeFunc, DecodeFunc, encode_ascii, decode_ascii @Dumper.text(date) @@ -21,10 +20,10 @@ class DateDumper(Dumper): oid = builtins["date"].oid - def dump(self, obj: date, __encode: EncodeFunc = encode_ascii) -> bytes: + def dump(self, obj: date) -> bytes: # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) # the YYYY-MM-DD is always understood correctly. - return __encode(str(obj))[0] + return str(obj).encode("utf8") @Dumper.text(time) @@ -32,8 +31,8 @@ class TimeDumper(Dumper): oid = builtins["timetz"].oid - def dump(self, obj: time, __encode: EncodeFunc = encode_ascii) -> bytes: - return __encode(str(obj))[0] + def dump(self, obj: time) -> bytes: + return str(obj).encode("utf8") @Dumper.text(datetime) @@ -41,10 +40,10 @@ class DateTimeDumper(Dumper): oid = builtins["timestamptz"].oid - def dump(self, obj: date, __encode: EncodeFunc = encode_ascii) -> bytes: + def dump(self, obj: date) -> bytes: # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) # the YYYY-MM-DD is always understood correctly. - return __encode(str(obj))[0] + return str(obj).encode("utf8") @Dumper.text(timedelta) @@ -61,10 +60,8 @@ class TimeDeltaDumper(Dumper): ): setattr(self, "dump", self._dump_sql) - def dump( - self, obj: timedelta, __encode: EncodeFunc = encode_ascii - ) -> bytes: - return __encode(str(obj))[0] + def dump(self, obj: timedelta) -> bytes: + return str(obj).encode("utf8") def _dump_sql(self, obj: timedelta) -> bytes: # sql_standard format needs explicit signs @@ -82,9 +79,9 @@ class DateLoader(Loader): super().__init__(oid, context) self._format = self._format_from_context() - def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> date: + def load(self, data: bytes) -> date: try: - return datetime.strptime(__decode(data)[0], self._format).date() + return datetime.strptime(data.decode("utf8"), self._format).date() except ValueError as e: return self._raise_error(data, e) @@ -140,11 +137,11 @@ class TimeLoader(Loader): _format = "%H:%M:%S.%f" _format_no_micro = _format.replace(".%f", "") - def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> time: + def load(self, data: bytes) -> time: # check if the data contains microseconds fmt = self._format if b"." in data else self._format_no_micro try: - return datetime.strptime(__decode(data)[0], fmt).time() + return datetime.strptime(data.decode("utf8"), fmt).time() except ValueError as e: return self._raise_error(data, e) @@ -170,22 +167,20 @@ class TimeTzLoader(TimeLoader): super().__init__(oid, context) - def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> time: + def load(self, data: bytes) -> time: # Hack to convert +HH in +HHMM if data[-3] in (43, 45): data += b"00" fmt = self._format if b"." in data else self._format_no_micro try: - dt = datetime.strptime(__decode(data)[0], fmt) + dt = datetime.strptime(data.decode("utf8"), fmt) except ValueError as e: return self._raise_error(data, e) return dt.time().replace(tzinfo=dt.tzinfo) - def _load_py36( - self, data: bytes, __decode: DecodeFunc = decode_ascii - ) -> time: + def _load_py36(self, data: bytes) -> time: # Drop seconds from timezone for Python 3.6 # Also, Python 3.6 doesn't support HHMM, only HH:MM if data[-6] in (43, 45): # +-HH:MM -> +-HHMM @@ -202,15 +197,13 @@ class TimestampLoader(DateLoader): super().__init__(oid, context) self._format_no_micro = self._format.replace(".%f", "") - def load( - self, data: bytes, __decode: DecodeFunc = decode_ascii - ) -> datetime: + def load(self, data: bytes) -> datetime: # check if the data contains microseconds fmt = ( self._format if data.find(b".", 19) >= 0 else self._format_no_micro ) try: - return datetime.strptime(__decode(data)[0], fmt) + return datetime.strptime(data.decode("utf8"), fmt) except ValueError as e: return self._raise_error(data, e) @@ -284,18 +277,14 @@ class TimestamptzLoader(TimestampLoader): setattr(self, "load", self._load_notimpl) return "" - def load( - self, data: bytes, __decode: DecodeFunc = decode_ascii - ) -> datetime: + def load(self, data: bytes) -> datetime: # Hack to convert +HH in +HHMM if data[-3] in (43, 45): data += b"00" return super().load(data) - def _load_py36( - self, data: bytes, __decode: DecodeFunc = decode_ascii - ) -> datetime: + def _load_py36(self, data: bytes) -> datetime: # Drop seconds from timezone for Python 3.6 # Also, Python 3.6 doesn't support HHMM, only HH:MM tzsep = (43, 45) # + and - bytes diff --git a/psycopg3/psycopg3/types/json.py b/psycopg3/psycopg3/types/json.py index 2e7ee5ffd..74ec26f13 100644 --- a/psycopg3/psycopg3/types/json.py +++ b/psycopg3/psycopg3/types/json.py @@ -10,7 +10,6 @@ from typing import Any, Callable, Optional from ..oids import builtins from ..adapt import Dumper, Loader from ..errors import DataError -from ..utils.codecs import EncodeFunc, encode_utf8 JSON_OID = builtins["json"].oid JSONB_OID = builtins["jsonb"].oid @@ -36,10 +35,8 @@ class Jsonb(_JsonWrapper): class _JsonDumper(Dumper): - def dump( - self, obj: _JsonWrapper, __encode: EncodeFunc = encode_utf8 - ) -> bytes: - return __encode(obj.dumps())[0] + def dump(self, obj: _JsonWrapper) -> bytes: + return obj.dumps().encode("utf-8") @Dumper.text(Json) @@ -55,10 +52,8 @@ class JsonbDumper(_JsonDumper): @Dumper.binary(Jsonb) class JsonbBinaryDumper(JsonbDumper): - def dump( - self, obj: _JsonWrapper, __encode: EncodeFunc = encode_utf8 - ) -> bytes: - return b"\x01" + __encode(obj.dumps())[0] + def dump(self, obj: _JsonWrapper) -> bytes: + return b"\x01" + obj.dumps().encode("utf-8") @Loader.text(builtins["json"].oid) diff --git a/psycopg3/psycopg3/types/network.py b/psycopg3/psycopg3/types/network.py index 42a7ae3a5..b192aeee5 100644 --- a/psycopg3/psycopg3/types/network.py +++ b/psycopg3/psycopg3/types/network.py @@ -5,20 +5,24 @@ Adapters for network types. # Copyright (C) 2020 The Psycopg Team # TODO: consiter lazy dumper registration. -from ipaddress import ip_address, ip_interface, ip_network +import ipaddress from ipaddress import IPv4Address, IPv4Interface, IPv4Network from ipaddress import IPv6Address, IPv6Interface, IPv6Network -from typing import cast, Union +from typing import cast, Callable, Union from ..oids import builtins from ..adapt import Dumper, Loader -from ..utils.codecs import encode_ascii, decode_ascii Address = Union[IPv4Address, IPv6Address] Interface = Union[IPv4Interface, IPv6Interface] Network = Union[IPv4Network, IPv6Network] +# in typeshed these types are commented out +ip_address = cast(Callable[[str], Address], ipaddress.ip_address) +ip_interface = cast(Callable[[str], Interface], ipaddress.ip_interface) +ip_network = cast(Callable[[str], Network], ipaddress.ip_network) + @Dumper.text(IPv4Address) @Dumper.text(IPv6Address) @@ -29,7 +33,7 @@ class InterfaceDumper(Dumper): oid = builtins["inet"].oid def dump(self, obj: Interface) -> bytes: - return encode_ascii(str(obj))[0] + return str(obj).encode("utf8") @Dumper.text(IPv4Network) @@ -39,19 +43,19 @@ class NetworkDumper(Dumper): oid = builtins["cidr"].oid def dump(self, obj: Network) -> bytes: - return encode_ascii(str(obj))[0] + return str(obj).encode("utf8") @Loader.text(builtins["inet"].oid) class InetLoader(Loader): def load(self, data: bytes) -> Union[Address, Interface]: if b"/" in data: - return cast(Interface, ip_interface(decode_ascii(data)[0])) + return ip_interface(data.decode("utf8")) else: - return cast(Address, ip_address(decode_ascii(data)[0])) + return ip_address(data.decode("utf8")) @Loader.text(builtins["cidr"].oid) class CidrLoader(Loader): def load(self, data: bytes) -> Network: - return cast(Network, ip_network(decode_ascii(data)[0])) + return ip_network(data.decode("utf8")) diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index c239a47c6..054c9c949 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -10,16 +10,22 @@ from decimal import Decimal from ..oids import builtins from ..adapt import Dumper, Loader -from ..utils.codecs import DecodeFunc, decode_ascii -PackInt = Callable[[int], bytes] -UnpackInt = Callable[[bytes], Tuple[int]] -UnpackFloat = Callable[[bytes], Tuple[float]] +_PackInt = Callable[[int], bytes] +_UnpackInt = Callable[[bytes], Tuple[int]] +_UnpackFloat = Callable[[bytes], Tuple[float]] + +_pack_int2 = cast(_PackInt, struct.Struct("!h").pack) +_pack_int4 = cast(_PackInt, struct.Struct("!i").pack) +_pack_uint4 = cast(_PackInt, struct.Struct("!I").pack) +_pack_int8 = cast(_PackInt, struct.Struct("!q").pack) +_unpack_int2 = cast(_UnpackInt, struct.Struct("!h").unpack) +_unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack) +_unpack_uint4 = cast(_UnpackInt, struct.Struct("!I").unpack) +_unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack) +_unpack_float4 = cast(_UnpackFloat, struct.Struct("!f").unpack) +_unpack_float8 = cast(_UnpackFloat, struct.Struct("!d").unpack) -_pack_int2 = cast(PackInt, struct.Struct("!h").pack) -_pack_int4 = cast(PackInt, struct.Struct("!i").pack) -_pack_uint4 = cast(PackInt, struct.Struct("!I").pack) -_pack_int8 = cast(PackInt, struct.Struct("!q").pack) # Wrappers to force numbers to be cast as specific PostgreSQL types @@ -142,48 +148,32 @@ class OidBinaryDumper(OidDumper): @Loader.text(builtins["int8"].oid) @Loader.text(builtins["oid"].oid) class IntLoader(Loader): - def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> int: - return int(__decode(data)[0]) + def load(self, data: bytes) -> int: + return int(data.decode("utf8")) @Loader.binary(builtins["int2"].oid) class Int2BinaryLoader(Loader): - def load( - self, - data: bytes, - __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!h").unpack), - ) -> int: - return __unpack(data)[0] + def load(self, data: bytes) -> int: + return _unpack_int2(data)[0] @Loader.binary(builtins["int4"].oid) class Int4BinaryLoader(Loader): - def load( - self, - data: bytes, - __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!i").unpack), - ) -> int: - return __unpack(data)[0] + def load(self, data: bytes) -> int: + return _unpack_int4(data)[0] @Loader.binary(builtins["int8"].oid) class Int8BinaryLoader(Loader): - def load( - self, - data: bytes, - __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!q").unpack), - ) -> int: - return __unpack(data)[0] + def load(self, data: bytes) -> int: + return _unpack_int8(data)[0] @Loader.binary(builtins["oid"].oid) class OidBinaryLoader(Loader): - def load( - self, - data: bytes, - __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!I").unpack), - ) -> int: - return __unpack(data)[0] + def load(self, data: bytes) -> int: + return _unpack_uint4(data)[0] @Loader.text(builtins["float4"].oid) @@ -196,27 +186,17 @@ class FloatLoader(Loader): @Loader.binary(builtins["float4"].oid) class Float4BinaryLoader(Loader): - def load( - self, - data: bytes, - __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!f").unpack), - ) -> int: - return __unpack(data)[0] + def load(self, data: bytes) -> float: + return _unpack_float4(data)[0] @Loader.binary(builtins["float8"].oid) class Float8BinaryLoader(Loader): - def load( - self, - data: bytes, - __unpack: UnpackInt = cast(UnpackInt, struct.Struct("!d").unpack), - ) -> int: - return __unpack(data)[0] + def load(self, data: bytes) -> float: + return _unpack_float8(data)[0] @Loader.text(builtins["numeric"].oid) class NumericLoader(Loader): - def load( - self, data: bytes, __decode: DecodeFunc = decode_ascii - ) -> Decimal: - return Decimal(__decode(data)[0]) + def load(self, data: bytes) -> Decimal: + return Decimal(data.decode("utf8")) diff --git a/psycopg3/psycopg3/types/text.py b/psycopg3/psycopg3/types/text.py index 6ad6ca45c..f16478073 100644 --- a/psycopg3/psycopg3/types/text.py +++ b/psycopg3/psycopg3/types/text.py @@ -4,14 +4,13 @@ Adapters for textual types. # Copyright (C) 2020 The Psycopg Team -from typing import Optional, Union, TYPE_CHECKING +from typing import Union, TYPE_CHECKING from ..pq import Escaping from ..oids import builtins, INVALID_OID from ..adapt import Dumper, Loader from ..proto import AdaptContext from ..errors import DataError -from ..utils.codecs import EncodeFunc, DecodeFunc, encode_utf8, decode_utf8 if TYPE_CHECKING: from ..pq.proto import Escaping as EscapingProto @@ -21,20 +20,19 @@ class _StringDumper(Dumper): def __init__(self, src: type, context: AdaptContext): super().__init__(src, context) - self._encode: EncodeFunc if self.connection: if self.connection.client_encoding != "SQL_ASCII": - self._encode = self.connection.codec.encode + self.encoding = self.connection.pyenc else: - self._encode = encode_utf8 + self.encoding = "utf-8" else: - self._encode = encode_utf8 + self.encoding = "utf-8" @Dumper.binary(str) class StringBinaryDumper(_StringDumper): def dump(self, obj: str) -> bytes: - return self._encode(obj)[0] + return obj.encode(self.encoding) @Dumper.text(str) @@ -45,7 +43,7 @@ class StringDumper(_StringDumper): "PostgreSQL text fields cannot contain NUL (0x00) bytes" ) else: - return self._encode(obj)[0] + return obj.encode(self.encoding) @Loader.text(builtins["text"].oid) @@ -54,23 +52,20 @@ class StringDumper(_StringDumper): @Loader.binary(builtins["varchar"].oid) @Loader.text(INVALID_OID) class TextLoader(Loader): - - decode: Optional[DecodeFunc] - def __init__(self, oid: int, context: AdaptContext): super().__init__(oid, context) if self.connection is not None: if self.connection.client_encoding != "SQL_ASCII": - self.decode = self.connection.codec.decode + self.encoding = self.connection.pyenc else: - self.decode = None + self.encoding = "" else: - self.decode = decode_utf8 + self.encoding = "utf-8" def load(self, data: bytes) -> Union[bytes, str]: - if self.decode is not None: - return self.decode(data)[0] + if self.encoding: + return data.decode(self.encoding) else: # return bytes for SQL_ASCII db return data @@ -83,15 +78,10 @@ class TextLoader(Loader): class UnknownLoader(Loader): def __init__(self, oid: int, context: AdaptContext): super().__init__(oid, context) - - self.decode: DecodeFunc - if self.connection is not None: - self.decode = self.connection.codec.decode - else: - self.decode = decode_utf8 + self.encoding = self.connection.pyenc if self.connection else "utf-8" def load(self, data: bytes) -> str: - return self.decode(data)[0] + return data.decode(self.encoding) @Dumper.text(bytes) diff --git a/psycopg3/psycopg3/types/uuid.py b/psycopg3/psycopg3/types/uuid.py index 6076b12f4..73cb30ef2 100644 --- a/psycopg3/psycopg3/types/uuid.py +++ b/psycopg3/psycopg3/types/uuid.py @@ -10,7 +10,6 @@ from uuid import UUID from ..oids import builtins from ..adapt import Dumper, Loader -from ..utils.codecs import EncodeFunc, DecodeFunc, encode_ascii, decode_ascii @Dumper.text(UUID) @@ -18,23 +17,20 @@ class UUIDDumper(Dumper): oid = builtins["uuid"].oid - def dump(self, obj: UUID, __encode: EncodeFunc = encode_ascii) -> bytes: - return __encode(obj.hex)[0] + def dump(self, obj: UUID) -> bytes: + return obj.hex.encode("utf8") @Dumper.binary(UUID) -class UUIDBinaryDumper(Dumper): - - oid = builtins["uuid"].oid - +class UUIDBinaryDumper(UUIDDumper): def dump(self, obj: UUID) -> bytes: return obj.bytes @Loader.text(builtins["uuid"].oid) class UUIDLoader(Loader): - def load(self, data: bytes, __decode: DecodeFunc = decode_ascii) -> UUID: - return UUID(__decode(data)[0]) + def load(self, data: bytes) -> UUID: + return UUID(data.decode("utf8")) @Loader.binary(builtins["uuid"].oid) diff --git a/psycopg3/psycopg3/utils/queries.py b/psycopg3/psycopg3/utils/queries.py index 01dc771f5..59550b9b5 100644 --- a/psycopg3/psycopg3/utils/queries.py +++ b/psycopg3/psycopg3/utils/queries.py @@ -53,14 +53,13 @@ class PostgresQuery: if isinstance(query, Composable): query = query.as_string(self._tx) - codec = self._tx.codec if vars is not None: self.query, self.formats, self._order, self._parts = _query2pg( - query, codec.name + query, self._tx.encoding ) else: if isinstance(query, str): - query = codec.encode(query)[0] + query = query.encode(self._tx.encoding) self.query = query self.formats = self._order = None diff --git a/psycopg3_c/psycopg3_c/_psycopg3.pyi b/psycopg3_c/psycopg3_c/_psycopg3.pyi index ff14ded20..2dddf25f9 100644 --- a/psycopg3_c/psycopg3_c/_psycopg3.pyi +++ b/psycopg3_c/psycopg3_c/_psycopg3.pyi @@ -7,7 +7,6 @@ information. Will submit a bug. # Copyright (C) 2020 The Psycopg Team -import codecs from typing import Any, Iterable, List, Optional, Sequence, Tuple from psycopg3.adapt import Dumper, Loader @@ -21,7 +20,7 @@ class Transformer: @property def connection(self) -> Optional[BaseConnection]: ... @property - def codec(self) -> codecs.CodecInfo: ... + def encoding(self) -> str: ... @property def dumpers(self) -> DumpersMap: ... @property diff --git a/psycopg3_c/psycopg3_c/transform.pyx b/psycopg3_c/psycopg3_c/transform.pyx index e79a32dba..a8cde1832 100644 --- a/psycopg3_c/psycopg3_c/transform.pyx +++ b/psycopg3_c/psycopg3_c/transform.pyx @@ -11,7 +11,6 @@ too many temporary Python objects and performing less memory copying. from cpython.ref cimport Py_INCREF from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM -import codecs from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from psycopg3_c cimport libpq @@ -39,9 +38,10 @@ cdef class Transformer: cdef list _dumpers_maps, _loaders_maps cdef dict _dumpers, _loaders, _dumpers_cache, _loaders_cache, _load_funcs - cdef object _connection, _codec + cdef object _connection cdef PGresult _pgresult cdef int _nfields, _ntuples + cdef str _encoding cdef list _row_loaders @@ -70,7 +70,7 @@ cdef class Transformer: cdef Transformer ctx if context is None: self._connection = None - self._codec = codecs.lookup("utf8") + self._encoding = "utf-8" self._dumpers = {} self._loaders = {} self._dumpers_maps = [self._dumpers] @@ -81,7 +81,7 @@ cdef class Transformer: # for nested types: share the entire state of the parent ctx = context self._connection = ctx._connection - self._codec = ctx._codec + self._encoding = ctx.encoding self._dumpers = ctx._dumpers self._loaders = ctx._loaders self._dumpers_maps.extend(ctx._dumpers_maps) @@ -91,7 +91,7 @@ cdef class Transformer: elif isinstance(context, BaseCursor): self._connection = context.connection - self._codec = context.connection.codec + self._encoding = context.connection.pyenc self._dumpers = {} self._dumpers_maps.extend( (self._dumpers, context.dumpers, self.connection.dumpers) @@ -103,7 +103,7 @@ cdef class Transformer: elif isinstance(context, BaseConnection): self._connection = context - self._codec = context.codec + self._encoding = context.pyenc self._dumpers = {} self._dumpers_maps.extend((self._dumpers, context.dumpers)) self._loaders = {} @@ -117,8 +117,8 @@ cdef class Transformer: return self._connection @property - def codec(self): - return self._codec + def encoding(self): + return self._encoding @property def dumpers(self): diff --git a/psycopg3_c/psycopg3_c/types/text.pyx b/psycopg3_c/psycopg3_c/types/text.pyx index a7f6f6f76..547fe8c64 100644 --- a/psycopg3_c/psycopg3_c/types/text.pyx +++ b/psycopg3_c/psycopg3_c/types/text.pyx @@ -4,38 +4,39 @@ Cython adapters for textual types. # Copyright (C) 2020 The Psycopg Team -from cpython.bytes cimport PyBytes_FromStringAndSize -from cpython.unicode cimport PyUnicode_DecodeUTF8 +from cpython.unicode cimport PyUnicode_Decode, PyUnicode_DecodeUTF8 from psycopg3_c cimport libpq cdef class TextLoader(CLoader): cdef int is_utf8 - cdef object pydecoder + cdef char *encoding + cdef bytes _bytes_encoding # needed to keep `encoding` alive def __init__(self, oid: int, context: "AdaptContext" = None): super().__init__(oid, context) self.is_utf8 = 0 - self.pydecoder = None + self.encoding = NULL + conn = self.connection if conn is not None: if conn.client_encoding == "UTF8": self.is_utf8 = 1 elif conn.client_encoding != "SQL_ASCII": - self.pydecoder = conn.codec.decode + self._bytes_encoding = conn.pyenc.encode("utf-8") + self.encoding = self._bytes_encoding else: - self.pydecoder = codecs.lookup("utf8").decode + self.encoding = "utf-8" cdef object cload(self, const char *data, size_t length): if self.is_utf8: return PyUnicode_DecodeUTF8(data, length, NULL) - b = PyBytes_FromStringAndSize(data, length) - if self.pydecoder is not None: - return self.pydecoder(b)[0] + if self.encoding: + return PyUnicode_Decode(data, length, self.encoding, NULL) else: - return b + return data[:length] cdef class ByteaLoader(CLoader): diff --git a/tests/test_connection.py b/tests/test_connection.py index 2d741a020..777a65ec6 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -190,7 +190,7 @@ def test_set_encoding(conn): def test_normalize_encoding(conn, enc, out, codec): conn.client_encoding = enc assert conn.client_encoding == out - assert conn.codec.name == codec + assert conn.pyenc == codec @pytest.mark.parametrize( @@ -207,7 +207,7 @@ def test_encoding_env_var(dsn, monkeypatch, enc, out, codec): monkeypatch.setenv("PGCLIENTENCODING", enc) conn = psycopg3.connect(dsn) assert conn.client_encoding == out - assert conn.codec.name == codec + assert conn.pyenc == codec def test_set_encoding_unsupported(conn): diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index daa8aad03..64109e5f9 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -205,7 +205,7 @@ async def test_set_encoding(aconn): async def test_normalize_encoding(aconn, enc, out, codec): await aconn.set_client_encoding(enc) assert aconn.client_encoding == out - assert aconn.codec.name == codec + assert aconn.pyenc == codec @pytest.mark.parametrize( @@ -222,7 +222,7 @@ async def test_encoding_env_var(dsn, monkeypatch, enc, out, codec): monkeypatch.setenv("PGCLIENTENCODING", enc) aconn = await psycopg3.AsyncConnection.connect(dsn) assert aconn.client_encoding == out - assert aconn.codec.name == codec + assert aconn.pyenc == codec async def test_set_encoding_unsupported(aconn): diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 981515858..23506a8ca 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -219,7 +219,7 @@ def make_testfunc(conn): ) .format(sql.Identifier(procname), sql.Identifier(paramname)) .as_string(conn) - .encode(conn.codec.name) + .encode(conn.pyenc) ) # execute regardless of sync/async conn