]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Tighten multipart implementation types (#975)
authorFlorimond Manca <florimond.manca@gmail.com>
Thu, 21 May 2020 15:41:36 +0000 (17:41 +0200)
committerGitHub <noreply@github.com>
Thu, 21 May 2020 15:41:36 +0000 (17:41 +0200)
httpx/_content_streams.py
setup.cfg

index 2484aa44a6a9c16b962f1aaa43bbc0acfbc33134..6436be563a0b26c00c8135e814b9cd8235bc7f29 100644 (file)
@@ -8,7 +8,7 @@ from urllib.parse import urlencode
 import httpcore
 
 from ._exceptions import StreamConsumed
-from ._types import RequestData, RequestFiles
+from ._types import FileContent, FileTypes, RequestData, RequestFiles
 from ._utils import (
     format_form_param,
     guess_content_type,
@@ -227,22 +227,25 @@ class MultipartStream(ContentStream):
         A single file field item, within a multipart form field.
         """
 
-        def __init__(
-            self,
-            name: str,
-            value: typing.Union[typing.IO[str], typing.IO[bytes], tuple],
-        ) -> None:
+        def __init__(self, name: str, value: FileTypes) -> None:
             self.name = name
-            if not isinstance(value, tuple):
-                self.filename = Path(str(getattr(value, "name", "upload"))).name
-                self.file: typing.Union[typing.IO[str], typing.IO[bytes]] = value
-                self.content_type = guess_content_type(self.filename)
+
+            fileobj: FileContent
+
+            if isinstance(value, tuple):
+                try:
+                    filename, fileobj, content_type = value  # type: ignore
+                except ValueError:
+                    filename, fileobj = value  # type: ignore
+                    content_type = guess_content_type(filename)
             else:
-                self.filename = value[0]
-                self.file = value[1]
-                self.content_type = (
-                    value[2] if len(value) > 2 else guess_content_type(self.filename)
-                )
+                filename = Path(str(getattr(value, "name", "upload"))).name
+                fileobj = value
+                content_type = guess_content_type(filename)
+
+            self.filename = filename
+            self.file = fileobj
+            self.content_type = content_type
 
         def get_length(self) -> int:
             headers = self.render_headers()
@@ -304,7 +307,7 @@ class MultipartStream(ContentStream):
             yield from self.render_data()
 
     def __init__(
-        self, data: typing.Mapping, files: typing.Mapping, boundary: bytes = None
+        self, data: typing.Mapping, files: RequestFiles, boundary: bytes = None
     ) -> None:
         if boundary is None:
             boundary = binascii.hexlify(os.urandom(16))
@@ -316,7 +319,7 @@ class MultipartStream(ContentStream):
         self.fields = list(self._iter_fields(data, files))
 
     def _iter_fields(
-        self, data: typing.Mapping, files: typing.Mapping
+        self, data: typing.Mapping, files: RequestFiles
     ) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
         for name, value in data.items():
             if isinstance(value, list):
index 61e27214d86271809aed225d08f707503af211ec..cdc17a44f0aeccdcd8bb6df4cacd7a4df9d9dedb 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -11,7 +11,7 @@ combine_as_imports = True
 force_grid_wrap = 0
 include_trailing_comma = True
 known_first_party = httpx,tests
-known_third_party = brotli,certifi,chardet,cryptography,hstspreload,httpcore,pytest,rfc3986,setuptools,sniffio,trio,trustme,urllib3,uvicorn
+known_third_party = brotli,certifi,chardet,cryptography,hstspreload,httpcore,pytest,rfc3986,setuptools,sniffio,trio,trustme,uvicorn
 line_length = 88
 multi_line_output = 3