self._row_loaders: List[LoadFunc] = []
def _setup_context(self, context: AdaptContext) -> None:
- if context is None:
+ if not context:
self._connection = None
self._encoding = "utf-8"
self._dumpers = {}
self._ntuples: int
self._nfields: int
- if result is None:
+ if not result:
self._nfields = self._ntuples = 0
return
def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
res = self.pgresult
- if res is None:
+ if not res:
return None
if row >= self._ntuples:
f"dumpers should be registered on classes, got {src} instead"
)
- where = context.dumpers if context is not None else Dumper.globals
+ where = context.dumpers if context else Dumper.globals
where[src, format] = cls
@classmethod
f"loaders should be registered on oid, got {oid} instead"
)
- where = context.loaders if context is not None else Loader.globals
+ where = context.loaders if context else Loader.globals
where[oid, format] = cls
@classmethod
def _connection_from_context(
context: AdaptContext,
) -> Optional[BaseConnection]:
- if context is None:
+ if not context:
return None
elif isinstance(context, BaseConnection):
return context
self._autocommit = value
def cursor(
- self, name: Optional[str] = None, format: pq.Format = pq.Format.TEXT
+ self, name: str = "", format: pq.Format = pq.Format.TEXT
) -> cursor.BaseCursor:
- if name is not None:
+ if name:
raise NotImplementedError
return self.cursor_factory(self, format=format)
@property
def client_encoding(self) -> str:
rv = self.pgconn.parameter_status(b"client_encoding")
- if rv is not None:
- return rv.decode("ascii")
- else:
- return "UTF8"
+ return rv.decode("utf-8") if rv else "UTF8"
@client_encoding.setter
def client_encoding(self, value: str) -> None:
wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult
) -> None:
self = wself()
- if self is None or not self._notice_handler:
+ if not (self and self._notice_handler):
return
diag = e.Diagnostic(res, self._pyenc)
wself: "ReferenceType[BaseConnection]", pgn: pq.PGnotify
) -> None:
self = wself()
- if self is None or not self._notify_handlers:
+ if not (self and self._notify_handlers):
return
n = Notify(
self.pgconn.finish()
def cursor(
- self, name: Optional[str] = None, format: pq.Format = pq.Format.TEXT
+ self, name: str = "", format: pq.Format = pq.Format.TEXT
) -> cursor.Cursor:
cur = super().cursor(name, format=format)
return cast(cursor.Cursor, cur)
self.pgconn.finish()
def cursor(
- self, name: Optional[str] = None, format: pq.Format = pq.Format.TEXT
+ self, name: str = "", format: pq.Format = pq.Format.TEXT
) -> cursor.AsyncCursor:
cur = super().cursor(name, format=format)
return cast(cursor.AsyncCursor, cur)
# Drop the None arguments
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
- if conninfo is not None:
+ if conninfo:
tmp = conninfo_to_dict(conninfo)
tmp.update(kwargs)
kwargs = tmp
@property
def connection(self) -> "BaseConnection":
- if self._connection is not None:
+ if self._connection:
return self._connection
self._connection = conn = self._transformer.connection
- if conn is not None:
+ if conn:
return conn
raise ValueError("no connection available")
@property
def name(self) -> str:
rv = self._pgresult.fname(self._index)
- if rv is not None:
+ if rv:
return rv.decode(self._encoding)
else:
raise e.InterfaceError(
@property
def status(self) -> Optional[pq.ExecStatus]:
res = self.pgresult
- if res is not None:
- return res.status
- else:
- return None
+ return res.status if res else None
@property
def pgresult(self) -> Optional[pq.proto.PGresult]:
@pgresult.setter
def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
self._pgresult = result
- if result is not None:
- if self._transformer is not None:
- self._transformer.pgresult = result
+ if result and self._transformer:
+ self._transformer.pgresult = result
@property
def description(self) -> Optional[List[Column]]:
res = self.pgresult
- if res is None or res.status != self.ExecStatus.TUPLES_OK:
+ if not res or res.status != self.ExecStatus.TUPLES_OK:
return None
encoding = self.connection.pyenc
return [Column(res, i, encoding) for i in range(res.nfields)]
def _check_result(self) -> None:
res = self.pgresult
- if res is None:
+ if not res:
raise e.ProgrammingError("no result available")
elif res.status != self.ExecStatus.TUPLES_OK:
raise e.ProgrammingError(
self._pos += 1
return rv
- def fetchmany(self, size: Optional[int] = None) -> List[Sequence[Any]]:
+ def fetchmany(self, size: int = 0) -> List[Sequence[Any]]:
self._check_result()
- if size is None:
+ if not size:
size = self.arraysize
rv: List[Sequence[Any]] = []
self._pos += 1
return rv
- async def fetchmany(
- self, size: Optional[int] = None
- ) -> List[Sequence[Any]]:
+ async def fetchmany(self, size: int = 0) -> List[Sequence[Any]]:
self._check_result()
- if size is None:
+ if not size:
size = self.arraysize
pos = self._pos
def get_base_exception(sqlstate: str) -> Type[Error]:
- exc = _base_exc_map.get(sqlstate[:2])
- if exc is not None:
- return exc
-
- exc = _base_exc_map.get(sqlstate[0])
- if exc is not None:
- return exc
-
- return DatabaseError
+ return (
+ _base_exc_map.get(sqlstate[:2])
+ or _base_exc_map.get(sqlstate[0])
+ or DatabaseError
+ )
_base_exc_map = {
n = pgconn.notifies()
if n is None:
break
- if pgconn.notify_handler is not None:
+ if pgconn.notify_handler:
pgconn.notify_handler(n)
res = pgconn.get_result()
ns = []
while 1:
n = pgconn.notifies()
- if n is not None:
+ if n:
ns.append(n)
else:
break
from psycopg3.errors import NotSupportedError
libname = ctypes.util.find_library("pq")
-if libname is None:
+if not libname:
raise ImportError("libpq library not found")
pq = ctypes.pydll.LoadLibrary(libname)
def PQhostaddr(pgconn: type) -> bytes:
- if _PQhostaddr is not None:
+ if _PQhostaddr:
return _PQhostaddr(pgconn)
else:
raise NotSupportedError(
arg: Any, result_ptr: impl.PGresult_struct, wconn: "ref[PGconn]"
) -> None:
pgconn = wconn()
- if pgconn is None or pgconn.notice_handler is None:
+ if not (pgconn and pgconn.notice_handler):
return
res = PGresult(result_ptr)
def finish(self) -> None:
self.pgconn_ptr, p = None, self.pgconn_ptr
- if p is not None:
+ if p:
impl.PQfinish(p)
@property
if not isinstance(command, bytes):
raise TypeError(f"bytes expected, got {type(command)} instead")
- nparams = len(param_values) if param_values is not None else 0
- aparams: Optional[Array[c_char_p]] = None
- alenghts: Optional[Array[c_int]] = None
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
if param_values:
+ nparams = len(param_values)
aparams = (c_char_p * nparams)(*param_values)
alenghts = (c_int * nparams)(
- *(len(p) if p is not None else 0 for p in param_values)
+ *(len(p) if p else 0 for p in param_values)
)
+ else:
+ nparams = 0
+ aparams = alenghts = None
atypes: Optional[Array[impl.Oid]]
- if param_types is None:
+ if not param_types:
atypes = None
else:
if len(param_types) != nparams:
)
atypes = (impl.Oid * nparams)(*param_types)
- if param_formats is None:
+ if not param_formats:
aformats = None
else:
if len(param_formats) != nparams:
f"'command' must be bytes, got {type(command)} instead"
)
- if param_types is None:
+ if not param_types:
nparams = 0
atypes = None
else:
if not isinstance(name, bytes):
raise TypeError(f"'name' must be bytes, got {type(name)} instead")
- nparams = len(param_values) if param_values is not None else 0
- aparams: Optional[Array[c_char_p]] = None
- alenghts: Optional[Array[c_int]] = None
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
if param_values:
+ nparams = len(param_values)
aparams = (c_char_p * nparams)(*param_values)
alenghts = (c_int * nparams)(
- *(len(p) if p is not None else 0 for p in param_values)
+ *(len(p) if p else 0 for p in param_values)
)
+ else:
+ nparams = 0
+ aparams = alenghts = None
- if param_formats is None:
+ if not param_formats:
aformats = None
else:
if len(param_formats) != nparams:
def clear(self) -> None:
self.pgresult_ptr, p = None, self.pgresult_ptr
- if p is not None:
+ if p:
impl.PQclear(p)
@property
def free(self) -> None:
self.pgcancel_ptr, p = None, self.pgcancel_ptr
- if p is not None:
+ if p:
impl.PQfreeCancel(p)
def cancel(self) -> None:
self.conn = conn
def escape_literal(self, data: bytes) -> bytes:
- if self.conn is not None:
+ if self.conn:
self.conn._ensure_pgconn()
out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
if not out:
raise PQerror("escape_literal failed: no connection provided")
def escape_identifier(self, data: bytes) -> bytes:
- if self.conn is not None:
+ if self.conn:
self.conn._ensure_pgconn()
out = impl.PQescapeIdentifier(
self.conn.pgconn_ptr, data, len(data)
raise PQerror("escape_identifier failed: no connection provided")
def escape_string(self, data: bytes) -> bytes:
- if self.conn is not None:
+ if self.conn:
self.conn._ensure_pgconn()
error = c_int()
out = create_string_buffer(len(data) * 2 + 1)
def escape_bytea(self, data: bytes) -> bytes:
len_out = c_size_t()
- if self.conn is not None:
+ if self.conn:
self.conn._ensure_pgconn()
out = impl.PQescapeByteaConn(
self.conn.pgconn_ptr,
def unescape_bytea(self, data: bytes) -> bytes:
# not needed, but let's keep it symmetric with the escaping:
# if a connection is passed in, it must be valid.
- if self.conn is not None:
+ if self.conn:
self.conn._ensure_pgconn()
len_out = c_size_t()
"""
- def __init__(
- self, name: Optional[str] = None, format: Format = Format.TEXT
- ):
+ def __init__(self, name: str = "", format: Format = Format.TEXT):
super().__init__(name)
- if isinstance(name, str):
- if ")" in name:
- raise ValueError("invalid name: %r" % name)
+ if not isinstance(name, str):
+ raise TypeError(f"expected string as name, got {name!r}")
- elif name is not None:
- raise TypeError("expected string or None as name, got %r" % name)
+ if ")" in name:
+ raise ValueError(f"invalid name: {name!r}")
self._format = format
def as_string(self, context: AdaptContext) -> str:
code = "s" if self._format == Format.TEXT else "b"
- if self._obj is not None:
- return f"%({self._obj}){code}"
- else:
- return f"%{code}"
+ return f"%({self._obj}){code}" if self._obj else f"%{code}"
# Literals
oid = 0
if base_oid:
info = builtins.get(base_oid)
- if info is not None:
+ if info:
oid = info.array_oid
return oid or TEXT_ARRAY_OID
def dump(self, obj: List[Any]) -> bytes:
tokens: List[bytes] = []
- oid: Optional[int] = None
+ oid = 0
def dump_list(obj: List[Any]) -> None:
nonlocal oid
elif item is not None:
dumper = self._tx.get_dumper(item, Format.TEXT)
ad = dumper.dump(item)
- if self._re_needs_quotes.search(ad) is not None:
+ if self._re_needs_quotes.search(ad):
ad = b'"' + self._re_escape.sub(br"\\\1", ad) + b'"'
tokens.append(ad)
- if oid is None:
+ if not oid:
oid = dumper.oid
else:
tokens.append(b"NULL")
dump_list(obj)
- if oid is not None:
+ if oid:
self._array_oid = self._get_array_oid(oid)
return b"".join(tokens)
context: AdaptContext = None,
factory: Optional[Callable[..., Any]] = None,
) -> None:
- if factory is None:
+ if not factory:
factory = namedtuple( # type: ignore
info.name, [f.name for f in info.fields]
)
dumper = self._tx.get_dumper(item, Format.TEXT)
ad = dumper.dump(item)
- if self._re_needs_quotes.search(ad) is not None:
+ if self._re_needs_quotes.search(ad):
ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"'
parts.append(ad)
return
for m in self._re_tokenize.finditer(data):
- if m.group(1) is not None:
+ if m.group(1):
yield None
elif m.group(2) is not None:
yield self._re_undouble.sub(br"\1", m.group(2))
def __init__(self, oid: int, context: AdaptContext):
super().__init__(oid, context)
- if self.connection is not None:
+ if self.connection:
if self.connection.client_encoding != "SQL_ASCII":
self.encoding = self.connection.pyenc
else:
def __init__(self, src: type, context: AdaptContext = None):
super().__init__(src, context)
self.esc = Escaping(
- self.connection.pgconn if self.connection is not None else None
+ self.connection.pgconn if self.connection else None
)
def dump(self, obj: bytes) -> bytes:
if self.types is None:
self.types = []
- for i in range(len(params)):
- param = params[i]
+ for i, param in enumerate(params):
if param is not None:
dumper = self._tx.get_dumper(param, self.formats[i])
self.params.append(dumper.dump(param))
self.params.append(None)
self.types.append(UNKNOWN_OID)
else:
- for i in range(len(params)):
- param = params[i]
+ for i, param in enumerate(params):
if param is not None:
dumper = self._tx.get_dumper(param, self.formats[i])
self.params.append(dumper.dump(param))
pre = query[cur : m.span(0)[0]]
parts.append((pre, m))
cur = m.span(0)[1]
- if m is None:
- parts.append((query, None))
- else:
+ if m:
parts.append((query[cur:], None))
+ else:
+ parts.append((query, None))
rv = []
# Index or name
item: Union[int, str]
- item = i if m.group(1) is None else m.group(1).decode(encoding)
+ item = m.group(1).decode(encoding) if m.group(1) else i
- if phtype is None:
+ if not phtype:
phtype = type(item)
- else:
- if phtype is not type(item): # noqa
- raise e.ProgrammingError(
- "positional and named placeholders cannot be mixed"
- )
+ elif phtype is not type(item):
+ raise e.ProgrammingError(
+ "positional and named placeholders cannot be mixed"
+ )
# Binary format
format = Format(ph[-1:] == b"b")
with open("psycopg3/version.py") as f:
data = f.read()
m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data)
- if m is None:
+ if not m:
raise Exception(f"cannot find version in {f.name}")
version = m.group(1)
def test_rowcount(conn):
cur = conn.cursor()
+
+ cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
cur.execute("select 1 from generate_series(1, 42)")
assert cur.rowcount == 42