]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor netrc handling (#558)
authorTom Christie <tom@tomchristie.com>
Thu, 28 Nov 2019 12:31:15 +0000 (12:31 +0000)
committerGitHub <noreply@github.com>
Thu, 28 Nov 2019 12:31:15 +0000 (12:31 +0000)
* Refactor netrc handling

* Linting

* Import sorting

* Import sorting

httpx/client.py
httpx/utils.py
tests/test_utils.py

index cdcbc51f8e6c42ef5e4fd15667d9a1203e8cf327..30e2181e4dd41aa4268961c0177c2d4a4d3d12e3 100644 (file)
@@ -1,5 +1,4 @@
 import functools
-import netrc
 import typing
 from types import TracebackType
 
@@ -47,7 +46,7 @@ from .models import (
     URLTypes,
 )
 from .status_codes import codes
-from .utils import ElapsedTimer, get_environment_proxies, get_logger, get_netrc
+from .utils import ElapsedTimer, NetRCInfo, get_environment_proxies, get_logger
 
 logger = get_logger(__name__)
 
@@ -158,6 +157,7 @@ class Client:
         self.trust_env = trust_env
         self.dispatch = dispatch
         self.concurrency_backend = backend
+        self.netrc = NetRCInfo()
 
         if proxies is None and trust_env:
             proxies = typing.cast(ProxiesTypes, get_environment_proxies())
@@ -385,13 +385,10 @@ class Client:
             return auth(request)
 
         if trust_env:
-            netrc_info = self._get_netrc()
-            if netrc_info is not None:
-                netrc_login = netrc_info.authenticators(request.url.authority)
-                netrc_username, _, netrc_password = netrc_login or ("", None, None)
-                if netrc_password is not None:
-                    auth = BasicAuth(username=netrc_username, password=netrc_password)
-                    return auth(request)
+            credentials = self.netrc.get_credentials(request.url.authority)
+            if credentials is not None:
+                auth = BasicAuth(username=credentials[0], password=credentials[1])
+                return auth(request)
 
         return request
 
@@ -563,10 +560,6 @@ class Client:
 
         return response
 
-    @functools.lru_cache(1)
-    def _get_netrc(self) -> typing.Optional[netrc.netrc]:
-        return get_netrc()
-
     def _dispatcher_for_request(
         self, request: Request, proxies: typing.Dict[str, Dispatcher]
     ) -> Dispatcher:
index 6516af6b4ec5fc5b2a52fb742f462ef884a351eb..07277a34b532a44db6ec84d7acbf3746d04432be 100644 (file)
@@ -96,22 +96,33 @@ def guess_json_utf(data: bytes) -> typing.Optional[str]:
     return None
 
 
-NETRC_STATIC_FILES = (Path("~/.netrc"), Path("~/_netrc"))
+class NetRCInfo:
+    def __init__(self, files: typing.Optional[typing.List[str]] = None) -> None:
+        if files is None:
+            files = [os.getenv("NETRC", ""), "~/.netrc", "~/_netrc"]
+        self.netrc_files = files
 
-
-def get_netrc() -> typing.Optional[netrc.netrc]:
-    NETRC_FILES = (Path(os.getenv("NETRC", "")),) + NETRC_STATIC_FILES
-    netrc_path = None
-
-    for file_path in NETRC_FILES:
-        expanded_path = file_path.expanduser()
-        if expanded_path.is_file():
-            netrc_path = expanded_path
-            break
-
-    if netrc_path is None:
-        return None
-    return netrc.netrc(str(netrc_path))
+    @property
+    def netrc_info(self) -> typing.Optional[netrc.netrc]:
+        if not hasattr(self, "_netrc_info"):
+            self._netrc_info = None
+            for file_path in self.netrc_files:
+                expanded_path = Path(file_path).expanduser()
+                if expanded_path.is_file():
+                    self._netrc_info = netrc.netrc(str(expanded_path))
+                    break
+        return self._netrc_info
+
+    def get_credentials(
+        self, authority: str
+    ) -> typing.Optional[typing.Tuple[str, str]]:
+        if self.netrc_info is None:
+            return None
+
+        auth_info = self.netrc_info.authenticators(authority)
+        if auth_info is None or auth_info[2] is None:
+            return None
+        return (auth_info[0], auth_info[2])
 
 
 def get_ca_bundle_from_env() -> typing.Optional[str]:
@@ -182,7 +193,7 @@ TRACE_LOG_LEVEL = 5
 class Logger(logging.Logger):
     # Stub for type checkers.
     def trace(self, message: str, *args: typing.Any, **kwargs: typing.Any) -> None:
-        ...
+        ...  # pragma: nocover
 
 
 def get_logger(name: str) -> Logger:
index e1f4ff785b00f4879af644d900ea6243a8bd21b2..c1cf93fa8585bd7f37659320bf8c4d2c6e0f9538 100644 (file)
@@ -6,9 +6,9 @@ import pytest
 import httpx
 from httpx.utils import (
     ElapsedTimer,
+    NetRCInfo,
     get_ca_bundle_from_env,
     get_environment_proxies,
-    get_netrc,
     guess_json_utf,
     obfuscate_sensitive_headers,
     parse_header_links,
@@ -54,28 +54,17 @@ def test_guess_by_bom(encoding, expected):
 
 
 def test_bad_get_netrc_login():
-    os.environ["NETRC"] = "tests/.netrc"
-    assert str(get_netrc()) is not None
-
-    from httpx import utils
-
-    utils.NETRC_STATIC_FILES = ()
-
-    os.environ["NETRC"] = "wrongpath"
-    assert utils.get_netrc() is None
-
-    os.environ["NETRC"] = ""
-    assert utils.get_netrc() is None
+    netrc_info = NetRCInfo(["tests/does-not-exist"])
+    assert netrc_info.get_credentials("netrcexample.org") is None
 
 
 def test_get_netrc_login():
-    os.environ["NETRC"] = "tests/.netrc"
-    netrc = get_netrc()
-    assert netrc.authenticators("netrcexample.org") == (
+    netrc_info = NetRCInfo(["tests/.netrc"])
+    expected_credentials = (
         "example-username",
-        None,
         "example-password",
     )
+    assert netrc_info.get_credentials("netrcexample.org") == expected_credentials
 
 
 @pytest.mark.parametrize(