]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add compat module for python versions compatibility
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Mar 2021 04:02:29 +0000 (05:02 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/utils/compat.py [new file with mode: 0644]
tests/pool/test_pool_async.py
tests/test_concurrency_async.py

index c0cd4d555dd69ec80f423a43236280eb8123e573..7cdb90f81d17f9d7b633e9ecf3b1bafd263a829c 100644 (file)
@@ -4,7 +4,6 @@ psycopg3 connection objects
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
-import sys
 import asyncio
 import logging
 import warnings
@@ -16,11 +15,6 @@ from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
 
-if sys.version_info >= (3, 7):
-    from contextlib import asynccontextmanager
-else:
-    from .utils.context import asynccontextmanager
-
 from . import pq
 from . import adapt
 from . import errors as e
@@ -34,9 +28,10 @@ from .proto import AdaptContext, ConnectionType
 from .cursor import Cursor, AsyncCursor
 from .conninfo import make_conninfo
 from .generators import notifies
+from ._preparing import PrepareManager
 from .transaction import Transaction, AsyncTransaction
+from .utils.compat import asynccontextmanager
 from .server_cursor import ServerCursor, AsyncServerCursor
-from ._preparing import PrepareManager
 
 logger = logging.getLogger(__name__)
 package_logger = logging.getLogger("psycopg3")
index 9fe7846a394dc4df0129a0a8dd9fa11a61df25fa..28e4be484af1f0a964bc51901f8c7d63eb98328c 100644 (file)
@@ -20,6 +20,7 @@ from .pq import ExecStatus
 from .adapt import Format
 from .proto import ConnectionType, PQGen, Transformer
 from .generators import copy_from, copy_to, copy_end
+from .utils.compat import create_task
 
 if TYPE_CHECKING:
     from .pq.proto import PGresult
@@ -359,8 +360,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
             return
 
         if not self._worker:
-            # TODO: can be asyncio.create_task once Python 3.6 is dropped
-            self._worker = asyncio.ensure_future(self.worker())
+            self._worker = create_task(self.worker())
 
         await self._queue.put(data)
 
index e33640829a4cd97fbdcfc8ecb49ac9e622926d39..a6897fc9850cd9a6c9863757fbb488fd2a1478b8 100644 (file)
@@ -23,11 +23,7 @@ from .proto import Row, RowFactory
 from ._column import Column
 from ._queries import PostgresQuery
 from ._preparing import Prepare
-
-if sys.version_info >= (3, 7):
-    from contextlib import asynccontextmanager
-else:
-    from .utils.context import asynccontextmanager
+from .utils.compat import asynccontextmanager
 
 if TYPE_CHECKING:
     from .proto import Transformer
index c3494dd4154230985208b432fd2ce5f32063141f..7adcd888d4eed43d76c02e6a8bc879c32a7ee1d0 100644 (file)
@@ -4,7 +4,6 @@ psycopg3 synchronous connection pool
 
 # Copyright (C) 2021 The Psycopg Team
 
-import sys
 import asyncio
 import logging
 from time import monotonic
@@ -15,21 +14,12 @@ from collections import deque
 
 from ..pq import TransactionStatus
 from ..connection import AsyncConnection
+from ..utils.compat import asynccontextmanager, get_running_loop
 
 from . import tasks
 from .base import ConnectionAttempt, BasePool
 from .errors import PoolClosed, PoolTimeout
 
-if sys.version_info >= (3, 7):
-    from contextlib import asynccontextmanager
-
-    get_running_loop = asyncio.get_running_loop
-
-else:
-    from ..utils.context import asynccontextmanager
-
-    get_running_loop = asyncio.get_event_loop
-
 logger = logging.getLogger(__name__)
 
 
diff --git a/psycopg3/psycopg3/utils/compat.py b/psycopg3/psycopg3/utils/compat.py
new file mode 100644 (file)
index 0000000..57ae8d2
--- /dev/null
@@ -0,0 +1,44 @@
+"""
+compatibility functions for different Python versions
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import sys
+import asyncio
+from typing import Any, Awaitable, Generator, Optional, Union, TypeVar
+
+T = TypeVar("T")
+FutureT = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
+
+if sys.version_info >= (3, 7):
+    from contextlib import asynccontextmanager
+
+    get_running_loop = asyncio.get_running_loop
+
+else:
+    from .context import asynccontextmanager
+
+    get_running_loop = asyncio.get_event_loop
+
+
+if sys.version_info >= (3, 8):
+    create_task = asyncio.create_task
+
+elif sys.version_info >= (3, 7):
+
+    def create_task(
+        coro: FutureT[T], name: Optional[str] = None
+    ) -> "asyncio.Future[T]":
+        return asyncio.create_task(coro)
+
+
+else:
+
+    def create_task(
+        coro: FutureT[T], name: Optional[str] = None
+    ) -> "asyncio.Future[T]":
+        return asyncio.ensure_future(coro)
+
+
+__all__ = ["asynccontextmanager", "get_running_loop", "create_task"]
index b0c59db7fcab7a0b6fc71222e2d907025c16c4b9..3fad68414f22b67d5788862cd3b572fe4460abd7 100644 (file)
@@ -1,4 +1,3 @@
-import sys
 import asyncio
 import logging
 import weakref
@@ -10,12 +9,7 @@ import pytest
 import psycopg3
 from psycopg3 import pool
 from psycopg3.pq import TransactionStatus
-
-create_task = (
-    asyncio.create_task
-    if sys.version_info >= (3, 7)
-    else asyncio.ensure_future
-)
+from psycopg3.utils.compat import create_task
 
 pytestmark = pytest.mark.asyncio
 
index c6fbe4c9010a905688c70165a7be123b386bd151..c7fc960bb29d118dc873d2ccf1735ffda27b42b5 100644 (file)
@@ -4,6 +4,7 @@ import asyncio
 from asyncio.queues import Queue
 
 import psycopg3
+from psycopg3.utils.compat import create_task
 
 pytestmark = pytest.mark.asyncio
 
@@ -144,7 +145,7 @@ async def test_identify_closure(aconn, dsn):
     ev = asyncio.Event()
     loop = asyncio.get_event_loop()
     loop.add_reader(aconn.fileno(), ev.set)
-    asyncio.ensure_future(closer())
+    create_task(closer())
 
     await asyncio.wait_for(ev.wait(), 1.0)
     with pytest.raises(psycopg3.OperationalError):