]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop `request.timer` attribute. (#1249)
authorTom Christie <tom@tomchristie.com>
Mon, 7 Sep 2020 08:06:14 +0000 (09:06 +0100)
committerGitHub <noreply@github.com>
Mon, 7 Sep 2020 08:06:14 +0000 (09:06 +0100)
* Drop request.timer attribute
* Response(..., elapsed_func=...)

httpx/_client.py
httpx/_models.py
httpx/_utils.py
tests/models/test_responses.py
tests/test_utils.py

index d6a0caf0855b5ab097ca0498234947122cc54b55..0b67a78dddf3f0847abdad05d724757efdcfacf6 100644 (file)
@@ -47,6 +47,7 @@ from ._types import (
 )
 from ._utils import (
     NetRCInfo,
+    Timer,
     URLPattern,
     get_environment_proxies,
     get_logger,
@@ -811,6 +812,8 @@ class Client(BaseClient):
         Sends a single request, without handling any redirections.
         """
         transport = self._transport_for_url(request.url)
+        timer = Timer()
+        timer.sync_start()
 
         with map_exceptions(HTTPCORE_EXC_MAP, request=request):
             (
@@ -832,6 +835,7 @@ class Client(BaseClient):
             headers=headers,
             stream=stream,  # type: ignore
             request=request,
+            elapsed_func=timer.sync_elapsed,
         )
 
         self.cookies.extract_cookies(response)
@@ -1434,6 +1438,8 @@ class AsyncClient(BaseClient):
         Sends a single request, without handling any redirections.
         """
         transport = self._transport_for_url(request.url)
+        timer = Timer()
+        await timer.async_start()
 
         with map_exceptions(HTTPCORE_EXC_MAP, request=request):
             (
@@ -1455,6 +1461,7 @@ class AsyncClient(BaseClient):
             headers=headers,
             stream=stream,  # type: ignore
             request=request,
+            elapsed_func=timer.async_elapsed,
         )
 
         self.cookies.extract_cookies(response)
index 5b6a9b65710d8e3692fdbf2eceedd07987357dee..713281e662de0682b5eddc5a5400598a0a7556d8 100644 (file)
@@ -47,7 +47,6 @@ from ._types import (
     URLTypes,
 )
 from ._utils import (
-    ElapsedTimer,
     flatten_queryparams,
     guess_json_utf,
     is_known_encoding,
@@ -606,7 +605,6 @@ class Request:
         else:
             self.stream = encode(data, files, json)
 
-        self.timer = ElapsedTimer()
         self.prepare()
 
     def prepare(self) -> None:
@@ -678,6 +676,7 @@ class Response:
         stream: ContentStream = None,
         content: bytes = None,
         history: typing.List["Response"] = None,
+        elapsed_func: typing.Callable = None,
     ):
         self.status_code = status_code
         self.http_version = http_version
@@ -688,6 +687,7 @@ class Response:
         self.call_next: typing.Optional[typing.Callable] = None
 
         self.history = [] if history is None else list(history)
+        self._elapsed_func = elapsed_func
 
         self.is_closed = False
         self.is_stream_consumed = False
@@ -708,7 +708,7 @@ class Response:
                 "'.elapsed' may only be accessed after the response "
                 "has been read or closed."
             )
-        return self._elapsed
+        return datetime.timedelta(seconds=self._elapsed)
 
     @property
     def request(self) -> Request:
@@ -976,8 +976,8 @@ class Response:
         """
         if not self.is_closed:
             self.is_closed = True
-            if self._request is not None:
-                self._elapsed = self.request.timer.elapsed
+            if self._elapsed_func is not None:
+                self._elapsed = self._elapsed_func()
             self._raw_stream.close()
 
     async def aread(self) -> bytes:
@@ -1056,8 +1056,8 @@ class Response:
         """
         if not self.is_closed:
             self.is_closed = True
-            if self._request is not None:
-                self._elapsed = self.request.timer.elapsed
+            if self._elapsed_func is not None:
+                self._elapsed = await self._elapsed_func()
             await self._raw_stream.aclose()
 
 
index 8080f63a4663587cfd24b219f9f82a08dcd3f562..aa670724cbadfffb166212816990e252b74ba7dc 100644 (file)
@@ -6,14 +6,14 @@ import netrc
 import os
 import re
 import sys
+import time
 import typing
 import warnings
-from datetime import timedelta
 from pathlib import Path
-from time import perf_counter
-from types import TracebackType
 from urllib.request import getproxies
 
+import sniffio
+
 from ._types import PrimitiveData
 
 if typing.TYPE_CHECKING:  # pragma: no cover
@@ -392,28 +392,35 @@ def flatten_queryparams(
     return items
 
 
-class ElapsedTimer:
-    def __init__(self) -> None:
-        self.start: float = perf_counter()
-        self.end: typing.Optional[float] = None
+class Timer:
+    async def _get_time(self) -> float:
+        library = sniffio.current_async_library()
+        if library == "trio":
+            import trio
 
-    def __enter__(self) -> "ElapsedTimer":
-        self.start = perf_counter()
-        return self
+            return trio.current_time()
+        elif library == "curio":  # pragma: nocover
+            import curio
 
-    def __exit__(
-        self,
-        exc_type: typing.Type[BaseException] = None,
-        exc_value: BaseException = None,
-        traceback: TracebackType = None,
-    ) -> None:
-        self.end = perf_counter()
+            return await curio.clock()
 
-    @property
-    def elapsed(self) -> timedelta:
-        if self.end is None:
-            return timedelta(seconds=perf_counter() - self.start)
-        return timedelta(seconds=self.end - self.start)
+        import asyncio
+
+        return asyncio.get_event_loop().time()
+
+    def sync_start(self) -> None:
+        self.started = time.perf_counter()
+
+    async def async_start(self) -> None:
+        self.started = await self._get_time()
+
+    def sync_elapsed(self) -> float:
+        now = time.perf_counter()
+        return now - self.started
+
+    async def async_elapsed(self) -> float:
+        now = await self._get_time()
+        return now - self.started
 
 
 class URLPattern:
index 32163a6fc830e99eec608ce5006e01cb16f60540..2b07a2704025ce70679cb409fb048a98b7856607 100644 (file)
@@ -1,4 +1,3 @@
-import datetime
 import json
 from unittest import mock
 
@@ -31,7 +30,6 @@ def test_response():
     assert response.text == "Hello, world!"
     assert response.request.method == "GET"
     assert response.request.url == "https://example.org"
-    assert response.elapsed >= datetime.timedelta(0)
     assert not response.is_error
 
 
index d5dfb5819b6017e6d1b13ac5822aa28460b121b8..ae4b3aa96c060fb3656f8845e4d0dae90c72512f 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import os
 import random
 
@@ -6,7 +5,6 @@ import pytest
 
 import httpx
 from httpx._utils import (
-    ElapsedTimer,
     NetRCInfo,
     URLPattern,
     get_ca_bundle_from_env,
@@ -177,17 +175,6 @@ def test_get_ssl_cert_file():
     assert get_ca_bundle_from_env() is None
 
 
-@pytest.mark.asyncio
-async def test_elapsed_timer():
-    with ElapsedTimer() as timer:
-        assert timer.elapsed.total_seconds() == pytest.approx(0, abs=0.05)
-        await asyncio.sleep(0.1)
-    await asyncio.sleep(
-        0.1
-    )  # test to ensure time spent after timer exits isn't accounted for.
-    assert timer.elapsed.total_seconds() == pytest.approx(0.1, abs=0.05)
-
-
 @pytest.mark.parametrize(
     ["environment", "proxies"],
     [