From: Daniele Varrazzo Date: Fri, 27 Mar 2020 12:26:57 +0000 (+1300) Subject: Added basic typecasting and cursor.fetchone() X-Git-Tag: 3.0.dev0~665 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=4b5751d7328d151a177547a3effb707abcec6fb4;p=thirdparty%2Fpsycopg.git Added basic typecasting and cursor.fetchone() --- diff --git a/psycopg3/adaptation.py b/psycopg3/adaptation.py index afc33c7d8..978bb98ba 100644 --- a/psycopg3/adaptation.py +++ b/psycopg3/adaptation.py @@ -5,6 +5,7 @@ Entry point into the adaptation system. # Copyright (C) 2020 The Psycopg Team import codecs +from functools import partial from . import exceptions as exc from .pq import Format @@ -15,10 +16,14 @@ NUMERIC_OID = 1700 FLOAT8_INT = 701 ascii_encode = codecs.lookup("ascii").encode +ascii_decode = codecs.lookup("ascii").decode utf8_codec = codecs.lookup("utf-8") +global_adapters = {} +global_casters = {} + -class ValuesAdapter: +class ValuesTransformer: """ An object that can adapt efficiently a number of value. @@ -43,11 +48,38 @@ class ValuesAdapter: else: raise TypeError( f"the context should be a connection or cursor," - f" got {type(context).__name__}") + f" got {type(context).__name__}" + ) - # mapping class -> adaptation function + # mapping class, fmt -> adaptation function self._adapt_funcs = {} + # mapping oid, fmt -> cast function + self._cast_funcs = {} + + # The result to return values from + self._result = None + + # sequence of cast function from value to python + # the length of the result columns + self._row_casters = None + + @property + def result(self): + return self._result + + @result.setter + def result(self, result): + if self._result is result: + return + + rc = self._row_casters = [] + for c in range(result.nfields): + oid = result.ftype(c) + fmt = result.fformat(c) + func = self.get_cast_function(oid, fmt) + rc.append(func) + def adapt_sequence(self, objs, fmts): out = [] types = [] @@ -64,39 +96,35 @@ class ValuesAdapter: return None, TEXT_OID cls = type(obj) + func = self.get_adapt_function(cls, fmt) + return func(obj) + + def get_adapt_function(self, cls, fmt): try: - func = self._adapt_funcs[cls, fmt] + return self._adapt_funcs[cls, fmt] except KeyError: pass - else: - return func(obj) - adapter = self.lookup_adapter(cls) + xf = self.lookup_adapter(cls) if fmt == Format.TEXT: - func = self._adapt_funcs[cls, fmt] = adapter.get_text_adapter( + func = self._adapt_funcs[cls, fmt] = xf.get_text_adapter( cls, self.connection ) else: assert fmt == Format.BINARY - func = self._adapt_funcs[cls, fmt] = adapter.get_binary_adapter( + func = self._adapt_funcs[cls, fmt] = xf.get_binary_adapter( cls, self.connection ) - return func(obj) + return func def lookup_adapter(self, cls): cur = self.cursor - if ( - cur is not None - and cls in cur.adapters - ): + if cur is not None and cls in cur.adapters: return cur.adapters[cls] conn = self.connection - if ( - conn is not None - and cls in conn.adapters - ): + if conn is not None and cls in conn.adapters: return conn.adapters[cls] if cls in global_adapters: @@ -104,11 +132,50 @@ class ValuesAdapter: raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}") + def cast_row(self, result, n): + self.result = result -global_adapters = {} + for col, func in enumerate(self._row_casters): + v = result.get_value(n, col) + if v is not None: + v = func(v) + yield v + + def get_cast_function(self, oid, fmt): + try: + return self._cast_funcs[oid, fmt] + except KeyError: + pass + + xf = self.lookup_caster(oid) + if fmt == Format.TEXT: + func = self._cast_funcs[oid, fmt] = xf.get_text_caster( + oid, self.connection + ) + else: + assert fmt == Format.BINARY + func = self._cast_funcs[oid, fmt] = xf.get_binary_caster( + oid, self.connection + ) + + return func + def lookup_caster(self, oid): + cur = self.cursor + if cur is not None and oid in cur.casters: + return cur.casters[oid] + + conn = self.connection + if conn is not None and oid in conn.casters: + return conn.casters[oid] -class Adapter: + if oid in global_casters: + return global_casters[oid] + else: + return UnknownCaster() + + +class Transformer: def get_text_adapter(self, cls, conn): raise exc.NotSupportedError( f"the type {cls.__name__} doesn't support text adaptation" @@ -119,29 +186,85 @@ class Adapter: f"the type {cls.__name__} doesn't support binary adaptation" ) + def get_text_caster(self, oid, conn): + raise exc.NotSupportedError( + f"the PostgreSQL type {oid} doesn't support cast from text" + ) + + def get_binary_caster(self, oid, conn): + raise exc.NotSupportedError( + f"the PostgreSQL type {oid} doesn't support cast from binary" + ) + + @staticmethod + def cast_to_bytes(value): + return value + + @staticmethod + def cast_to_str(codec, value): + return codec.decode(value)[0] -class StringAdapter(Adapter): + +class StringTransformer(Transformer): def get_text_adapter(self, cls, conn): codec = conn.codec if conn is not None else utf8_codec + return partial(self.adapt_str, codec) - def adapt_text(value): - return codec.encode(value)[0], TEXT_OID + def get_text_caster(self, oid, conn): + if conn is None or conn.pgenc == b"SQL_ASCII": + # we don't have enough info to decode bytes + return self.unparsed_bytes - return adapt_text + codec = conn.codec + return partial(self.cast_to_str, codec) # format is the same in binary and text get_binary_adapter = get_text_adapter + get_binary_caster = get_text_caster + + @staticmethod + def adapt_str(codec, value): + return codec.encode(value)[0], TEXT_OID -global_adapters[str] = StringAdapter() +global_adapters[str] = global_casters[TEXT_OID] = StringTransformer() -class IntAdapter(Adapter): +class IntTransformer(Transformer): def get_text_adapter(self, cls, conn): return self.adapt_int - def adapt_int(self, value): + def get_text_caster(self, oid, conn): + return self.cast_int + + @staticmethod + def adapt_int(value): return ascii_encode(str(value))[0], NUMERIC_OID + @staticmethod + def cast_int(value): + return int(ascii_decode(value)[0]) + + +global_casters[NUMERIC_OID] = global_adapters[int] = IntTransformer() + + +class UnknownCaster(Transformer): + """ + Fallback object to convert unknown types to Python + """ + + def get_text_caster(self, oid, conn): + if conn is None: + # we don't have enough info to decode bytes + return self.cast_to_bytes + + codec = conn.codec + return partial(self.cast_to_str, codec) + + def get_binary_caster(self, oid, conn): + return self.cast_to_bytes -global_adapters[int] = IntAdapter() + @staticmethod + def cast_to_str(codec, value): + return codec.decode(value)[0] diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 874e2e448..eff6503e4 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -30,8 +30,9 @@ class BaseConnection: self.pgconn = pgconn self.cursor_factory = None self.adapters = {} + self.casters = {} # name of the postgres encoding (in bytes) - self._pgenc = None + self.pgenc = None def cursor(self, name=None): return self.cursor_factory(self) @@ -40,10 +41,11 @@ class BaseConnection: def codec(self): # TODO: utf8 fastpath? pgenc = self.pgconn.parameter_status(b"client_encoding") - if self._pgenc != pgenc: + if self.pgenc != pgenc: # for unknown encodings and SQL_ASCII be strict and use ascii pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii") self._codec = codecs.lookup(pyenc) + self.pgenc = pgenc return self._codec def encode(self, s): diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index 3bbb8a96c..030948d74 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -6,7 +6,7 @@ psycopg3 cursor objects from . import exceptions as exc from .pq import error_message, DiagnosticField, ExecStatus -from .adaptation import ValuesAdapter +from .adaptation import ValuesTransformer from .utils.queries import query2pg, reorder_params @@ -15,15 +15,20 @@ class BaseCursor: self.conn = conn self.binary = binary self.adapters = {} + self.casters = {} + self._reset() + + def _reset(self): self._results = [] self._result = None + self._pos = 0 self._iresult = 0 + self._transformer = ValuesTransformer(self) def _execute_send(self, query, vars): # Implement part of execute() before waiting common to sync and async - self._results = [] - self._result = None - self._iresult = 0 + self._reset() + codec = self.conn.codec if isinstance(query, str): @@ -35,8 +40,7 @@ class BaseCursor: if vars: if order is not None: vars = reorder_params(vars, order) - adapter = ValuesAdapter(self) - params, types = adapter.adapt_sequence(vars, formats) + params, types = self._transformer.adapt_sequence(vars, formats) self.conn.pgconn.send_query_params( query, params, param_formats=formats, param_types=types ) @@ -84,8 +88,23 @@ class BaseCursor: self._iresult += 1 if self._iresult < len(self._results): self._result = self._results[self._iresult] + self._pos = 0 return True + def fetchone(self): + rv = self._cast_row(self._pos) + if rv is not None: + self._pos += 1 + return rv + + def _cast_row(self, n): + if self._result is None: + return None + if n >= self._result.ntuples: + return None + + return tuple(self._transformer.cast_row(self._result, n)) + class Cursor(BaseCursor): def execute(self, query, vars=None): diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 52fc3f817..f99c6868e 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -18,3 +18,14 @@ def test_execute_sequence(conn): assert cur._result.get_value(0, 1) == b"foo" assert cur._result.get_value(0, 2) is None assert cur.nextset() is None + + +def test_fetchone(conn): + cur = conn.cursor() + cur.execute("select %s, %s, %s", [1, "foo", None]) + row = cur.fetchone() + assert row[0] == 1 + assert row[1] == "foo" + assert row[2] is None + row = cur.fetchone() + assert row is None