From: Daniele Varrazzo Date: Sat, 11 Apr 2020 04:54:47 +0000 (+1200) Subject: Added barebone implementation of dbapi 2.0 X-Git-Tag: 3.0.dev0~582 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=be4e3ee85dc773878d67edf0b7880824bf7059aa;p=thirdparty%2Fpsycopg.git Added barebone implementation of dbapi 2.0 Several methods not tested yet, fetch method not added yet to async cursor. --- diff --git a/psycopg3/__init__.py b/psycopg3/__init__.py index 291c32519..23488315c 100644 --- a/psycopg3/__init__.py +++ b/psycopg3/__init__.py @@ -23,6 +23,9 @@ from .errors import ( # register default adapters from . import types # noqa +from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING +from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks +from .dbapi20 import Timestamp, TimestampFromTicks # DBAPI compliancy connect = Connection.connect @@ -30,18 +33,12 @@ apilevel = "2.0" threadsafety = 2 paramstyle = "pyformat" -__all__ = [ - "Warning", - "Error", - "InterfaceError", - "DatabaseError", - "DataError", - "OperationalError", - "IntegrityError", - "InternalError", - "ProgrammingError", - "NotSupportedError", - "AsyncConnection", - "Connection", - "connect", -] +__all__ = ( + ["Warning", "Error", "InterfaceError", "DatabaseError", "DataError"] + + ["OperationalError", "IntegrityError", "InternalError"] + + ["ProgrammingError", "NotSupportedError"] + + ["AsyncConnection", "Connection", "connect"] + + ["BINARY", "DATETIME", "NUMBER", "ROWID", "STRING"] + + ["Binary", "Date", "DateFromTicks", "Time", "TimeFromTicks"] + + ["Timestamp", "TimestampFromTicks"] +) diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 9976fe014..1ce4f62b2 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -35,6 +35,18 @@ class BaseConnection: allow different interfaces (sync/async). """ + # DBAPI2 exposed exceptions + Warning = e.Warning + Error = e.Error + InterfaceError = e.InterfaceError + DatabaseError = e.DatabaseError + DataError = e.DataError + OperationalError = e.OperationalError + IntegrityError = e.IntegrityError + InternalError = e.InternalError + ProgrammingError = e.ProgrammingError + NotSupportedError = e.NotSupportedError + def __init__(self, pgconn: pq.PGconn): self.pgconn = pgconn self.cursor_factory = cursor.BaseCursor @@ -43,6 +55,9 @@ class BaseConnection: # name of the postgres encoding (in bytes) self._pgenc = b"" + def close(self) -> None: + self.pgconn.finish() + def cursor( self, name: Optional[str] = None, binary: bool = False ) -> cursor.BaseCursor: diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index c6ae1fea2..d4309b7c2 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -4,23 +4,59 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team +import codecs +from operator import attrgetter from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING from . import errors as e -from .pq import ExecStatus, PGresult, Format +from .pq import ConnStatus, ExecStatus, PGresult, Format from .utils.queries import query2pg, reorder_params from .utils.typing import Query, Params if TYPE_CHECKING: - from .connection import ( - BaseConnection, - Connection, - AsyncConnection, - QueryGen, - ) + from .connection import BaseConnection, Connection, AsyncConnection + from .connection import QueryGen from .adapt import DumpersMap, LoadersMap +class Column(Sequence[Any]): + def __init__( + self, pgresult: PGresult, index: int, codec: codecs.CodecInfo + ): + self._pgresult = pgresult + self._index = index + self._codec = codec + + _attrs = tuple( + map( + attrgetter, + """ + name type_code display_size internal_size precision scale null_ok + """.split(), + ) + ) + + def __len__(self) -> int: + return 7 + + def __getitem__(self, index: Any) -> Any: + return self._attrs[index](self) + + @property + def name(self) -> str: + rv = self._pgresult.fname(self._index) + if rv is not None: + return self._codec.decode(rv)[0] + else: + raise e.InterfaceError( + f"no name available for column {self._index}" + ) + + @property + def type_code(self) -> int: + return self._pgresult.ftype(self._index) + + class BaseCursor: def __init__(self, conn: "BaseConnection", binary: bool = False): self.conn = conn @@ -28,6 +64,7 @@ class BaseCursor: self.dumpers: DumpersMap = {} self.loaders: LoadersMap = {} self._reset() + self.arraysize = 1 def _reset(self) -> None: from .adapt import Transformer @@ -51,10 +88,44 @@ class BaseCursor: for i in range(result.nfields) ) + @property + def description(self) -> Optional[List[Column]]: + res = self.pgresult + if res is None or res.status != ExecStatus.TUPLES_OK: + return None + return [Column(res, i, self.conn.codec) for i in range(res.nfields)] + + @property + def rowcount(self) -> int: + res = self.pgresult + if res is None or res.status != ExecStatus.TUPLES_OK: + return -1 + else: + return res.ntuples + + def setinputsizes(self, sizes: Sequence[Any]) -> None: + # no-op + pass + + def setoutputsize(self, size: Any, column: Optional[int] = None) -> None: + # no-op + pass + def _execute_send( self, query: Query, vars: Optional[Params] ) -> "QueryGen": # Implement part of execute() before waiting common to sync and async + if self.conn.pgconn.status != ConnStatus.OK: + if self.conn.pgconn.status == ConnStatus.BAD: + raise e.InterfaceError( + "cannot execute operations: the connection is closed" + ) + else: + raise e.InterfaceError( + f"cannot execute operations: the connection is" + f" in status {self.conn.pgconn.status}" + ) + self._reset() codec = self.conn.codec @@ -134,7 +205,12 @@ class BaseCursor: def _load_row(self, n: int) -> Optional[Tuple[Any, ...]]: res = self.pgresult if res is None: - return None + raise e.ProgrammingError("no result available") + elif res.status != ExecStatus.TUPLES_OK: + raise e.ProgrammingError( + "the last operation didn't produce a result" + ) + if n >= res.ntuples: return None @@ -158,12 +234,48 @@ class Cursor(BaseCursor): self._execute_results(results) return self + def executemany( + self, query: Query, vars_seq: Sequence[Params] + ) -> "Cursor": + with self.conn.lock: + # TODO: trivial implementation; use prepare + for vars in vars_seq: + gen = self._execute_send(query, vars) + results = self.conn.wait(gen) + self._execute_results(results) + return self + def fetchone(self) -> Optional[Sequence[Any]]: rv = self._load_row(self._pos) if rv is not None: self._pos += 1 return rv + def fetchmany(self, size: Optional[int] = None) -> List[Sequence[Any]]: + if size is None: + size = self.arraysize + + rv: List[Sequence[Any]] = [] + while len(rv) < size: + row = self._load_row(self._pos) + if row is None: + break + self._pos += 1 + rv.append(row) + + return rv + + def fetchall(self) -> List[Sequence[Any]]: + rv: List[Sequence[Any]] = [] + while 1: + row = self._load_row(self._pos) + if row is None: + break + self._pos += 1 + rv.append(row) + + return rv + class AsyncCursor(BaseCursor): conn: "AsyncConnection" diff --git a/psycopg3/dbapi20.py b/psycopg3/dbapi20.py new file mode 100644 index 000000000..527a02159 --- /dev/null +++ b/psycopg3/dbapi20.py @@ -0,0 +1,90 @@ +""" +Compatibility objects with DBAPI 2.0 +""" + +# Copyright (C) 2020 The Psycopg Team + +import time +import datetime as dt +from math import floor +from typing import Any, Sequence, Tuple + +from .types.oids import builtins +from .adapt import Dumper + + +class DBAPITypeObject: + def __init__(self, name: str, type_names: Sequence[str]): + self.name = name + self.values = tuple(builtins[n].oid for n in type_names) + + def __repr__(self) -> str: + return f"psycopg3.{self.name}" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, int): + return other in self.values + else: + return NotImplemented + + def __ne__(self, other: Any) -> bool: + if isinstance(other, int): + return other not in self.values + else: + return NotImplemented + + +BINARY = DBAPITypeObject("BINARY", ("bytea",)) +DATETIME = DBAPITypeObject( + "DATETIME", "timestamp timestamptz date time timetz interval".split() +) +NUMBER = DBAPITypeObject( + "NUMBER", "int2 int4 int8 float4 float8 numeric".split() +) +ROWID = DBAPITypeObject("ROWID", ("oid",)) +STRING = DBAPITypeObject("STRING", "text varchar bpchar".split()) + + +class Binary: + def __init__(self, obj: Any): + self.obj = obj + + +@Dumper.text(Binary) +def dump_Binary(obj: Binary) -> Tuple[bytes, int]: + rv = obj.obj + if not isinstance(rv, bytes): + rv = bytes(rv) + + return rv, builtins["bytea"].oid + + +def Date(year: int, month: int, day: int) -> dt.date: + return dt.date(year, month, day) + + +def DateFromTicks(ticks: float) -> dt.date: + return TimestampFromTicks(ticks).date() + + +def Time(hour: int, minute: int, second: int) -> dt.time: + return dt.time(hour, minute, second) + + +def TimeFromTicks(ticks: float) -> dt.time: + return TimestampFromTicks(ticks).time() + + +def Timestamp( + year: int, month: int, day: int, hour: int, minute: int, second: int +) -> dt.datetime: + return dt.datetime(year, month, day, hour, minute, second) + + +def TimestampFromTicks(ticks: float) -> dt.datetime: + secs = floor(ticks) + frac = ticks - secs + t = time.localtime(ticks) + tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff)) + rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo) + return rv diff --git a/tests/dbapi20.py b/tests/dbapi20.py new file mode 100644 index 000000000..86e42b53e --- /dev/null +++ b/tests/dbapi20.py @@ -0,0 +1,870 @@ +#!/usr/bin/env python +# flake8: noqa +# fmt: off +''' Python DB API 2.0 driver compliance unit test suite. + + This software is Public Domain and may be used without restrictions. + + "Now we have booze and barflies entering the discussion, plus rumours of + DBAs on drugs... and I won't tell you what flashes through my mind each + time I read the subject line with 'Anal Compliance' in it. All around + this is turning out to be a thoroughly unwholesome unit test." + + -- Ian Bicking +''' + +__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $' +__version__ = '$Revision: 1.12 $'[11:-2] +__author__ = 'Stuart Bishop ' + +import unittest +import time +import sys + + +# Revision 1.12 2009/02/06 03:35:11 kf7xm +# Tested okay with Python 3.0, includes last minute patches from Mark H. +# +# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole +# Include latest changes from main branch +# Updates for py3k +# +# Revision 1.11 2005/01/02 02:41:01 zenzen +# Update author email address +# +# Revision 1.10 2003/10/09 03:14:14 zenzen +# Add test for DB API 2.0 optional extension, where database exceptions +# are exposed as attributes on the Connection object. +# +# Revision 1.9 2003/08/13 01:16:36 zenzen +# Minor tweak from Stefan Fleiter +# +# Revision 1.8 2003/04/10 00:13:25 zenzen +# Changes, as per suggestions by M.-A. Lemburg +# - Add a table prefix, to ensure namespace collisions can always be avoided +# +# Revision 1.7 2003/02/26 23:33:37 zenzen +# Break out DDL into helper functions, as per request by David Rushby +# +# Revision 1.6 2003/02/21 03:04:33 zenzen +# Stuff from Henrik Ekelund: +# added test_None +# added test_nextset & hooks +# +# Revision 1.5 2003/02/17 22:08:43 zenzen +# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize +# defaults to 1 & generic cursor.callproc test added +# +# Revision 1.4 2003/02/15 00:16:33 zenzen +# Changes, as per suggestions and bug reports by M.-A. Lemburg, +# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar +# - Class renamed +# - Now a subclass of TestCase, to avoid requiring the driver stub +# to use multiple inheritance +# - Reversed the polarity of buggy test in test_description +# - Test exception hierarchy correctly +# - self.populate is now self._populate(), so if a driver stub +# overrides self.ddl1 this change propagates +# - VARCHAR columns now have a width, which will hopefully make the +# DDL even more portible (this will be reversed if it causes more problems) +# - cursor.rowcount being checked after various execute and fetchXXX methods +# - Check for fetchall and fetchmany returning empty lists after results +# are exhausted (already checking for empty lists if select retrieved +# nothing +# - Fix bugs in test_setoutputsize_basic and test_setinputsizes +# + +class DatabaseAPI20Test(unittest.TestCase): + ''' Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. + + The 'Optional Extensions' are not yet being tested. + + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: + + from . import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] + + Don't 'from .dbapi20 import DatabaseAPI20Test', or you will + confuse the unit tester - just 'from . import dbapi20'. + ''' + + # The self.driver module. This should be the module where the 'connect' + # method is to be found + driver = None + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + + ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix + ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix + xddl1 = 'drop table %sbooze' % table_prefix + xddl2 = 'drop table %sbarflys' % table_prefix + + lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + + # Some drivers may need to override these helpers, for example adding + # a 'commit' after the execute. + def executeDDL1(self,cursor): + cursor.execute(self.ddl1) + + def executeDDL2(self,cursor): + cursor.execute(self.ddl2) + + def setUp(self): + ''' self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + ''' + pass + + def tearDown(self): + ''' self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + ''' + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1,self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + + def _connect(self): + try: + return self.driver.connect( + *self.connect_args,**self.connect_kw_args + ) + except AttributeError: + self.fail("No connect method found in self.driver module") + + def test_connect(self): + con = self._connect() + con.close() + + def test_apilevel(self): + try: + # Must exist + apilevel = self.driver.apilevel + # Must equal 2.0 + self.assertEqual(apilevel,'2.0') + except AttributeError: + self.fail("Driver doesn't define apilevel") + + def test_threadsafety(self): + try: + # Must exist + threadsafety = self.driver.threadsafety + # Must be a valid value + self.failUnless(threadsafety in (0,1,2,3)) + except AttributeError: + self.fail("Driver doesn't define threadsafety") + + def test_paramstyle(self): + try: + # Must exist + paramstyle = self.driver.paramstyle + # Must be a valid value + self.failUnless(paramstyle in ( + 'qmark','numeric','named','format','pyformat' + )) + except AttributeError: + self.fail("Driver doesn't define paramstyle") + + def test_Exceptions(self): + # Make sure required exceptions exist, and are in the + # defined hierarchy. + if sys.version[0] == '3': #under Python 3 StardardError no longer exists + self.failUnless(issubclass(self.driver.Warning,Exception)) + self.failUnless(issubclass(self.driver.Error,Exception)) + else: + self.failUnless(issubclass(self.driver.Warning,StandardError)) + self.failUnless(issubclass(self.driver.Error,StandardError)) + + self.failUnless( + issubclass(self.driver.InterfaceError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.DatabaseError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.OperationalError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.IntegrityError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.InternalError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.ProgrammingError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.NotSupportedError,self.driver.Error) + ) + + def test_ExceptionsAsConnectionAttributes(self): + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. + con = self._connect() + drv = self.driver + self.failUnless(con.Warning is drv.Warning) + self.failUnless(con.Error is drv.Error) + self.failUnless(con.InterfaceError is drv.InterfaceError) + self.failUnless(con.DatabaseError is drv.DatabaseError) + self.failUnless(con.OperationalError is drv.OperationalError) + self.failUnless(con.IntegrityError is drv.IntegrityError) + self.failUnless(con.InternalError is drv.InternalError) + self.failUnless(con.ProgrammingError is drv.ProgrammingError) + self.failUnless(con.NotSupportedError is drv.NotSupportedError) + + + def test_commit(self): + con = self._connect() + try: + # Commit must work, even if it doesn't do anything + con.commit() + finally: + con.close() + + def test_rollback(self): + con = self._connect() + # If rollback is defined, it should either work or throw + # the documented exception + if hasattr(con,'rollback'): + try: + con.rollback() + except self.driver.NotSupportedError: + pass + + def test_cursor(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + def test_cursor_isolation(self): + con = self._connect() + try: + # Make sure cursors created from the same connection have + # the documented transaction isolation level + cur1 = con.cursor() + cur2 = con.cursor() + self.executeDDL1(cur1) + cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + cur2.execute("select name from %sbooze" % self.table_prefix) + booze = cur2.fetchall() + self.assertEqual(len(booze),1) + self.assertEqual(len(booze[0]),1) + self.assertEqual(booze[0][0],'Victoria Bitter') + finally: + con.close() + + def test_description(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.description,None, + 'cursor.description should be none after executing a ' + 'statement that can return no rows (such as DDL)' + ) + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(len(cur.description),1, + 'cursor.description describes too many columns' + ) + self.assertEqual(len(cur.description[0]),7, + 'cursor.description[x] tuples must have 7 elements' + ) + self.assertEqual(cur.description[0][0].lower(),'name', + 'cursor.description[x][0] must return column name' + ) + self.assertEqual(cur.description[0][1],self.driver.STRING, + 'cursor.description[x][1] must return column type. Got %r' + % cur.description[0][1] + ) + + # Make sure self.description gets reset + self.executeDDL2(cur) + self.assertEqual(cur.description,None, + 'cursor.description not being set to None when executing ' + 'no-result statements (eg. DDL)' + ) + finally: + con.close() + + def test_rowcount(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount should be -1 after executing no-result ' + 'statements' + ) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.failUnless(cur.rowcount in (-1,1), + 'cursor.rowcount should == number or rows inserted, or ' + 'set to -1 after executing an insert statement' + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.failUnless(cur.rowcount in (-1,1), + 'cursor.rowcount should == number of rows returned, or ' + 'set to -1 after executing a select statement' + ) + self.executeDDL2(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount not being reset to -1 after executing ' + 'no-result statements' + ) + finally: + con.close() + + lower_func = 'lower' + def test_callproc(self): + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur,'callproc'): + r = cur.callproc(self.lower_func,('FOO',)) + self.assertEqual(len(r),1) + self.assertEqual(r[0],'FOO') + r = cur.fetchall() + self.assertEqual(len(r),1,'callproc produced no result set') + self.assertEqual(len(r[0]),1, + 'callproc produced invalid result set' + ) + self.assertEqual(r[0][0],'foo', + 'callproc produced invalid results' + ) + finally: + con.close() + + def test_close(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + # cursor.execute should raise an Error if called after connection + # closed + self.assertRaises(self.driver.Error,self.executeDDL1,cur) + + # connection.commit should raise an Error if called after connection' + # closed.' + self.assertRaises(self.driver.Error,con.commit) + + # connection.close should raise an Error if called more than once + # Issue discussed on DB-SIG: consensus seem that close() should not + # raised if called on closed objects. Issue reported back to Stuart. + # self.assertRaises(self.driver.Error,con.close) + + def test_execute(self): + con = self._connect() + try: + cur = con.cursor() + self._paraminsert(cur) + finally: + con.close() + + def _paraminsert(self,cur): + self.executeDDL1(cur) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.failUnless(cur.rowcount in (-1,1)) + + if self.driver.paramstyle == 'qmark': + cur.execute( + 'insert into %sbooze values (?)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'numeric': + cur.execute( + 'insert into %sbooze values (:1)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'named': + cur.execute( + 'insert into %sbooze values (:beer)' % self.table_prefix, + {'beer':"Cooper's"} + ) + elif self.driver.paramstyle == 'format': + cur.execute( + 'insert into %sbooze values (%%s)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'pyformat': + cur.execute( + 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, + {'beer':"Cooper's"} + ) + else: + self.fail('Invalid paramstyle') + self.failUnless(cur.rowcount in (-1,1)) + + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + self.assertEqual(beers[1],"Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + + def test_executemany(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + largs = [ ("Cooper's",) , ("Boag's",) ] + margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] + if self.driver.paramstyle == 'qmark': + cur.executemany( + 'insert into %sbooze values (?)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'numeric': + cur.executemany( + 'insert into %sbooze values (:1)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'named': + cur.executemany( + 'insert into %sbooze values (:beer)' % self.table_prefix, + margs + ) + elif self.driver.paramstyle == 'format': + cur.executemany( + 'insert into %sbooze values (%%s)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'pyformat': + cur.executemany( + 'insert into %sbooze values (%%(beer)s)' % ( + self.table_prefix + ), + margs + ) + else: + self.fail('Unknown paramstyle') + self.failUnless(cur.rowcount in (-1,2), + 'insert using cursor.executemany set cursor.rowcount to ' + 'incorrect value %r' % cur.rowcount + ) + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2, + 'cursor.fetchall retrieved incorrect number of rows' + ) + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') + self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + finally: + con.close() + + def test_fetchone(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error,cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if a query retrieves ' + 'no rows' + ) + self.failUnless(cur.rowcount in (-1,0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchone() + self.assertEqual(len(r),1, + 'cursor.fetchone should have retrieved a single row' + ) + self.assertEqual(r[0],'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data' + ) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if no more rows available' + ) + self.failUnless(cur.rowcount in (-1,1)) + finally: + con.close() + + samples = [ + 'Carlton Cold', + 'Carlton Draft', + 'Mountain Goat', + 'Redback', + 'Victoria Bitter', + 'XXXX' + ] + + def _populate(self): + ''' Return a list of sql commands to setup the DB for the fetch + tests. + ''' + populate = [ + "insert into %sbooze values ('%s')" % (self.table_prefix,s) + for s in self.samples + ] + return populate + + def test_fetchmany(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchmany should raise an Error if called without + #issuing a query + self.assertRaises(self.driver.Error,cur.fetchmany,4) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() + self.assertEqual(len(r),1, + 'cursor.fetchmany retrieved incorrect number of rows, ' + 'default of arraysize is one.' + ) + cur.arraysize=10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual(len(r),3, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual(len(r),2, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence after ' + 'results are exhausted' + ) + self.failUnless(cur.rowcount in (-1,6)) + + # Same as above, using cursor.arraysize + cur.arraysize=4 + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual(len(r),4, + 'cursor.arraysize not being honoured by fetchmany' + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r),2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r),0) + self.failUnless(cur.rowcount in (-1,6)) + + cur.arraysize=6 + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.failUnless(cur.rowcount in (-1,6)) + self.assertEqual(len(rows),6) + self.assertEqual(len(rows),6) + rows = [r[0] for r in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0,6): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved by cursor.fetchmany' + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual(len(rows),0, + 'cursor.fetchmany should return an empty sequence if ' + 'called after the whole result set has been fetched' + ) + self.failUnless(cur.rowcount in (-1,6)) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence if ' + 'query retrieved no rows' + ) + self.failUnless(cur.rowcount in (-1,0)) + + finally: + con.close() + + def test_fetchall(self): + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows + self.assertRaises(self.driver.Error,cur.fetchall) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchall() + self.failUnless(cur.rowcount in (-1,len(self.samples))) + self.assertEqual(len(rows),len(self.samples), + 'cursor.fetchall did not retrieve all rows' + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'cursor.fetchall retrieved incorrect rows' + ) + rows = cur.fetchall() + self.assertEqual( + len(rows),0, + 'cursor.fetchall should return an empty list if called ' + 'after the whole result set has been fetched' + ) + self.failUnless(cur.rowcount in (-1,len(self.samples))) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + rows = cur.fetchall() + self.failUnless(cur.rowcount in (-1,0)) + self.assertEqual(len(rows),0, + 'cursor.fetchall should return an empty list if ' + 'a select query returns no rows' + ) + + finally: + con.close() + + def test_mixedfetch(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows1 = cur.fetchone() + rows23 = cur.fetchmany(2) + rows4 = cur.fetchone() + rows56 = cur.fetchall() + self.failUnless(cur.rowcount in (-1,6)) + self.assertEqual(len(rows23),2, + 'fetchmany returned incorrect number of rows' + ) + self.assertEqual(len(rows56),2, + 'fetchall returned incorrect number of rows' + ) + + rows = [rows1[0]] + rows.extend([rows23[0][0],rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows56[0][0],rows56[1][0]]) + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved or inserted' + ) + finally: + con.close() + + def help_nextset_setUp(self,cur): + ''' Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + ''' + raise NotImplementedError('Helper not implemented') + #sql=""" + # create procedure deleteme as + # begin + # select count(*) from booze + # select name from booze + # end + #""" + #cur.execute(sql) + + def help_nextset_tearDown(self,cur): + 'If cleaning up is needed after nextSetTest' + raise NotImplementedError('Helper not implemented') + #cur.execute("drop procedure deleteme") + + def test_nextset(self): + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur,'nextset'): + return + + try: + self.executeDDL1(cur) + sql=self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + cur.callproc('deleteme') + numberofrows=cur.fetchone() + assert numberofrows[0]== len(self.samples) + assert cur.nextset() + names=cur.fetchall() + assert len(names) == len(self.samples) + s=cur.nextset() + assert s is None, 'No more return sets, should return None' + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + def test_nextset(self): + raise NotImplementedError('Drivers need to override this test') + + def test_arraysize(self): + # Not much here - rest of the tests for this are in test_fetchmany + con = self._connect() + try: + cur = con.cursor() + self.failUnless(hasattr(cur,'arraysize'), + 'cursor.arraysize must be defined' + ) + finally: + con.close() + + def test_setinputsizes(self): + con = self._connect() + try: + cur = con.cursor() + cur.setinputsizes( (25,) ) + self._paraminsert(cur) # Make sure cursor still works + finally: + con.close() + + def test_setoutputsize_basic(self): + # Basic test is to make sure setoutputsize doesn't blow up + con = self._connect() + try: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000,0) + self._paraminsert(cur) # Make sure the cursor still works + finally: + con.close() + + def test_setoutputsize(self): + # Real test for setoutputsize is driver dependent + raise NotImplementedError('Driver needed to override this test') + + def test_None(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchall() + self.assertEqual(len(r),1) + self.assertEqual(len(r[0]),1) + self.assertEqual(r[0][0],None,'NULL value not returned as None') + finally: + con.close() + + def test_Date(self): + d1 = self.driver.Date(2002,12,25) + d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(d1),str(d2)) + + def test_Time(self): + t1 = self.driver.Time(13,45,30) + t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Timestamp(self): + t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t2 = self.driver.TimestampFromTicks( + time.mktime((2002,12,25,13,45,30,0,0,0)) + ) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Binary(self): + b = self.driver.Binary(b'Something') + b = self.driver.Binary(b'') + + def test_STRING(self): + self.failUnless(hasattr(self.driver,'STRING'), + 'module.STRING must be defined' + ) + + def test_BINARY(self): + self.failUnless(hasattr(self.driver,'BINARY'), + 'module.BINARY must be defined.' + ) + + def test_NUMBER(self): + self.failUnless(hasattr(self.driver,'NUMBER'), + 'module.NUMBER must be defined.' + ) + + def test_DATETIME(self): + self.failUnless(hasattr(self.driver,'DATETIME'), + 'module.DATETIME must be defined.' + ) + + def test_ROWID(self): + self.failUnless(hasattr(self.driver,'ROWID'), + 'module.ROWID must be defined.' + ) +# fmt: on diff --git a/tests/test_psycopg3_dbapi20.py b/tests/test_psycopg3_dbapi20.py new file mode 100644 index 000000000..fffd5379d --- /dev/null +++ b/tests/test_psycopg3_dbapi20.py @@ -0,0 +1,103 @@ +import pytest +import datetime as dt + +import psycopg3 + +from . import dbapi20 + + +@pytest.fixture(scope="class") +def with_dsn(request, dsn): + request.cls.connect_args = (dsn,) + + +@pytest.mark.usefixtures("with_dsn") +class Psycopg3Tests(dbapi20.DatabaseAPI20Test): + driver = psycopg3 + # connect_args = () # set by the fixture + connect_kw_args = {} + + def test_nextset(self): + # tested elsewhere + pass + + def test_setoutputsize(self): + # no-op + pass + + +# Shut up warnings +Psycopg3Tests.failUnless = Psycopg3Tests.assertTrue + + +@pytest.mark.parametrize( + "typename, singleton", + [ + ("bytea", "BINARY"), + ("date", "DATETIME"), + ("timestamp without time zone", "DATETIME"), + ("timestamp with time zone", "DATETIME"), + ("time without time zone", "DATETIME"), + ("time with time zone", "DATETIME"), + ("interval", "DATETIME"), + ("integer", "NUMBER"), + ("smallint", "NUMBER"), + ("bigint", "NUMBER"), + ("real", "NUMBER"), + ("double precision", "NUMBER"), + ("numeric", "NUMBER"), + ("decimal", "NUMBER"), + ("oid", "ROWID"), + ("varchar", "STRING"), + ("char", "STRING"), + ("text", "STRING"), + ], +) +def test_singletons(conn, typename, singleton): + singleton = getattr(psycopg3, singleton) + cur = conn.cursor() + cur.execute(f"select null::{typename}") + oid = cur.description[0].type_code + assert singleton == oid + assert oid == singleton + assert singleton != oid + 10000 + assert oid + 10000 != singleton + + +@pytest.mark.parametrize( + "ticks, want", + [ + (0, "1970-01-01T00:00:00.000000+0000"), + (1273173119.99992, "2010-05-06T14:11:59.999920-0500"), + ], +) +def test_timestamp_from_ticks(ticks, want): + s = psycopg3.TimestampFromTicks(ticks) + want = dt.datetime.strptime(want, "%Y-%m-%dT%H:%M:%S.%f%z") + assert s == want + + +@pytest.mark.parametrize( + "ticks, want", + [ + (0, "1970-01-01"), + # Returned date is local + (1273173119.99992, ["2010-05-06", "2010-05-07"]), + ], +) +def test_date_from_ticks(ticks, want): + s = psycopg3.DateFromTicks(ticks) + if isinstance(want, str): + want = [want] + want = [dt.datetime.strptime(w, "%Y-%m-%d").date() for w in want] + assert s in want + + +@pytest.mark.parametrize( + "ticks, want", + [(0, "00:00:00.000000"), (1273173119.99992, "00:11:59.999920")], +) +def test_time_from_ticks(ticks, want): + s = psycopg3.TimeFromTicks(ticks) + want = dt.datetime.strptime(want, "%H:%M:%S.%f").time() + assert s.replace(hour=0) == want