]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added barebone implementation of dbapi 2.0
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 04:54:47 +0000 (16:54 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 04:57:30 +0000 (16:57 +1200)
Several methods not tested yet, fetch method not added yet to async
cursor.

psycopg3/__init__.py
psycopg3/connection.py
psycopg3/cursor.py
psycopg3/dbapi20.py [new file with mode: 0644]
tests/dbapi20.py [new file with mode: 0644]
tests/test_psycopg3_dbapi20.py [new file with mode: 0644]

index 291c32519f414c04bb5ac6202dcf7ecb7fbe5278..23488315c62ec9c257de7c49c64f0a93ea36ce27 100644 (file)
@@ -23,6 +23,9 @@ from .errors import (
 # register default adapters
 from . import types  # noqa
 
+from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
+from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
+from .dbapi20 import Timestamp, TimestampFromTicks
 
 # DBAPI compliancy
 connect = Connection.connect
@@ -30,18 +33,12 @@ apilevel = "2.0"
 threadsafety = 2
 paramstyle = "pyformat"
 
-__all__ = [
-    "Warning",
-    "Error",
-    "InterfaceError",
-    "DatabaseError",
-    "DataError",
-    "OperationalError",
-    "IntegrityError",
-    "InternalError",
-    "ProgrammingError",
-    "NotSupportedError",
-    "AsyncConnection",
-    "Connection",
-    "connect",
-]
+__all__ = (
+    ["Warning", "Error", "InterfaceError", "DatabaseError", "DataError"]
+    + ["OperationalError", "IntegrityError", "InternalError"]
+    + ["ProgrammingError", "NotSupportedError"]
+    + ["AsyncConnection", "Connection", "connect"]
+    + ["BINARY", "DATETIME", "NUMBER", "ROWID", "STRING"]
+    + ["Binary", "Date", "DateFromTicks", "Time", "TimeFromTicks"]
+    + ["Timestamp", "TimestampFromTicks"]
+)
index 9976fe014b2bd31a46e38026ea65277cfd75a2b8..1ce4f62b2fb42a8938928cbd8d34b5adcb1d3307 100644 (file)
@@ -35,6 +35,18 @@ class BaseConnection:
     allow different interfaces (sync/async).
     """
 
+    # DBAPI2 exposed exceptions
+    Warning = e.Warning
+    Error = e.Error
+    InterfaceError = e.InterfaceError
+    DatabaseError = e.DatabaseError
+    DataError = e.DataError
+    OperationalError = e.OperationalError
+    IntegrityError = e.IntegrityError
+    InternalError = e.InternalError
+    ProgrammingError = e.ProgrammingError
+    NotSupportedError = e.NotSupportedError
+
     def __init__(self, pgconn: pq.PGconn):
         self.pgconn = pgconn
         self.cursor_factory = cursor.BaseCursor
@@ -43,6 +55,9 @@ class BaseConnection:
         # name of the postgres encoding (in bytes)
         self._pgenc = b""
 
+    def close(self) -> None:
+        self.pgconn.finish()
+
     def cursor(
         self, name: Optional[str] = None, binary: bool = False
     ) -> cursor.BaseCursor:
index c6ae1fea2305298ea232505b475a210d09753d7f..d4309b7c2ea41136e25ea290ac66601b25c9108e 100644 (file)
@@ -4,23 +4,59 @@ psycopg3 cursor objects
 
 # Copyright (C) 2020 The Psycopg Team
 
+import codecs
+from operator import attrgetter
 from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
 
 from . import errors as e
-from .pq import ExecStatus, PGresult, Format
+from .pq import ConnStatus, ExecStatus, PGresult, Format
 from .utils.queries import query2pg, reorder_params
 from .utils.typing import Query, Params
 
 if TYPE_CHECKING:
-    from .connection import (
-        BaseConnection,
-        Connection,
-        AsyncConnection,
-        QueryGen,
-    )
+    from .connection import BaseConnection, Connection, AsyncConnection
+    from .connection import QueryGen
     from .adapt import DumpersMap, LoadersMap
 
 
+class Column(Sequence[Any]):
+    def __init__(
+        self, pgresult: PGresult, index: int, codec: codecs.CodecInfo
+    ):
+        self._pgresult = pgresult
+        self._index = index
+        self._codec = codec
+
+    _attrs = tuple(
+        map(
+            attrgetter,
+            """
+            name type_code display_size internal_size precision scale null_ok
+            """.split(),
+        )
+    )
+
+    def __len__(self) -> int:
+        return 7
+
+    def __getitem__(self, index: Any) -> Any:
+        return self._attrs[index](self)
+
+    @property
+    def name(self) -> str:
+        rv = self._pgresult.fname(self._index)
+        if rv is not None:
+            return self._codec.decode(rv)[0]
+        else:
+            raise e.InterfaceError(
+                f"no name available for column {self._index}"
+            )
+
+    @property
+    def type_code(self) -> int:
+        return self._pgresult.ftype(self._index)
+
+
 class BaseCursor:
     def __init__(self, conn: "BaseConnection", binary: bool = False):
         self.conn = conn
@@ -28,6 +64,7 @@ class BaseCursor:
         self.dumpers: DumpersMap = {}
         self.loaders: LoadersMap = {}
         self._reset()
+        self.arraysize = 1
 
     def _reset(self) -> None:
         from .adapt import Transformer
@@ -51,10 +88,44 @@ class BaseCursor:
                 for i in range(result.nfields)
             )
 
+    @property
+    def description(self) -> Optional[List[Column]]:
+        res = self.pgresult
+        if res is None or res.status != ExecStatus.TUPLES_OK:
+            return None
+        return [Column(res, i, self.conn.codec) for i in range(res.nfields)]
+
+    @property
+    def rowcount(self) -> int:
+        res = self.pgresult
+        if res is None or res.status != ExecStatus.TUPLES_OK:
+            return -1
+        else:
+            return res.ntuples
+
+    def setinputsizes(self, sizes: Sequence[Any]) -> None:
+        # no-op
+        pass
+
+    def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
+        # no-op
+        pass
+
     def _execute_send(
         self, query: Query, vars: Optional[Params]
     ) -> "QueryGen":
         # Implement part of execute() before waiting common to sync and async
+        if self.conn.pgconn.status != ConnStatus.OK:
+            if self.conn.pgconn.status == ConnStatus.BAD:
+                raise e.InterfaceError(
+                    "cannot execute operations: the connection is closed"
+                )
+            else:
+                raise e.InterfaceError(
+                    f"cannot execute operations: the connection is"
+                    f" in status {self.conn.pgconn.status}"
+                )
+
         self._reset()
 
         codec = self.conn.codec
@@ -134,7 +205,12 @@ class BaseCursor:
     def _load_row(self, n: int) -> Optional[Tuple[Any, ...]]:
         res = self.pgresult
         if res is None:
-            return None
+            raise e.ProgrammingError("no result available")
+        elif res.status != ExecStatus.TUPLES_OK:
+            raise e.ProgrammingError(
+                "the last operation didn't produce a result"
+            )
+
         if n >= res.ntuples:
             return None
 
@@ -158,12 +234,48 @@ class Cursor(BaseCursor):
             self._execute_results(results)
         return self
 
+    def executemany(
+        self, query: Query, vars_seq: Sequence[Params]
+    ) -> "Cursor":
+        with self.conn.lock:
+            # TODO: trivial implementation; use prepare
+            for vars in vars_seq:
+                gen = self._execute_send(query, vars)
+                results = self.conn.wait(gen)
+                self._execute_results(results)
+        return self
+
     def fetchone(self) -> Optional[Sequence[Any]]:
         rv = self._load_row(self._pos)
         if rv is not None:
             self._pos += 1
         return rv
 
+    def fetchmany(self, size: Optional[int] = None) -> List[Sequence[Any]]:
+        if size is None:
+            size = self.arraysize
+
+        rv: List[Sequence[Any]] = []
+        while len(rv) < size:
+            row = self._load_row(self._pos)
+            if row is None:
+                break
+            self._pos += 1
+            rv.append(row)
+
+        return rv
+
+    def fetchall(self) -> List[Sequence[Any]]:
+        rv: List[Sequence[Any]] = []
+        while 1:
+            row = self._load_row(self._pos)
+            if row is None:
+                break
+            self._pos += 1
+            rv.append(row)
+
+        return rv
+
 
 class AsyncCursor(BaseCursor):
     conn: "AsyncConnection"
diff --git a/psycopg3/dbapi20.py b/psycopg3/dbapi20.py
new file mode 100644 (file)
index 0000000..527a021
--- /dev/null
@@ -0,0 +1,90 @@
+"""
+Compatibility objects with DBAPI 2.0
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import time
+import datetime as dt
+from math import floor
+from typing import Any, Sequence, Tuple
+
+from .types.oids import builtins
+from .adapt import Dumper
+
+
+class DBAPITypeObject:
+    def __init__(self, name: str, type_names: Sequence[str]):
+        self.name = name
+        self.values = tuple(builtins[n].oid for n in type_names)
+
+    def __repr__(self) -> str:
+        return f"psycopg3.{self.name}"
+
+    def __eq__(self, other: Any) -> bool:
+        if isinstance(other, int):
+            return other in self.values
+        else:
+            return NotImplemented
+
+    def __ne__(self, other: Any) -> bool:
+        if isinstance(other, int):
+            return other not in self.values
+        else:
+            return NotImplemented
+
+
+BINARY = DBAPITypeObject("BINARY", ("bytea",))
+DATETIME = DBAPITypeObject(
+    "DATETIME", "timestamp timestamptz date time timetz interval".split()
+)
+NUMBER = DBAPITypeObject(
+    "NUMBER", "int2 int4 int8 float4 float8 numeric".split()
+)
+ROWID = DBAPITypeObject("ROWID", ("oid",))
+STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
+
+
+class Binary:
+    def __init__(self, obj: Any):
+        self.obj = obj
+
+
+@Dumper.text(Binary)
+def dump_Binary(obj: Binary) -> Tuple[bytes, int]:
+    rv = obj.obj
+    if not isinstance(rv, bytes):
+        rv = bytes(rv)
+
+    return rv, builtins["bytea"].oid
+
+
+def Date(year: int, month: int, day: int) -> dt.date:
+    return dt.date(year, month, day)
+
+
+def DateFromTicks(ticks: float) -> dt.date:
+    return TimestampFromTicks(ticks).date()
+
+
+def Time(hour: int, minute: int, second: int) -> dt.time:
+    return dt.time(hour, minute, second)
+
+
+def TimeFromTicks(ticks: float) -> dt.time:
+    return TimestampFromTicks(ticks).time()
+
+
+def Timestamp(
+    year: int, month: int, day: int, hour: int, minute: int, second: int
+) -> dt.datetime:
+    return dt.datetime(year, month, day, hour, minute, second)
+
+
+def TimestampFromTicks(ticks: float) -> dt.datetime:
+    secs = floor(ticks)
+    frac = ticks - secs
+    t = time.localtime(ticks)
+    tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
+    rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
+    return rv
diff --git a/tests/dbapi20.py b/tests/dbapi20.py
new file mode 100644 (file)
index 0000000..86e42b5
--- /dev/null
@@ -0,0 +1,870 @@
+#!/usr/bin/env python
+# flake8: noqa
+# fmt: off
+''' Python DB API 2.0 driver compliance unit test suite.
+
+    This software is Public Domain and may be used without restrictions.
+
+ "Now we have booze and barflies entering the discussion, plus rumours of
+  DBAs on drugs... and I won't tell you what flashes through my mind each
+  time I read the subject line with 'Anal Compliance' in it.  All around
+  this is turning out to be a thoroughly unwholesome unit test."
+
+    -- Ian Bicking
+'''
+
+__rcs_id__  = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
+__version__ = '$Revision: 1.12 $'[11:-2]
+__author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
+
+import unittest
+import time
+import sys
+
+
+# Revision 1.12  2009/02/06 03:35:11  kf7xm
+# Tested okay with Python 3.0, includes last minute patches from Mark H.
+#
+# Revision 1.1.1.1.2.1  2008/09/20 19:54:59  rupole
+# Include latest changes from main branch
+# Updates for py3k
+#
+# Revision 1.11  2005/01/02 02:41:01  zenzen
+# Update author email address
+#
+# Revision 1.10  2003/10/09 03:14:14  zenzen
+# Add test for DB API 2.0 optional extension, where database exceptions
+# are exposed as attributes on the Connection object.
+#
+# Revision 1.9  2003/08/13 01:16:36  zenzen
+# Minor tweak from Stefan Fleiter
+#
+# Revision 1.8  2003/04/10 00:13:25  zenzen
+# Changes, as per suggestions by M.-A. Lemburg
+# - Add a table prefix, to ensure namespace collisions can always be avoided
+#
+# Revision 1.7  2003/02/26 23:33:37  zenzen
+# Break out DDL into helper functions, as per request by David Rushby
+#
+# Revision 1.6  2003/02/21 03:04:33  zenzen
+# Stuff from Henrik Ekelund:
+#     added test_None
+#     added test_nextset & hooks
+#
+# Revision 1.5  2003/02/17 22:08:43  zenzen
+# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
+# defaults to 1 & generic cursor.callproc test added
+#
+# Revision 1.4  2003/02/15 00:16:33  zenzen
+# Changes, as per suggestions and bug reports by M.-A. Lemburg,
+# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
+# - Class renamed
+# - Now a subclass of TestCase, to avoid requiring the driver stub
+#   to use multiple inheritance
+# - Reversed the polarity of buggy test in test_description
+# - Test exception hierarchy correctly
+# - self.populate is now self._populate(), so if a driver stub
+#   overrides self.ddl1 this change propagates
+# - VARCHAR columns now have a width, which will hopefully make the
+#   DDL even more portible (this will be reversed if it causes more problems)
+# - cursor.rowcount being checked after various execute and fetchXXX methods
+# - Check for fetchall and fetchmany returning empty lists after results
+#   are exhausted (already checking for empty lists if select retrieved
+#   nothing
+# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
+#
+
+class DatabaseAPI20Test(unittest.TestCase):
+    ''' Test a database self.driver for DB API 2.0 compatibility.
+        This implementation tests Gadfly, but the TestCase
+        is structured so that other self.drivers can subclass this
+        test case to ensure compiliance with the DB-API. It is
+        expected that this TestCase may be expanded in the future
+        if ambiguities or edge conditions are discovered.
+
+        The 'Optional Extensions' are not yet being tested.
+
+        self.drivers should subclass this test, overriding setUp, tearDown,
+        self.driver, connect_args and connect_kw_args. Class specification
+        should be as follows:
+
+        from . import dbapi20
+        class mytest(dbapi20.DatabaseAPI20Test):
+           [...]
+
+        Don't 'from .dbapi20 import DatabaseAPI20Test', or you will
+        confuse the unit tester - just 'from . import dbapi20'.
+    '''
+
+    # The self.driver module. This should be the module where the 'connect'
+    # method is to be found
+    driver = None
+    connect_args = () # List of arguments to pass to connect
+    connect_kw_args = {} # Keyword arguments for connect
+    table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
+
+    ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
+    ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
+    xddl1 = 'drop table %sbooze' % table_prefix
+    xddl2 = 'drop table %sbarflys' % table_prefix
+
+    lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
+
+    # Some drivers may need to override these helpers, for example adding
+    # a 'commit' after the execute.
+    def executeDDL1(self,cursor):
+        cursor.execute(self.ddl1)
+
+    def executeDDL2(self,cursor):
+        cursor.execute(self.ddl2)
+
+    def setUp(self):
+        ''' self.drivers should override this method to perform required setup
+            if any is necessary, such as creating the database.
+        '''
+        pass
+
+    def tearDown(self):
+        ''' self.drivers should override this method to perform required cleanup
+            if any is necessary, such as deleting the test database.
+            The default drops the tables that may be created.
+        '''
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            for ddl in (self.xddl1,self.xddl2):
+                try:
+                    cur.execute(ddl)
+                    con.commit()
+                except self.driver.Error:
+                    # Assume table didn't exist. Other tests will check if
+                    # execute is busted.
+                    pass
+        finally:
+            con.close()
+
+    def _connect(self):
+        try:
+            return self.driver.connect(
+                *self.connect_args,**self.connect_kw_args
+                )
+        except AttributeError:
+            self.fail("No connect method found in self.driver module")
+
+    def test_connect(self):
+        con = self._connect()
+        con.close()
+
+    def test_apilevel(self):
+        try:
+            # Must exist
+            apilevel = self.driver.apilevel
+            # Must equal 2.0
+            self.assertEqual(apilevel,'2.0')
+        except AttributeError:
+            self.fail("Driver doesn't define apilevel")
+
+    def test_threadsafety(self):
+        try:
+            # Must exist
+            threadsafety = self.driver.threadsafety
+            # Must be a valid value
+            self.failUnless(threadsafety in (0,1,2,3))
+        except AttributeError:
+            self.fail("Driver doesn't define threadsafety")
+
+    def test_paramstyle(self):
+        try:
+            # Must exist
+            paramstyle = self.driver.paramstyle
+            # Must be a valid value
+            self.failUnless(paramstyle in (
+                'qmark','numeric','named','format','pyformat'
+                ))
+        except AttributeError:
+            self.fail("Driver doesn't define paramstyle")
+
+    def test_Exceptions(self):
+        # Make sure required exceptions exist, and are in the
+        # defined hierarchy.
+        if sys.version[0] == '3': #under Python 3 StardardError no longer exists
+            self.failUnless(issubclass(self.driver.Warning,Exception))
+            self.failUnless(issubclass(self.driver.Error,Exception))
+        else:
+            self.failUnless(issubclass(self.driver.Warning,StandardError))
+            self.failUnless(issubclass(self.driver.Error,StandardError))
+
+        self.failUnless(
+            issubclass(self.driver.InterfaceError,self.driver.Error)
+            )
+        self.failUnless(
+            issubclass(self.driver.DatabaseError,self.driver.Error)
+            )
+        self.failUnless(
+            issubclass(self.driver.OperationalError,self.driver.Error)
+            )
+        self.failUnless(
+            issubclass(self.driver.IntegrityError,self.driver.Error)
+            )
+        self.failUnless(
+            issubclass(self.driver.InternalError,self.driver.Error)
+            )
+        self.failUnless(
+            issubclass(self.driver.ProgrammingError,self.driver.Error)
+            )
+        self.failUnless(
+            issubclass(self.driver.NotSupportedError,self.driver.Error)
+            )
+
+    def test_ExceptionsAsConnectionAttributes(self):
+        # OPTIONAL EXTENSION
+        # Test for the optional DB API 2.0 extension, where the exceptions
+        # are exposed as attributes on the Connection object
+        # I figure this optional extension will be implemented by any
+        # driver author who is using this test suite, so it is enabled
+        # by default.
+        con = self._connect()
+        drv = self.driver
+        self.failUnless(con.Warning is drv.Warning)
+        self.failUnless(con.Error is drv.Error)
+        self.failUnless(con.InterfaceError is drv.InterfaceError)
+        self.failUnless(con.DatabaseError is drv.DatabaseError)
+        self.failUnless(con.OperationalError is drv.OperationalError)
+        self.failUnless(con.IntegrityError is drv.IntegrityError)
+        self.failUnless(con.InternalError is drv.InternalError)
+        self.failUnless(con.ProgrammingError is drv.ProgrammingError)
+        self.failUnless(con.NotSupportedError is drv.NotSupportedError)
+
+
+    def test_commit(self):
+        con = self._connect()
+        try:
+            # Commit must work, even if it doesn't do anything
+            con.commit()
+        finally:
+            con.close()
+
+    def test_rollback(self):
+        con = self._connect()
+        # If rollback is defined, it should either work or throw
+        # the documented exception
+        if hasattr(con,'rollback'):
+            try:
+                con.rollback()
+            except self.driver.NotSupportedError:
+                pass
+
+    def test_cursor(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+        finally:
+            con.close()
+
+    def test_cursor_isolation(self):
+        con = self._connect()
+        try:
+            # Make sure cursors created from the same connection have
+            # the documented transaction isolation level
+            cur1 = con.cursor()
+            cur2 = con.cursor()
+            self.executeDDL1(cur1)
+            cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
+                self.table_prefix
+                ))
+            cur2.execute("select name from %sbooze" % self.table_prefix)
+            booze = cur2.fetchall()
+            self.assertEqual(len(booze),1)
+            self.assertEqual(len(booze[0]),1)
+            self.assertEqual(booze[0][0],'Victoria Bitter')
+        finally:
+            con.close()
+
+    def test_description(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self.executeDDL1(cur)
+            self.assertEqual(cur.description,None,
+                'cursor.description should be none after executing a '
+                'statement that can return no rows (such as DDL)'
+                )
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            self.assertEqual(len(cur.description),1,
+                'cursor.description describes too many columns'
+                )
+            self.assertEqual(len(cur.description[0]),7,
+                'cursor.description[x] tuples must have 7 elements'
+                )
+            self.assertEqual(cur.description[0][0].lower(),'name',
+                'cursor.description[x][0] must return column name'
+                )
+            self.assertEqual(cur.description[0][1],self.driver.STRING,
+                'cursor.description[x][1] must return column type. Got %r'
+                    % cur.description[0][1]
+                )
+
+            # Make sure self.description gets reset
+            self.executeDDL2(cur)
+            self.assertEqual(cur.description,None,
+                'cursor.description not being set to None when executing '
+                'no-result statements (eg. DDL)'
+                )
+        finally:
+            con.close()
+
+    def test_rowcount(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self.executeDDL1(cur)
+            self.assertEqual(cur.rowcount,-1,
+                'cursor.rowcount should be -1 after executing no-result '
+                'statements'
+                )
+            cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+                self.table_prefix
+                ))
+            self.failUnless(cur.rowcount in (-1,1),
+                'cursor.rowcount should == number or rows inserted, or '
+                'set to -1 after executing an insert statement'
+                )
+            cur.execute("select name from %sbooze" % self.table_prefix)
+            self.failUnless(cur.rowcount in (-1,1),
+                'cursor.rowcount should == number of rows returned, or '
+                'set to -1 after executing a select statement'
+                )
+            self.executeDDL2(cur)
+            self.assertEqual(cur.rowcount,-1,
+                'cursor.rowcount not being reset to -1 after executing '
+                'no-result statements'
+                )
+        finally:
+            con.close()
+
+    lower_func = 'lower'
+    def test_callproc(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            if self.lower_func and hasattr(cur,'callproc'):
+                r = cur.callproc(self.lower_func,('FOO',))
+                self.assertEqual(len(r),1)
+                self.assertEqual(r[0],'FOO')
+                r = cur.fetchall()
+                self.assertEqual(len(r),1,'callproc produced no result set')
+                self.assertEqual(len(r[0]),1,
+                    'callproc produced invalid result set'
+                    )
+                self.assertEqual(r[0][0],'foo',
+                    'callproc produced invalid results'
+                    )
+        finally:
+            con.close()
+
+    def test_close(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+        finally:
+            con.close()
+
+        # cursor.execute should raise an Error if called after connection
+        # closed
+        self.assertRaises(self.driver.Error,self.executeDDL1,cur)
+
+        # connection.commit should raise an Error if called after connection'
+        # closed.'
+        self.assertRaises(self.driver.Error,con.commit)
+
+        # connection.close should raise an Error if called more than once
+        # Issue discussed on DB-SIG: consensus seem that close() should not
+        # raised if called on closed objects. Issue reported back to Stuart.
+        # self.assertRaises(self.driver.Error,con.close)
+
+    def test_execute(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self._paraminsert(cur)
+        finally:
+            con.close()
+
+    def _paraminsert(self,cur):
+        self.executeDDL1(cur)
+        cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+            self.table_prefix
+            ))
+        self.failUnless(cur.rowcount in (-1,1))
+
+        if self.driver.paramstyle == 'qmark':
+            cur.execute(
+                'insert into %sbooze values (?)' % self.table_prefix,
+                ("Cooper's",)
+                )
+        elif self.driver.paramstyle == 'numeric':
+            cur.execute(
+                'insert into %sbooze values (:1)' % self.table_prefix,
+                ("Cooper's",)
+                )
+        elif self.driver.paramstyle == 'named':
+            cur.execute(
+                'insert into %sbooze values (:beer)' % self.table_prefix,
+                {'beer':"Cooper's"}
+                )
+        elif self.driver.paramstyle == 'format':
+            cur.execute(
+                'insert into %sbooze values (%%s)' % self.table_prefix,
+                ("Cooper's",)
+                )
+        elif self.driver.paramstyle == 'pyformat':
+            cur.execute(
+                'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
+                {'beer':"Cooper's"}
+                )
+        else:
+            self.fail('Invalid paramstyle')
+        self.failUnless(cur.rowcount in (-1,1))
+
+        cur.execute('select name from %sbooze' % self.table_prefix)
+        res = cur.fetchall()
+        self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
+        beers = [res[0][0],res[1][0]]
+        beers.sort()
+        self.assertEqual(beers[0],"Cooper's",
+            'cursor.fetchall retrieved incorrect data, or data inserted '
+            'incorrectly'
+            )
+        self.assertEqual(beers[1],"Victoria Bitter",
+            'cursor.fetchall retrieved incorrect data, or data inserted '
+            'incorrectly'
+            )
+
+    def test_executemany(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self.executeDDL1(cur)
+            largs = [ ("Cooper's",) , ("Boag's",) ]
+            margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
+            if self.driver.paramstyle == 'qmark':
+                cur.executemany(
+                    'insert into %sbooze values (?)' % self.table_prefix,
+                    largs
+                    )
+            elif self.driver.paramstyle == 'numeric':
+                cur.executemany(
+                    'insert into %sbooze values (:1)' % self.table_prefix,
+                    largs
+                    )
+            elif self.driver.paramstyle == 'named':
+                cur.executemany(
+                    'insert into %sbooze values (:beer)' % self.table_prefix,
+                    margs
+                    )
+            elif self.driver.paramstyle == 'format':
+                cur.executemany(
+                    'insert into %sbooze values (%%s)' % self.table_prefix,
+                    largs
+                    )
+            elif self.driver.paramstyle == 'pyformat':
+                cur.executemany(
+                    'insert into %sbooze values (%%(beer)s)' % (
+                        self.table_prefix
+                        ),
+                    margs
+                    )
+            else:
+                self.fail('Unknown paramstyle')
+            self.failUnless(cur.rowcount in (-1,2),
+                'insert using cursor.executemany set cursor.rowcount to '
+                'incorrect value %r' % cur.rowcount
+                )
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            res = cur.fetchall()
+            self.assertEqual(len(res),2,
+                'cursor.fetchall retrieved incorrect number of rows'
+                )
+            beers = [res[0][0],res[1][0]]
+            beers.sort()
+            self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
+            self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
+        finally:
+            con.close()
+
+    def test_fetchone(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+
+            # cursor.fetchone should raise an Error if called before
+            # executing a select-type query
+            self.assertRaises(self.driver.Error,cur.fetchone)
+
+            # cursor.fetchone should raise an Error if called after
+            # executing a query that cannot return rows
+            self.executeDDL1(cur)
+            self.assertRaises(self.driver.Error,cur.fetchone)
+
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            self.assertEqual(cur.fetchone(),None,
+                'cursor.fetchone should return None if a query retrieves '
+                'no rows'
+                )
+            self.failUnless(cur.rowcount in (-1,0))
+
+            # cursor.fetchone should raise an Error if called after
+            # executing a query that cannot return rows
+            cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+                self.table_prefix
+                ))
+            self.assertRaises(self.driver.Error,cur.fetchone)
+
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            r = cur.fetchone()
+            self.assertEqual(len(r),1,
+                'cursor.fetchone should have retrieved a single row'
+                )
+            self.assertEqual(r[0],'Victoria Bitter',
+                'cursor.fetchone retrieved incorrect data'
+                )
+            self.assertEqual(cur.fetchone(),None,
+                'cursor.fetchone should return None if no more rows available'
+                )
+            self.failUnless(cur.rowcount in (-1,1))
+        finally:
+            con.close()
+
+    samples = [
+        'Carlton Cold',
+        'Carlton Draft',
+        'Mountain Goat',
+        'Redback',
+        'Victoria Bitter',
+        'XXXX'
+        ]
+
+    def _populate(self):
+        ''' Return a list of sql commands to setup the DB for the fetch
+            tests.
+        '''
+        populate = [
+            "insert into %sbooze values ('%s')" % (self.table_prefix,s)
+                for s in self.samples
+            ]
+        return populate
+
+    def test_fetchmany(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+
+            # cursor.fetchmany should raise an Error if called without
+            #issuing a query
+            self.assertRaises(self.driver.Error,cur.fetchmany,4)
+
+            self.executeDDL1(cur)
+            for sql in self._populate():
+                cur.execute(sql)
+
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            r = cur.fetchmany()
+            self.assertEqual(len(r),1,
+                'cursor.fetchmany retrieved incorrect number of rows, '
+                'default of arraysize is one.'
+                )
+            cur.arraysize=10
+            r = cur.fetchmany(3) # Should get 3 rows
+            self.assertEqual(len(r),3,
+                'cursor.fetchmany retrieved incorrect number of rows'
+                )
+            r = cur.fetchmany(4) # Should get 2 more
+            self.assertEqual(len(r),2,
+                'cursor.fetchmany retrieved incorrect number of rows'
+                )
+            r = cur.fetchmany(4) # Should be an empty sequence
+            self.assertEqual(len(r),0,
+                'cursor.fetchmany should return an empty sequence after '
+                'results are exhausted'
+            )
+            self.failUnless(cur.rowcount in (-1,6))
+
+            # Same as above, using cursor.arraysize
+            cur.arraysize=4
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            r = cur.fetchmany() # Should get 4 rows
+            self.assertEqual(len(r),4,
+                'cursor.arraysize not being honoured by fetchmany'
+                )
+            r = cur.fetchmany() # Should get 2 more
+            self.assertEqual(len(r),2)
+            r = cur.fetchmany() # Should be an empty sequence
+            self.assertEqual(len(r),0)
+            self.failUnless(cur.rowcount in (-1,6))
+
+            cur.arraysize=6
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            rows = cur.fetchmany() # Should get all rows
+            self.failUnless(cur.rowcount in (-1,6))
+            self.assertEqual(len(rows),6)
+            self.assertEqual(len(rows),6)
+            rows = [r[0] for r in rows]
+            rows.sort()
+
+            # Make sure we get the right data back out
+            for i in range(0,6):
+                self.assertEqual(rows[i],self.samples[i],
+                    'incorrect data retrieved by cursor.fetchmany'
+                    )
+
+            rows = cur.fetchmany() # Should return an empty list
+            self.assertEqual(len(rows),0,
+                'cursor.fetchmany should return an empty sequence if '
+                'called after the whole result set has been fetched'
+                )
+            self.failUnless(cur.rowcount in (-1,6))
+
+            self.executeDDL2(cur)
+            cur.execute('select name from %sbarflys' % self.table_prefix)
+            r = cur.fetchmany() # Should get empty sequence
+            self.assertEqual(len(r),0,
+                'cursor.fetchmany should return an empty sequence if '
+                'query retrieved no rows'
+                )
+            self.failUnless(cur.rowcount in (-1,0))
+
+        finally:
+            con.close()
+
+    def test_fetchall(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            # cursor.fetchall should raise an Error if called
+            # without executing a query that may return rows (such
+            # as a select)
+            self.assertRaises(self.driver.Error, cur.fetchall)
+
+            self.executeDDL1(cur)
+            for sql in self._populate():
+                cur.execute(sql)
+
+            # cursor.fetchall should raise an Error if called
+            # after executing a a statement that cannot return rows
+            self.assertRaises(self.driver.Error,cur.fetchall)
+
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            rows = cur.fetchall()
+            self.failUnless(cur.rowcount in (-1,len(self.samples)))
+            self.assertEqual(len(rows),len(self.samples),
+                'cursor.fetchall did not retrieve all rows'
+                )
+            rows = [r[0] for r in rows]
+            rows.sort()
+            for i in range(0,len(self.samples)):
+                self.assertEqual(rows[i],self.samples[i],
+                'cursor.fetchall retrieved incorrect rows'
+                )
+            rows = cur.fetchall()
+            self.assertEqual(
+                len(rows),0,
+                'cursor.fetchall should return an empty list if called '
+                'after the whole result set has been fetched'
+                )
+            self.failUnless(cur.rowcount in (-1,len(self.samples)))
+
+            self.executeDDL2(cur)
+            cur.execute('select name from %sbarflys' % self.table_prefix)
+            rows = cur.fetchall()
+            self.failUnless(cur.rowcount in (-1,0))
+            self.assertEqual(len(rows),0,
+                'cursor.fetchall should return an empty list if '
+                'a select query returns no rows'
+                )
+
+        finally:
+            con.close()
+
+    def test_mixedfetch(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self.executeDDL1(cur)
+            for sql in self._populate():
+                cur.execute(sql)
+
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            rows1  = cur.fetchone()
+            rows23 = cur.fetchmany(2)
+            rows4  = cur.fetchone()
+            rows56 = cur.fetchall()
+            self.failUnless(cur.rowcount in (-1,6))
+            self.assertEqual(len(rows23),2,
+                'fetchmany returned incorrect number of rows'
+                )
+            self.assertEqual(len(rows56),2,
+                'fetchall returned incorrect number of rows'
+                )
+
+            rows = [rows1[0]]
+            rows.extend([rows23[0][0],rows23[1][0]])
+            rows.append(rows4[0])
+            rows.extend([rows56[0][0],rows56[1][0]])
+            rows.sort()
+            for i in range(0,len(self.samples)):
+                self.assertEqual(rows[i],self.samples[i],
+                    'incorrect data retrieved or inserted'
+                    )
+        finally:
+            con.close()
+
+    def help_nextset_setUp(self,cur):
+        ''' Should create a procedure called deleteme
+            that returns two result sets, first the
+           number of rows in booze then "name from booze"
+        '''
+        raise NotImplementedError('Helper not implemented')
+        #sql="""
+        #    create procedure deleteme as
+        #    begin
+        #        select count(*) from booze
+        #        select name from booze
+        #    end
+        #"""
+        #cur.execute(sql)
+
+    def help_nextset_tearDown(self,cur):
+        'If cleaning up is needed after nextSetTest'
+        raise NotImplementedError('Helper not implemented')
+        #cur.execute("drop procedure deleteme")
+
+    def test_nextset(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            if not hasattr(cur,'nextset'):
+                return
+
+            try:
+                self.executeDDL1(cur)
+                sql=self._populate()
+                for sql in self._populate():
+                    cur.execute(sql)
+
+                self.help_nextset_setUp(cur)
+
+                cur.callproc('deleteme')
+                numberofrows=cur.fetchone()
+                assert numberofrows[0]== len(self.samples)
+                assert cur.nextset()
+                names=cur.fetchall()
+                assert len(names) == len(self.samples)
+                s=cur.nextset()
+                assert s is None, 'No more return sets, should return None'
+            finally:
+                self.help_nextset_tearDown(cur)
+
+        finally:
+            con.close()
+
+    def test_nextset(self):
+        raise NotImplementedError('Drivers need to override this test')
+
+    def test_arraysize(self):
+        # Not much here - rest of the tests for this are in test_fetchmany
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self.failUnless(hasattr(cur,'arraysize'),
+                'cursor.arraysize must be defined'
+                )
+        finally:
+            con.close()
+
+    def test_setinputsizes(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            cur.setinputsizes( (25,) )
+            self._paraminsert(cur) # Make sure cursor still works
+        finally:
+            con.close()
+
+    def test_setoutputsize_basic(self):
+        # Basic test is to make sure setoutputsize doesn't blow up
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            cur.setoutputsize(1000)
+            cur.setoutputsize(2000,0)
+            self._paraminsert(cur) # Make sure the cursor still works
+        finally:
+            con.close()
+
+    def test_setoutputsize(self):
+        # Real test for setoutputsize is driver dependent
+        raise NotImplementedError('Driver needed to override this test')
+
+    def test_None(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self.executeDDL1(cur)
+            cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
+            cur.execute('select name from %sbooze' % self.table_prefix)
+            r = cur.fetchall()
+            self.assertEqual(len(r),1)
+            self.assertEqual(len(r[0]),1)
+            self.assertEqual(r[0][0],None,'NULL value not returned as None')
+        finally:
+            con.close()
+
+    def test_Date(self):
+        d1 = self.driver.Date(2002,12,25)
+        d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
+        # Can we assume this? API doesn't specify, but it seems implied
+        # self.assertEqual(str(d1),str(d2))
+
+    def test_Time(self):
+        t1 = self.driver.Time(13,45,30)
+        t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
+        # Can we assume this? API doesn't specify, but it seems implied
+        # self.assertEqual(str(t1),str(t2))
+
+    def test_Timestamp(self):
+        t1 = self.driver.Timestamp(2002,12,25,13,45,30)
+        t2 = self.driver.TimestampFromTicks(
+            time.mktime((2002,12,25,13,45,30,0,0,0))
+            )
+        # Can we assume this? API doesn't specify, but it seems implied
+        # self.assertEqual(str(t1),str(t2))
+
+    def test_Binary(self):
+        b = self.driver.Binary(b'Something')
+        b = self.driver.Binary(b'')
+
+    def test_STRING(self):
+        self.failUnless(hasattr(self.driver,'STRING'),
+            'module.STRING must be defined'
+            )
+
+    def test_BINARY(self):
+        self.failUnless(hasattr(self.driver,'BINARY'),
+            'module.BINARY must be defined.'
+            )
+
+    def test_NUMBER(self):
+        self.failUnless(hasattr(self.driver,'NUMBER'),
+            'module.NUMBER must be defined.'
+            )
+
+    def test_DATETIME(self):
+        self.failUnless(hasattr(self.driver,'DATETIME'),
+            'module.DATETIME must be defined.'
+            )
+
+    def test_ROWID(self):
+        self.failUnless(hasattr(self.driver,'ROWID'),
+            'module.ROWID must be defined.'
+            )
+# fmt: on
diff --git a/tests/test_psycopg3_dbapi20.py b/tests/test_psycopg3_dbapi20.py
new file mode 100644 (file)
index 0000000..fffd537
--- /dev/null
@@ -0,0 +1,103 @@
+import pytest
+import datetime as dt
+
+import psycopg3
+
+from . import dbapi20
+
+
+@pytest.fixture(scope="class")
+def with_dsn(request, dsn):
+    request.cls.connect_args = (dsn,)
+
+
+@pytest.mark.usefixtures("with_dsn")
+class Psycopg3Tests(dbapi20.DatabaseAPI20Test):
+    driver = psycopg3
+    # connect_args = () # set by the fixture
+    connect_kw_args = {}
+
+    def test_nextset(self):
+        # tested elsewhere
+        pass
+
+    def test_setoutputsize(self):
+        # no-op
+        pass
+
+
+# Shut up warnings
+Psycopg3Tests.failUnless = Psycopg3Tests.assertTrue
+
+
+@pytest.mark.parametrize(
+    "typename, singleton",
+    [
+        ("bytea", "BINARY"),
+        ("date", "DATETIME"),
+        ("timestamp without time zone", "DATETIME"),
+        ("timestamp with time zone", "DATETIME"),
+        ("time without time zone", "DATETIME"),
+        ("time with time zone", "DATETIME"),
+        ("interval", "DATETIME"),
+        ("integer", "NUMBER"),
+        ("smallint", "NUMBER"),
+        ("bigint", "NUMBER"),
+        ("real", "NUMBER"),
+        ("double precision", "NUMBER"),
+        ("numeric", "NUMBER"),
+        ("decimal", "NUMBER"),
+        ("oid", "ROWID"),
+        ("varchar", "STRING"),
+        ("char", "STRING"),
+        ("text", "STRING"),
+    ],
+)
+def test_singletons(conn, typename, singleton):
+    singleton = getattr(psycopg3, singleton)
+    cur = conn.cursor()
+    cur.execute(f"select null::{typename}")
+    oid = cur.description[0].type_code
+    assert singleton == oid
+    assert oid == singleton
+    assert singleton != oid + 10000
+    assert oid + 10000 != singleton
+
+
+@pytest.mark.parametrize(
+    "ticks, want",
+    [
+        (0, "1970-01-01T00:00:00.000000+0000"),
+        (1273173119.99992, "2010-05-06T14:11:59.999920-0500"),
+    ],
+)
+def test_timestamp_from_ticks(ticks, want):
+    s = psycopg3.TimestampFromTicks(ticks)
+    want = dt.datetime.strptime(want, "%Y-%m-%dT%H:%M:%S.%f%z")
+    assert s == want
+
+
+@pytest.mark.parametrize(
+    "ticks, want",
+    [
+        (0, "1970-01-01"),
+        # Returned date is local
+        (1273173119.99992, ["2010-05-06", "2010-05-07"]),
+    ],
+)
+def test_date_from_ticks(ticks, want):
+    s = psycopg3.DateFromTicks(ticks)
+    if isinstance(want, str):
+        want = [want]
+    want = [dt.datetime.strptime(w, "%Y-%m-%d").date() for w in want]
+    assert s in want
+
+
+@pytest.mark.parametrize(
+    "ticks, want",
+    [(0, "00:00:00.000000"), (1273173119.99992, "00:11:59.999920")],
+)
+def test_time_from_ticks(ticks, want):
+    s = psycopg3.TimeFromTicks(ticks)
+    want = dt.datetime.strptime(want, "%H:%M:%S.%f").time()
+    assert s.replace(hour=0) == want