]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dropped excessive use of is [not] None
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 7 Nov 2020 02:24:09 +0000 (02:24 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 7 Nov 2020 13:19:10 +0000 (13:19 +0000)
It was done with the assumption that it is faster, but it is only in
C/Cython, not in pure Python.

17 files changed:
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/conninfo.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/errors.py
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/pq/_pq_ctypes.py
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/sql.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/text.py
psycopg3/psycopg3/utils/queries.py
psycopg3/setup.py
tests/test_cursor.py

index b10aa7cf4feb209083f7f3dc6f68be4245cdb3c2..dfe79fbd4a8134f205231b03d939df5bbb01bea7 100644 (file)
@@ -53,7 +53,7 @@ class Transformer:
         self._row_loaders: List[LoadFunc] = []
 
     def _setup_context(self, context: AdaptContext) -> None:
-        if context is None:
+        if not context:
             self._connection = None
             self._encoding = "utf-8"
             self._dumpers = {}
@@ -117,7 +117,7 @@ class Transformer:
 
         self._ntuples: int
         self._nfields: int
-        if result is None:
+        if not result:
             self._nfields = self._ntuples = 0
             return
 
@@ -168,7 +168,7 @@ class Transformer:
 
     def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
         res = self.pgresult
-        if res is None:
+        if not res:
             return None
 
         if row >= self._ntuples:
index f0b9e14fd1c81bf71483293f978e81390a724812..6b40fee270f22da6a552f23c0f819b4630dc689e 100644 (file)
@@ -55,7 +55,7 @@ class Dumper:
                 f"dumpers should be registered on classes, got {src} instead"
             )
 
-        where = context.dumpers if context is not None else Dumper.globals
+        where = context.dumpers if context else Dumper.globals
         where[src, format] = cls
 
     @classmethod
@@ -103,7 +103,7 @@ class Loader:
                 f"loaders should be registered on oid, got {oid} instead"
             )
 
-        where = context.loaders if context is not None else Loader.globals
+        where = context.loaders if context else Loader.globals
         where[oid, format] = cls
 
     @classmethod
@@ -130,7 +130,7 @@ class Loader:
 def _connection_from_context(
     context: AdaptContext,
 ) -> Optional[BaseConnection]:
-    if context is None:
+    if not context:
         return None
     elif isinstance(context, BaseConnection):
         return context
index a720528ababedb31914527fa893a74a8690e7300..d63cb393bd96e3d50a441c14d5a6925645c5d34b 100644 (file)
@@ -120,9 +120,9 @@ class BaseConnection:
         self._autocommit = value
 
     def cursor(
-        self, name: Optional[str] = None, format: pq.Format = pq.Format.TEXT
+        self, name: str = "", format: pq.Format = pq.Format.TEXT
     ) -> cursor.BaseCursor:
-        if name is not None:
+        if name:
             raise NotImplementedError
         return self.cursor_factory(self, format=format)
 
@@ -144,10 +144,7 @@ class BaseConnection:
     @property
     def client_encoding(self) -> str:
         rv = self.pgconn.parameter_status(b"client_encoding")
-        if rv is not None:
-            return rv.decode("ascii")
-        else:
-            return "UTF8"
+        return rv.decode("utf-8") if rv else "UTF8"
 
     @client_encoding.setter
     def client_encoding(self, value: str) -> None:
@@ -171,7 +168,7 @@ class BaseConnection:
         wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult
     ) -> None:
         self = wself()
-        if self is None or not self._notice_handler:
+        if not (self and self._notice_handler):
             return
 
         diag = e.Diagnostic(res, self._pyenc)
@@ -194,7 +191,7 @@ class BaseConnection:
         wself: "ReferenceType[BaseConnection]", pgn: pq.PGnotify
     ) -> None:
         self = wself()
-        if self is None or not self._notify_handlers:
+        if not (self and self._notify_handlers):
             return
 
         n = Notify(
@@ -241,7 +238,7 @@ class Connection(BaseConnection):
         self.pgconn.finish()
 
     def cursor(
-        self, name: Optional[str] = None, format: pq.Format = pq.Format.TEXT
+        self, name: str = "", format: pq.Format = pq.Format.TEXT
     ) -> cursor.Cursor:
         cur = super().cursor(name, format=format)
         return cast(cursor.Cursor, cur)
@@ -356,7 +353,7 @@ class AsyncConnection(BaseConnection):
         self.pgconn.finish()
 
     def cursor(
-        self, name: Optional[str] = None, format: pq.Format = pq.Format.TEXT
+        self, name: str = "", format: pq.Format = pq.Format.TEXT
     ) -> cursor.AsyncCursor:
         cur = super().cursor(name, format=format)
         return cast(cursor.AsyncCursor, cur)
index f5f07f3616e453f370dd5e5faad34eafa91814c3..a5c431e58fa9d5bb30e2cba60dcc7fdd3e1c78d4 100644 (file)
@@ -24,7 +24,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     # Drop the None arguments
     kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
 
-    if conninfo is not None:
+    if conninfo:
         tmp = conninfo_to_dict(conninfo)
         tmp.update(kwargs)
         kwargs = tmp
index 6cd89e533d250d31d50e81918483103decd1e948..0c105850bd8baef4b584ae4a08f1702677261aed 100644 (file)
@@ -46,11 +46,11 @@ class BaseCopy:
 
     @property
     def connection(self) -> "BaseConnection":
-        if self._connection is not None:
+        if self._connection:
             return self._connection
 
         self._connection = conn = self._transformer.connection
-        if conn is not None:
+        if conn:
             return conn
 
         raise ValueError("no connection available")
index c062d2ae77aa01e97e07bc1bb2b1a4c7781462cd..fda21b1f82206d752aa37bd70f12614c1aba7b4f 100644 (file)
@@ -78,7 +78,7 @@ class Column(Sequence[Any]):
     @property
     def name(self) -> str:
         rv = self._pgresult.fname(self._index)
-        if rv is not None:
+        if rv:
             return rv.decode(self._encoding)
         else:
             raise e.InterfaceError(
@@ -170,10 +170,7 @@ class BaseCursor:
     @property
     def status(self) -> Optional[pq.ExecStatus]:
         res = self.pgresult
-        if res is not None:
-            return res.status
-        else:
-            return None
+        return res.status if res else None
 
     @property
     def pgresult(self) -> Optional[pq.proto.PGresult]:
@@ -182,14 +179,13 @@ class BaseCursor:
     @pgresult.setter
     def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
         self._pgresult = result
-        if result is not None:
-            if self._transformer is not None:
-                self._transformer.pgresult = result
+        if result and self._transformer:
+            self._transformer.pgresult = result
 
     @property
     def description(self) -> Optional[List[Column]]:
         res = self.pgresult
-        if res is None or res.status != self.ExecStatus.TUPLES_OK:
+        if not res or res.status != self.ExecStatus.TUPLES_OK:
             return None
         encoding = self.connection.pyenc
         return [Column(res, i, encoding) for i in range(res.nfields)]
@@ -319,7 +315,7 @@ class BaseCursor:
 
     def _check_result(self) -> None:
         res = self.pgresult
-        if res is None:
+        if not res:
             raise e.ProgrammingError("no result available")
         elif res.status != self.ExecStatus.TUPLES_OK:
             raise e.ProgrammingError(
@@ -474,9 +470,9 @@ class Cursor(BaseCursor):
             self._pos += 1
         return rv
 
-    def fetchmany(self, size: Optional[int] = None) -> List[Sequence[Any]]:
+    def fetchmany(self, size: int = 0) -> List[Sequence[Any]]:
         self._check_result()
-        if size is None:
+        if not size:
             size = self.arraysize
 
         rv: List[Sequence[Any]] = []
@@ -602,11 +598,9 @@ class AsyncCursor(BaseCursor):
             self._pos += 1
         return rv
 
-    async def fetchmany(
-        self, size: Optional[int] = None
-    ) -> List[Sequence[Any]]:
+    async def fetchmany(self, size: int = 0) -> List[Sequence[Any]]:
         self._check_result()
-        if size is None:
+        if not size:
             size = self.arraysize
 
         pos = self._pos
index edec693d9f0dcd4ccf6156e40da3fe9ba975f2c3..db610f69858c7d02e3334005da779bd1a9568876 100644 (file)
@@ -258,15 +258,11 @@ def _class_for_state(sqlstate: str) -> Type[Error]:
 
 
 def get_base_exception(sqlstate: str) -> Type[Error]:
-    exc = _base_exc_map.get(sqlstate[:2])
-    if exc is not None:
-        return exc
-
-    exc = _base_exc_map.get(sqlstate[0])
-    if exc is not None:
-        return exc
-
-    return DatabaseError
+    return (
+        _base_exc_map.get(sqlstate[:2])
+        or _base_exc_map.get(sqlstate[0])
+        or DatabaseError
+    )
 
 
 _base_exc_map = {
index 7ecb11ad04f54ba1f2681d9193b92641d8a1006d..0cd7abcf5f56caab9d22de75e3757d2495369387 100644 (file)
@@ -121,7 +121,7 @@ def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]:
             n = pgconn.notifies()
             if n is None:
                 break
-            if pgconn.notify_handler is not None:
+            if pgconn.notify_handler:
                 pgconn.notify_handler(n)
 
         res = pgconn.get_result()
@@ -143,7 +143,7 @@ def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]:
     ns = []
     while 1:
         n = pgconn.notifies()
-        if n is not None:
+        if n:
             ns.append(n)
         else:
             break
index 9aaa39fb893dd7ec7e6bda75955efad54b92ef26..88947ca14dce42366249a2f1ce24c6ee462b1d70 100644 (file)
@@ -13,7 +13,7 @@ from typing import List, Tuple
 from psycopg3.errors import NotSupportedError
 
 libname = ctypes.util.find_library("pq")
-if libname is None:
+if not libname:
     raise ImportError("libpq library not found")
 
 pq = ctypes.pydll.LoadLibrary(libname)
@@ -172,7 +172,7 @@ if libpq_version >= 120000:
 
 
 def PQhostaddr(pgconn: type) -> bytes:
-    if _PQhostaddr is not None:
+    if _PQhostaddr:
         return _PQhostaddr(pgconn)
     else:
         raise NotSupportedError(
index 4147436d9fa1c3b2a5ec97a09904f5d94de821ce..70cdedd756c1e95ed8ad53fd44347668175c7139 100644 (file)
@@ -47,7 +47,7 @@ def notice_receiver(
     arg: Any, result_ptr: impl.PGresult_struct, wconn: "ref[PGconn]"
 ) -> None:
     pgconn = wconn()
-    if pgconn is None or pgconn.notice_handler is None:
+    if not (pgconn and pgconn.notice_handler):
         return
 
     res = PGresult(result_ptr)
@@ -115,7 +115,7 @@ class PGconn:
 
     def finish(self) -> None:
         self.pgconn_ptr, p = None, self.pgconn_ptr
-        if p is not None:
+        if p:
             impl.PQfinish(p)
 
     @property
@@ -330,17 +330,20 @@ class PGconn:
         if not isinstance(command, bytes):
             raise TypeError(f"bytes expected, got {type(command)} instead")
 
-        nparams = len(param_values) if param_values is not None else 0
-        aparams: Optional[Array[c_char_p]] = None
-        alenghts: Optional[Array[c_int]] = None
+        aparams: Optional[Array[c_char_p]]
+        alenghts: Optional[Array[c_int]]
         if param_values:
+            nparams = len(param_values)
             aparams = (c_char_p * nparams)(*param_values)
             alenghts = (c_int * nparams)(
-                *(len(p) if p is not None else 0 for p in param_values)
+                *(len(p) if p else 0 for p in param_values)
             )
+        else:
+            nparams = 0
+            aparams = alenghts = None
 
         atypes: Optional[Array[impl.Oid]]
-        if param_types is None:
+        if not param_types:
             atypes = None
         else:
             if len(param_types) != nparams:
@@ -350,7 +353,7 @@ class PGconn:
                 )
             atypes = (impl.Oid * nparams)(*param_types)
 
-        if param_formats is None:
+        if not param_formats:
             aformats = None
         else:
             if len(param_formats) != nparams:
@@ -385,7 +388,7 @@ class PGconn:
                 f"'command' must be bytes, got {type(command)} instead"
             )
 
-        if param_types is None:
+        if not param_types:
             nparams = 0
             atypes = None
         else:
@@ -408,16 +411,19 @@ class PGconn:
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
 
-        nparams = len(param_values) if param_values is not None else 0
-        aparams: Optional[Array[c_char_p]] = None
-        alenghts: Optional[Array[c_int]] = None
+        aparams: Optional[Array[c_char_p]]
+        alenghts: Optional[Array[c_int]]
         if param_values:
+            nparams = len(param_values)
             aparams = (c_char_p * nparams)(*param_values)
             alenghts = (c_int * nparams)(
-                *(len(p) if p is not None else 0 for p in param_values)
+                *(len(p) if p else 0 for p in param_values)
             )
+        else:
+            nparams = 0
+            aparams = alenghts = None
 
-        if param_formats is None:
+        if not param_formats:
             aformats = None
         else:
             if len(param_formats) != nparams:
@@ -575,7 +581,7 @@ class PGresult:
 
     def clear(self) -> None:
         self.pgresult_ptr, p = None, self.pgresult_ptr
-        if p is not None:
+        if p:
             impl.PQclear(p)
 
     @property
@@ -680,7 +686,7 @@ class PGcancel:
 
     def free(self) -> None:
         self.pgcancel_ptr, p = None, self.pgcancel_ptr
-        if p is not None:
+        if p:
             impl.PQfreeCancel(p)
 
     def cancel(self) -> None:
@@ -746,7 +752,7 @@ class Escaping:
         self.conn = conn
 
     def escape_literal(self, data: bytes) -> bytes:
-        if self.conn is not None:
+        if self.conn:
             self.conn._ensure_pgconn()
             out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
             if not out:
@@ -761,7 +767,7 @@ class Escaping:
             raise PQerror("escape_literal failed: no connection provided")
 
     def escape_identifier(self, data: bytes) -> bytes:
-        if self.conn is not None:
+        if self.conn:
             self.conn._ensure_pgconn()
             out = impl.PQescapeIdentifier(
                 self.conn.pgconn_ptr, data, len(data)
@@ -778,7 +784,7 @@ class Escaping:
             raise PQerror("escape_identifier failed: no connection provided")
 
     def escape_string(self, data: bytes) -> bytes:
-        if self.conn is not None:
+        if self.conn:
             self.conn._ensure_pgconn()
             error = c_int()
             out = create_string_buffer(len(data) * 2 + 1)
@@ -807,7 +813,7 @@ class Escaping:
 
     def escape_bytea(self, data: bytes) -> bytes:
         len_out = c_size_t()
-        if self.conn is not None:
+        if self.conn:
             self.conn._ensure_pgconn()
             out = impl.PQescapeByteaConn(
                 self.conn.pgconn_ptr,
@@ -831,7 +837,7 @@ class Escaping:
     def unescape_bytea(self, data: bytes) -> bytes:
         # not needed, but let's keep it symmetric with the escaping:
         # if a connection is passed in, it must be valid.
-        if self.conn is not None:
+        if self.conn:
             self.conn._ensure_pgconn()
 
         len_out = c_size_t()
index 60681877c90030d7a8475147724a3312acb77f78..3b57d68cf10bec05b9e19006b521b211fb062823 100644 (file)
@@ -404,16 +404,13 @@ class Placeholder(Composable):
 
     """
 
-    def __init__(
-        self, name: Optional[str] = None, format: Format = Format.TEXT
-    ):
+    def __init__(self, name: str = "", format: Format = Format.TEXT):
         super().__init__(name)
-        if isinstance(name, str):
-            if ")" in name:
-                raise ValueError("invalid name: %r" % name)
+        if not isinstance(name, str):
+            raise TypeError(f"expected string as name, got {name!r}")
 
-        elif name is not None:
-            raise TypeError("expected string or None as name, got %r" % name)
+        if ")" in name:
+            raise ValueError(f"invalid name: {name!r}")
 
         self._format = format
 
@@ -428,10 +425,7 @@ class Placeholder(Composable):
 
     def as_string(self, context: AdaptContext) -> str:
         code = "s" if self._format == Format.TEXT else "b"
-        if self._obj is not None:
-            return f"%({self._obj}){code}"
-        else:
-            return f"%{code}"
+        return f"%({self._obj}){code}" if self._obj else f"%{code}"
 
 
 # Literals
index 955944eede5c884f66ad865f448608bbff830498..03fc21bdc0ea0f8c182afa94891fe09de9d18631 100644 (file)
@@ -38,7 +38,7 @@ class BaseListDumper(Dumper):
         oid = 0
         if base_oid:
             info = builtins.get(base_oid)
-            if info is not None:
+            if info:
                 oid = info.array_oid
 
         return oid or TEXT_ARRAY_OID
@@ -66,7 +66,7 @@ class ListDumper(BaseListDumper):
 
     def dump(self, obj: List[Any]) -> bytes:
         tokens: List[bytes] = []
-        oid: Optional[int] = None
+        oid = 0
 
         def dump_list(obj: List[Any]) -> None:
             nonlocal oid
@@ -82,10 +82,10 @@ class ListDumper(BaseListDumper):
                 elif item is not None:
                     dumper = self._tx.get_dumper(item, Format.TEXT)
                     ad = dumper.dump(item)
-                    if self._re_needs_quotes.search(ad) is not None:
+                    if self._re_needs_quotes.search(ad):
                         ad = b'"' + self._re_escape.sub(br"\\\1", ad) + b'"'
                     tokens.append(ad)
-                    if oid is None:
+                    if not oid:
                         oid = dumper.oid
                 else:
                     tokens.append(b"NULL")
@@ -96,7 +96,7 @@ class ListDumper(BaseListDumper):
 
         dump_list(obj)
 
-        if oid is not None:
+        if oid:
             self._array_oid = self._get_array_oid(oid)
 
         return b"".join(tokens)
index 11df459fc10ad915ac8d4314d3531f31e31d433c..c568438dce0b088f4b6ebe66dad0349d32059c80 100644 (file)
@@ -65,7 +65,7 @@ def register(
     context: AdaptContext = None,
     factory: Optional[Callable[..., Any]] = None,
 ) -> None:
-    if factory is None:
+    if not factory:
         factory = namedtuple(  # type: ignore
             info.name, [f.name for f in info.fields]
         )
@@ -142,7 +142,7 @@ class TupleDumper(Dumper):
 
             dumper = self._tx.get_dumper(item, Format.TEXT)
             ad = dumper.dump(item)
-            if self._re_needs_quotes.search(ad) is not None:
+            if self._re_needs_quotes.search(ad):
                 ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"'
 
             parts.append(ad)
@@ -181,7 +181,7 @@ class RecordLoader(BaseCompositeLoader):
             return
 
         for m in self._re_tokenize.finditer(data):
-            if m.group(1) is not None:
+            if m.group(1):
                 yield None
             elif m.group(2) is not None:
                 yield self._re_undouble.sub(br"\1", m.group(2))
index f164780731a374d34137aba2cf60dc291bc451e7..1cac29ffc397f7cd9c9f53046c61c4a49c34cdd3 100644 (file)
@@ -55,7 +55,7 @@ class TextLoader(Loader):
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
 
-        if self.connection is not None:
+        if self.connection:
             if self.connection.client_encoding != "SQL_ASCII":
                 self.encoding = self.connection.pyenc
             else:
@@ -92,7 +92,7 @@ class BytesDumper(Dumper):
     def __init__(self, src: type, context: AdaptContext = None):
         super().__init__(src, context)
         self.esc = Escaping(
-            self.connection.pgconn if self.connection is not None else None
+            self.connection.pgconn if self.connection else None
         )
 
     def dump(self, obj: bytes) -> bytes:
index 59550b9b5629034a175284b4cf24131dc296b425..a20273dbc9e79e3b26bae4bcb3cd59d71222b74e 100644 (file)
@@ -80,8 +80,7 @@ class PostgresQuery:
 
             if self.types is None:
                 self.types = []
-                for i in range(len(params)):
-                    param = params[i]
+                for i, param in enumerate(params):
                     if param is not None:
                         dumper = self._tx.get_dumper(param, self.formats[i])
                         self.params.append(dumper.dump(param))
@@ -90,8 +89,7 @@ class PostgresQuery:
                         self.params.append(None)
                         self.types.append(UNKNOWN_OID)
             else:
-                for i in range(len(params)):
-                    param = params[i]
+                for i, param in enumerate(params):
                     if param is not None:
                         dumper = self._tx.get_dumper(param, self.formats[i])
                         self.params.append(dumper.dump(param))
@@ -225,10 +223,10 @@ def _split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
         pre = query[cur : m.span(0)[0]]
         parts.append((pre, m))
         cur = m.span(0)[1]
-    if m is None:
-        parts.append((query, None))
-    else:
+    if m:
         parts.append((query[cur:], None))
+    else:
+        parts.append((query, None))
 
     rv = []
 
@@ -269,15 +267,14 @@ def _split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
 
         # Index or name
         item: Union[int, str]
-        item = i if m.group(1) is None else m.group(1).decode(encoding)
+        item = m.group(1).decode(encoding) if m.group(1) else i
 
-        if phtype is None:
+        if not phtype:
             phtype = type(item)
-        else:
-            if phtype is not type(item):  # noqa
-                raise e.ProgrammingError(
-                    "positional and named placeholders cannot be mixed"
-                )
+        elif phtype is not type(item):
+            raise e.ProgrammingError(
+                "positional and named placeholders cannot be mixed"
+            )
 
         # Binary format
         format = Format(ph[-1:] == b"b")
index 80bb67feb08fb900d315f1dbd95e3679fbb317d7..b7b7ef26caa12519e0a5e4fdaf1c0dfb1423b129 100644 (file)
@@ -19,7 +19,7 @@ if os.path.abspath(os.getcwd()) != here:
 with open("psycopg3/version.py") as f:
     data = f.read()
     m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data)
-    if m is None:
+    if not m:
         raise Exception(f"cannot find version in {f.name}")
     version = m.group(1)
 
index 23506a8cac6ca6edca9f43cabfc4c646c3eb83ec..c5b8fd9f6541d9b88f12daaf5ba750a8d0a5d6d6 100644 (file)
@@ -263,6 +263,10 @@ def test_callproc_dict_bad(conn, args, exc):
 
 def test_rowcount(conn):
     cur = conn.cursor()
+
+    cur.execute("select 1 from generate_series(1, 0)")
+    assert cur.rowcount == 0
+
     cur.execute("select 1 from generate_series(1, 42)")
     assert cur.rowcount == 42