]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Wrapped lipq method for async connection
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 13 Mar 2020 11:05:53 +0000 (00:05 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 13 Mar 2020 11:05:53 +0000 (00:05 +1300)
psycopg3/_pq_ctypes.py
psycopg3/pq_ctypes.py
psycopg3/pq_enums.py
tests/test_pq.py

index bfa0ad97454aaedfab7bfd31deddc404465da876..d6e748db7f1a9ac5accddc2e8aca548281c0a3e3 100644 (file)
@@ -12,16 +12,42 @@ from ctypes import c_char_p, c_int
 pq = ctypes.pydll.LoadLibrary(ctypes.util.find_library("pq"))
 
 
+# libpq data types
+
+
 class PGconn(Structure):
     _fields_ = []
 
 
 PGconn_ptr = POINTER(PGconn)
 
+
+# Function definitions as explained in PostgreSQL 12 documentation
+
+# 33.1. Database Connection Control Functions
+
+# PQconnectdbParams: doesn't seem useful, won't wrap for now
+
 PQconnectdb = pq.PQconnectdb
 PQconnectdb.argtypes = [c_char_p]
 PQconnectdb.restype = PGconn_ptr
 
+# PQsetdbLogin: not useful
+# PQsetdb: not useful
+
+# PQconnectStartParams: not useful
+
+PQconnectStart = pq.PQconnectStart
+PQconnectStart.argtypes = [c_char_p]
+PQconnectStart.restype = PGconn_ptr
+
+PQconnectPoll = pq.PQconnectPoll
+PQconnectPoll.argtypes = [PGconn_ptr]
+PQconnectPoll.restype = c_int
+
+
+# 33.2. Connection Status Functions
+
 PQstatus = pq.PQstatus
 PQstatus.argtypes = [PGconn_ptr]
 PQstatus.restype = c_int
@@ -29,3 +55,7 @@ PQstatus.restype = c_int
 PQerrorMessage = pq.PQerrorMessage
 PQerrorMessage.argtypes = [PGconn_ptr]
 PQerrorMessage.restype = c_char_p
+
+PQsocket = pq.PQsocket
+PQsocket.argtypes = [PGconn_ptr]
+PQsocket.restype = c_int
index 9322b8ae66a87b2904f2be4112bc1a32d72d9358..db1abbfac841fc66094a09c7cacc15f4b571ea5b 100644 (file)
@@ -8,7 +8,7 @@ implementation.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from .pq_enums import ConnStatus
+from .pq_enums import ConnStatus, PostgresPollingStatus
 
 from . import _pq_ctypes as impl
 
@@ -30,6 +30,21 @@ class PGconn:
         pgconn_ptr = impl.PQconnectdb(conninfo)
         return cls(pgconn_ptr)
 
+    @classmethod
+    def connect_start(cls, conninfo):
+        if isinstance(conninfo, str):
+            conninfo = conninfo.encode("utf8")
+
+        if not isinstance(conninfo, bytes):
+            raise TypeError("bytes expected, got %r instead" % conninfo)
+
+        pgconn_ptr = impl.PQconnectStart(conninfo)
+        return cls(pgconn_ptr)
+
+    def connect_poll(self):
+        rv = impl.PQconnectPoll(self.pgconn_ptr)
+        return PostgresPollingStatus(rv)
+
     @property
     def status(self):
         rv = impl.PQstatus(self.pgconn_ptr)
@@ -39,3 +54,7 @@ class PGconn:
     def error_message(self):
         # TODO: decode
         return impl.PQerrorMessage(self.pgconn_ptr)
+
+    @property
+    def socket(self):
+        return impl.PQsocket(self.pgconn_ptr)
index a1b0dfa39797265202fe4522bd5b8e09c5397ddf..1eaa262ea7c3ccb4ec99dcf217711670d4d9ebba 100644 (file)
@@ -4,9 +4,29 @@ libpq enum definitions for psycopg3
 
 # Copyright (C) 2020 The Psycopg Team
 
-from enum import IntEnum
+from enum import IntEnum, auto
 
 
 class ConnStatus(IntEnum):
     CONNECTION_OK = 0
-    CONNECTION_BAD = 1
+    CONNECTION_BAD = auto()
+
+    CONNECTION_STARTED = auto()
+    CONNECTION_MADE = auto()
+    CONNECTION_AWAITING_RESPONSE = auto()
+    CONNECTION_AUTH_OK = auto()
+    CONNECTION_SETENV = auto()
+    CONNECTION_SSL_STARTUP = auto()
+    CONNECTION_NEEDED = auto()
+    CONNECTION_CHECK_WRITABLE = auto()
+    CONNECTION_CONSUME = auto()
+    CONNECTION_GSS_STARTUP = auto()
+    CONNECTION_CHECK_TARGET = auto()
+
+
+class PostgresPollingStatus(IntEnum):
+    PGRES_POLLING_FAILED = 0
+    PGRES_POLLING_READING = auto()
+    PGRES_POLLING_WRITING = auto()
+    PGRES_POLLING_OK = auto()
+    PGRES_POLLING_ACTIVE = auto()
index 16a8babf2db04d928536f352db1f92f757202249..e6f3169a636ea0cd579a332d2140e0ae538fb0fe 100644 (file)
@@ -1,22 +1,60 @@
+from select import select
+
 import pytest
 
+from psycopg3.pq_enums import ConnStatus, PostgresPollingStatus
+
 
-def test_PQconnectdb(pq, dsn):
+def test_connectdb(pq, dsn):
     conn = pq.PGconn.connectdb(dsn)
-    assert conn.status == pq.ConnStatus.CONNECTION_OK, conn.error_message
+    assert conn.status == ConnStatus.CONNECTION_OK, conn.error_message
 
 
-def test_PQconnectdb_bytes(pq, dsn):
+def test_connectdb_bytes(pq, dsn):
     conn = pq.PGconn.connectdb(dsn.encode("utf8"))
-    assert conn.status == pq.ConnStatus.CONNECTION_OK, conn.error_message
+    assert conn.status == ConnStatus.CONNECTION_OK, conn.error_message
 
 
-def test_PQconnectdb_error(pq):
+def test_connectdb_error(pq):
     conn = pq.PGconn.connectdb("dbname=psycopg3_test_not_for_real")
-    assert conn.status == pq.ConnStatus.CONNECTION_BAD
+    assert conn.status == ConnStatus.CONNECTION_BAD
 
 
 @pytest.mark.parametrize("baddsn", [None, 42])
-def test_PQconnectdb_badtype(pq, baddsn):
+def test_connectdb_badtype(pq, baddsn):
     with pytest.raises(TypeError):
         pq.PGconn.connectdb(baddsn)
+
+
+def test_connect_async(pq, dsn):
+    conn = pq.PGconn.connect_start(dsn)
+    while 1:
+        assert conn.status != ConnStatus.CONNECTION_BAD
+        rv = conn.connect_poll()
+        if rv == PostgresPollingStatus.PGRES_POLLING_OK:
+            break
+        elif rv == PostgresPollingStatus.PGRES_POLLING_READING:
+            select([conn.socket], [], [])
+        elif rv == PostgresPollingStatus.PGRES_POLLING_WRITING:
+            select([], [conn.socket], [])
+        else:
+            assert False, rv
+
+    assert conn.status == ConnStatus.CONNECTION_OK
+
+
+def test_connect_async_bad(pq, dsn):
+    conn = pq.PGconn.connect_start("dbname=psycopg3_test_not_for_real")
+    while 1:
+        assert conn.status != ConnStatus.CONNECTION_BAD
+        rv = conn.connect_poll()
+        if rv == PostgresPollingStatus.PGRES_POLLING_FAILED:
+            break
+        elif rv == PostgresPollingStatus.PGRES_POLLING_READING:
+            select([conn.socket], [], [])
+        elif rv == PostgresPollingStatus.PGRES_POLLING_WRITING:
+            select([], [conn.socket], [])
+        else:
+            assert False, rv
+
+    assert conn.status == ConnStatus.CONNECTION_BAD