From: Tom Christie Date: Tue, 16 Sep 2025 17:59:11 +0000 (+0100) Subject: Initial commit X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=77589cc6f7b0b395f29df49c767cda6d144812bc;p=thirdparty%2Fhttpx.git Initial commit --- 77589cc6f7b0b395f29df49c767cda6d144812bc diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..3ba13e0c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false diff --git a/.github/ISSUE_TEMPLATE/read-only-issues.md b/.github/ISSUE_TEMPLATE/read-only-issues.md new file mode 100644 index 00000000..2ea56183 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/read-only-issues.md @@ -0,0 +1,10 @@ +--- +name: Read-only issues +about: Restricted Zone ⛔️ +title: '' +labels: '' +assignees: '' + +--- + +Issues on this repository are considered read-only, and currently reserved for the maintenance team. diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml new file mode 100644 index 00000000..7fed9368 --- /dev/null +++ b/.github/workflows/test-suite.yml @@ -0,0 +1,28 @@ +--- +name: Test Suite + +on: + push: + branches: ["dev"] + pull_request: + branches: ["dev", "version-*"] + +jobs: + tests: + name: "Python ${{ matrix.python-version }}" + runs-on: "ubuntu-latest" + + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + + steps: + - uses: "actions/checkout@v4" + - uses: "actions/setup-python@v5" + with: + python-version: "${{ matrix.python-version }}" + allow-prereleases: true + - name: "Install dependencies" + run: "scripts/install" + - name: "Run tests" + run: "scripts/test" diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f9d43a11 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +*.pyc +.coverage +.mypy_cache/ +.pytest_cache/ +__pycache__/ +dist/ +venv/ +build/ diff --git a/README.md b/README.md new file mode 100644 index 00000000..f6f99268 --- /dev/null +++ b/README.md @@ -0,0 +1,72 @@ +

+ HTTPX +

+ +

HTTPX 1.0 — Design proposal.

+ +--- + +A complete HTTP framework for Python. + +*Installation...* + +```shell +$ pip install --pre httpx +``` + +*Making requests as a client...* + +```python +>>> r = httpx.get('https://www.example.org/') +>>> r + +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'\n\n\nExample Domain...' +``` + +*Serving responses as the server...* + +```python +>>> def app(request): +... content = httpx.HTML('hello, world.') +... return httpx.Response(200, content=content) + +>>> httpx.run(app) +Serving on http://127.0.0.1:8080/ (Press CTRL+C to quit) +``` + +--- + +# Documentation + +The [HTTPX 1.0 design proposal](https://www.encode.io/httpnext/) is now available. + +* [Quickstart](https://www.encode.io/httpnext/quickstart) +* [Clients](https://www.encode.io/httpnext/clients) +* [Servers](https://www.encode.io/httpnext/servers) +* [Requests](https://www.encode.io/httpnext/requests) +* [Responses](https://www.encode.io/httpnext/responses) +* [URLs](https://www.encode.io/httpnext/urls) +* [Headers](https://www.encode.io/httpnext/headers) +* [Content Types](https://www.encode.io/httpnext/content-types) +* [Connections](https://www.encode.io/httpnext/connections) +* [Parsers](https://www.encode.io/httpnext/parsers) +* [Network Backends](https://www.encode.io/httpnext/networking) + +--- + +# Collaboration + +The repository for this project is currently private. + +We’re looking at creating paid opportunities for working on open source software *which are properly compensated, flexible & well balanced.* + +If you're interested in a position working on this project, please send an intro. + +--- + +

This provisional design work is not currently licensed for reuse.
Designed & crafted with care.

— 🦋 —

diff --git a/docs/about.md b/docs/about.md new file mode 100644 index 00000000..46d75648 --- /dev/null +++ b/docs/about.md @@ -0,0 +1,19 @@ +# About + +This work is a design proposal for an `httpx` 1.0 release. + +--- + +## Sponsorship + +We are currently seeking forward-looking investment that recognises the value of the infrastructure development on it's own merit. Sponsorships may be [made through GitHub](https://github.com/encode). + +We do not offer equity, placements, or endorsments. + +## License + +The rights of the author have been asserted. + +--- + +

home

diff --git a/docs/clients.md b/docs/clients.md new file mode 100644 index 00000000..7de41615 --- /dev/null +++ b/docs/clients.md @@ -0,0 +1,311 @@ +# Clients + +HTTP requests are sent by using a `Client` instance. Client instances are thread safe interfaces that maintain a pool of HTTP connections. + + + +```{ .python .httpx } +>>> cli = httpx.Client() +>>> cli + +``` + +```{ .python .ahttpx .hidden } +>>> cli = ahttpx.Client() +>>> cli + +``` + +The client representation provides an indication of how many connections are currently in the pool. + + + +```{ .python .httpx } +>>> r = cli.get("https://www.example.com") +>>> r = cli.get("https://www.wikipedia.com") +>>> r = cli.get("https://www.theguardian.com/uk") +>>> cli + +``` + +```{ .python .ahttpx .hidden } +>>> r = await cli.get("https://www.example.com") +>>> r = await cli.get("https://www.wikipedia.com") +>>> r = await cli.get("https://www.theguardian.com/uk") +>>> cli + +``` + +The connections in the pool can be explicitly closed, using the `close()` method... + + + +```{ .python .httpx } +>>> cli.close() +>>> cli + +``` + +```{ .python .ahttpx .hidden } +>>> await cli.close() +>>> cli + +``` + +Client instances support being used in a context managed scope. You can use this style to enforce properly scoped resources, ensuring that the connection pool is cleanly closed when no longer required. + + + +```{ .python .httpx } +>>> with httpx.Client() as cli: +... r = cli.get("https://www.example.com") +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.Client() as cli: +... r = await cli.get("https://www.example.com") +``` + +It is important to scope the use of client instances as widely as possible. + +Typically you should have a single client instance that is used throughout the lifespan of your application. This ensures that connection pooling is maximised, and minmises unneccessary reloading of SSL certificate stores. + +The recommened usage is to *either* a have single global instance created at import time, *or* a single context scoped instance that is passed around wherever it is required. + +## Setting a base URL + +Client instances can be configured with a base URL that is used when constructing requests... + + + +```{ .python .httpx } +>>> with httpx.Client(url="https://www.httpbin.org") as cli: +>>> r = cli.get("/json") +>>> print(r) + +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.Client(url="https://www.httpbin.org") as cli: +>>> r = cli.get("/json") +>>> print(r) + +``` + +## Setting client headers + +Client instances include a set of headers that are used on every outgoing request. + +The default headers are: + +* `Accept: */*` - Indicates to servers that any media type may be returned. +* `Accept-Encoding: gzip` - Indicates to servers that gzip compression may be used on responses. +* `Connection: keep-alive` - Indicates that HTTP/1.1 connections should be reused over multiple requests. +* `User-Agent: python-httpx/1.0` - Identify the client as `httpx`. + +You can override this behavior by explicitly specifying the default headers... + + + +```{ .python .httpx } +>>> headers = {"User-Agent": "dev", "Accept-Encoding": "gzip"} +>>> with httpx.Client(headers=headers) as cli: +>>> r = cli.get("https://www.example.com/") +``` + +```{ .python .ahttpx .hidden } +>>> headers = {"User-Agent": "dev", "Accept-Encoding": "gzip"} +>>> async with ahttpx.Client(headers=headers) as cli: +>>> r = await cli.get("https://www.example.com/") +``` + +## Configuring the connection pool + +The connection pool used by the client can be configured in order to customise the SSL context, the maximum number of concurrent connections, or the network backend. + + + +```{ .python .httpx } +>>> # Setup an SSL context to allow connecting to improperly configured SSL. +>>> no_verify = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +>>> no_verify.check_hostname = False +>>> no_verify.verify_mode = ssl.CERT_NONE +>>> # Instantiate a client with our custom SSL context. +>>> pool = httpx.ConnectionPool(ssl_context=no_verify) +>>> with httpx.Client(transport=pool) as cli: +>>> ... +``` + +```{ .python .ahttpx .hidden } +>>> # Setup an SSL context to allow connecting to improperly configured SSL. +>>> no_verify = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +>>> no_verify.check_hostname = False +>>> no_verify.verify_mode = ssl.CERT_NONE +>>> # Instantiate a client with our custom SSL context. +>>> pool = ahttpx.ConnectionPool(ssl_context=no_verify) +>>> async with ahttpx.Client(transport=pool) as cli: +>>> ... +``` + +## Sending requests + +* `.request()` - Send an HTTP request, reading the response to completion. +* `.stream()` - Send an HTTP request, streaming the response. + +Shortcut methods... + +* `.get()` - Send an HTTP `GET` request. +* `.post()` - Send an HTTP `POST` request. +* `.put()` - Send an HTTP `PUT` request. +* `.delete()` - Send an HTTP `DELETE` request. + +--- + +## Transports + +By default requests are sent using the `ConnectionPool` class. Alternative implementations for sending requests can be created by subclassing the `Transport` interface. + +For example, a mock transport class that doesn't make any network requests and instead always returns a fixed response. + + + +```{ .python .httpx } +class MockTransport(httpx.Transport): + def __init__(self, response): + self._response = response + + @contextlib.contextmanager + def send(self, request): + yield response + + def close(self): + pass + +response = httpx.Response(200, content=httpx.Text('Hello, world')) +transport = MockTransport(response=response) +with httpx.Client(transport=transport) as cli: + r = cli.get('https://www.example.com') + print(r) +``` + +```{ .python .ahttpx .hidden } +class MockTransport(ahttpx.Transport): + def __init__(self, response): + self._response = response + + @contextlib.contextmanager + def send(self, request): + yield response + + def close(self): + pass + +response = ahttpx.Response(200, content=httpx.Text('Hello, world')) +transport = MockTransport(response=response) +async with ahttpx.Client(transport=transport) as cli: + r = await cli.get('https://www.example.com') + print(r) +``` + +--- + +## Middleware + +In addition to maintaining an HTTP connection pool, client instances are responsible for two other pieces of functionality... + +* Dealing with HTTP redirects. +* Maintaining an HTTP cookie store. + +### `RedirectMiddleware` + +Wraps a transport class, adding support for HTTP redirect handling. + +### `CookieMiddleware` + +Wraps a transport class, adding support for HTTP cookie persistence. + +--- + +## Custom client implementations + +The `Client` implementation in `httpx` is intentionally lightweight. + +If you're working with a large codebase you might want to create a custom client implementation in order to constrain the types of request that are sent. + +The following example demonstrates a custom API client that only exposes `GET` and `POST` requests, and always uses JSON payloads. + + + +```{ .python .httpx } +class APIClient: + def __init__(self): + self.url = httpx.URL('https://www.example.com') + self.headers = httpx.Headers({ + 'Accept-Encoding': 'gzip', + 'Connection': 'keep-alive', + 'User-Agent': 'dev' + }) + self.via = httpx.RedirectMiddleware(httpx.ConnectionPool()) + + def get(self, path: str) -> Response: + request = httpx.Request( + method="GET", + url=self.url.join(path), + headers=self.headers, + ) + with self.via.send(request) as response: + response.read() + return response + + def post(self, path: str, payload: Any) -> httpx.Response: + request = httpx.Request( + method="POST", + url=self.url.join(path), + headers=self.headers, + content=httpx.JSON(payload), + ) + with self.via.send(request) as response: + response.read() + return response +``` + +```{ .python .ahttpx .hidden } +class APIClient: + def __init__(self): + self.url = ahttpx.URL('https://www.example.com') + self.headers = ahttpx.Headers({ + 'Accept-Encoding': 'gzip', + 'Connection': 'keep-alive', + 'User-Agent': 'dev' + }) + self.via = ahttpx.RedirectMiddleware(ahttpx.ConnectionPool()) + + async def get(self, path: str) -> Response: + request = ahttpx.Request( + method="GET", + url=self.url.join(path), + headers=self.headers, + ) + async with self.via.send(request) as response: + await response.read() + return response + + async def post(self, path: str, payload: Any) -> ahttpx.Response: + request = ahttpx.Request( + method="POST", + url=self.url.join(path), + headers=self.headers, + content=httpx.JSON(payload), + ) + async with self.via.send(request) as response: + await response.read() + return response +``` + +You can expand on this pattern to provide behavior such as request or response schema validation, consistent timeouts, or standardised logging and exception handling. + +--- + +← [Quickstart](quickstart.md) +[Servers](servers.md) → +  diff --git a/docs/connections.md b/docs/connections.md new file mode 100644 index 00000000..602641a1 --- /dev/null +++ b/docs/connections.md @@ -0,0 +1,245 @@ +# Connections + +The mechanics of sending HTTP requests is dealt with by the `ConnectionPool` and `Connection` classes. + +We can introspect a `Client` instance to get some visibility onto the state of the connection pool. + + + +```{ .python .httpx } +>>> with httpx.Client() as cli +>>> urls = [ +... "https://www.wikipedia.org/", +... "https://www.theguardian.com/", +... "https://news.ycombinator.com/", +... ] +... for url in urls: +... cli.get(url) +... print(cli.transport) +... # +... print(cli.transport.connections) +... # [ +... # , +... # , +... # , +... # ] +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.Client() as cli +>>> urls = [ +... "https://www.wikipedia.org/", +... "https://www.theguardian.com/", +... "https://news.ycombinator.com/", +... ] +... for url in urls: +... await cli.get(url) +... print(cli.transport) +... # +... print(cli.transport.connections) +... # [ +... # , +... # , +... # , +... # ] +``` + +--- + +## Understanding the stack + +The `Client` class is responsible for handling redirects and cookies. + +It also ensures that outgoing requests include a default set of headers such as `User-Agent` and `Accept-Encoding`. + + + +```{ .python .httpx } +>>> with httpx.Client() as cli: +>>> r = cli.request("GET", "https://www.example.com/") +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.Client() as cli: +>>> r = await cli.request("GET", "https://www.example.com/") +``` + +The `Client` class sends requests using a `ConnectionPool`, which is responsible for managing a pool of HTTP connections. This ensures quicker and more efficient use of resources than opening and closing a TCP connection with each request. The connection pool also handles HTTP proxying if required. + +A single connection pool is able to handle multiple concurrent requests, with locking in place to ensure that the pool does not become over-saturated. + + + +```{ .python .httpx } +>>> with httpx.ConnectionPool() as pool: +>>> r = pool.request("GET", "https://www.example.com/") +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.ConnectionPool() as pool: +>>> r = await pool.request("GET", "https://www.example.com/") +``` + +Individual HTTP connections can be managed directly with the `Connection` class. A single connection can only handle requests sequentially. Locking is provided to ensure that requests are strictly queued sequentially. + + + +```{ .python .httpx } +>>> with httpx.open_connection("https://www.example.com/") as conn: +>>> r = conn.request("GET", "/") +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.open_connection("https://www.example.com/") as conn: +>>> r = await conn.request("GET", "/") +``` + +The `NetworkBackend` is responsible for managing the TCP stream, providing a raw byte-wise interface onto the underlying socket. + +--- + +## ConnectionPool + + + +```{ .python .httpx } +>>> pool = httpx.ConnectionPool() +>>> pool + +``` + +```{ .python .ahttpx .hidden } +>>> pool = ahttpx.ConnectionPool() +>>> pool + +``` + +### `.request(method, url, headers=None, content=None)` + + + +```{ .python .httpx } +>>> with httpx.ConnectionPool() as pool: +>>> res = pool.request("GET", "https://www.example.com") +>>> res, pool +, +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.ConnectionPool() as pool: +>>> res = await pool.request("GET", "https://www.example.com") +>>> res, pool +, +``` + +### `.stream(method, url, headers=None, content=None)` + + + +```{ .python .httpx } +>>> with httpx.ConnectionPool() as pool: +>>> with pool.stream("GET", "https://www.example.com") as res: +>>> res, pool +, +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.ConnectionPool() as pool: +>>> async with await pool.stream("GET", "https://www.example.com") as res: +>>> res, pool +, +``` + +### `.send(request)` + + + +```{ .python .httpx } +>>> with httpx.ConnectionPool() as pool: +>>> req = httpx.Request("GET", "https://www.example.com") +>>> with pool.send(req) as res: +>>> res.read() +>>> res, pool +, +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.ConnectionPool() as pool: +>>> req = ahttpx.Request("GET", "https://www.example.com") +>>> async with await pool.send(req) as res: +>>> await res.read() +>>> res, pool +, +``` + +### `.close()` + + + +```{ .python .httpx } +>>> with httpx.ConnectionPool() as pool: +>>> pool.close() + +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.ConnectionPool() as pool: +>>> await pool.close() + +``` + +--- + +## Connection + +*TODO* + +--- + +## Protocol upgrades + + + +```{ .python .httpx } +with httpx.open_connection("https://www.example.com/") as conn: + with conn.upgrade("GET", "/feed", {"Upgrade": "WebSocket"}) as stream: + ... +``` + +```{ .python .ahttpx .hidden } +async with await ahttpx.open_connection("https://www.example.com/") as conn: + async with await conn.upgrade("GET", "/feed", {"Upgrade": "WebSocket"}) as stream: + ... +``` + +`` + +## Proxy `CONNECT` requests + + + +```{ .python .httpx } +with httpx.open_connection("http://127.0.0.1:8080") as conn: + with conn.upgrade("CONNECT", "www.encode.io:443") as stream: + stream.start_tls(ctx, hostname="www.encode.io") + ... +``` + +```{ .python .ahttpx .hidden } +async with await ahttpx.open_connection("http://127.0.0.1:8080") as conn: + async with await conn.upgrade("CONNECT", "www.encode.io:443") as stream: + await stream.start_tls(ctx, hostname="www.encode.io") + ... +``` + +`` + +--- + +*Describe the `Transport` interface.* + +--- + +← [Streams](streams.md) +[Parsers](parsers.md) → +  diff --git a/docs/content-types.md b/docs/content-types.md new file mode 100644 index 00000000..091aa8b8 --- /dev/null +++ b/docs/content-types.md @@ -0,0 +1,174 @@ +# Content Types + +Some HTTP requests including `POST`, `PUT` and `PATCH` can include content in the body of the request. + +The most common content types for upload data are... + +* HTML form submissions use the `application/x-www-form-urlencoded` content type. +* HTML form submissions including file uploads use the `multipart/form-data` content type. +* JSON data uses the `application/json` content type. + +Content can be included directly in a request by using bytes or a byte iterator and setting the appropriate `Content-Type` header. + + + +```{ .python .httpx } +>>> headers = {'Content-Type': 'application/json'} +>>> content = json.dumps({"number": 123.5, "bool": [True, False], "text": "hello"}) +>>> response = cli.put(url, headers=headers, content=content) +``` + +```{ .python .ahttpx .hidden } +>>> headers = {'Content-Type': 'application/json'} +>>> content = json.dumps({"number": 123.5, "bool": [True, False], "text": "hello"}) +>>> response = await cli.put(url, headers=headers, content=content) +``` + +There are also several classes provided for setting the request content. These implement either the `Content` or `StreamingContent` API, and handle constructing the content and setting the relevant headers. + +* `
` +* `` +* `` +* `` +* `` + +For example, sending a JSON request... + + + +```{ .python .httpx } +>>> data = httpx.JSON({"number": 123.5, "bool": [True, False], "text": "hello"}) +>>> cli.post(url, content=data) +``` + +```{ .python .ahttpx .hidden } +>>> data = httpx.JSON({"number": 123.5, "bool": [True, False], "text": "hello"}) +>>> await cli.post(url, content=data) +``` + +--- + +## Form + +The `Form` class provides an immutable multi-dict for accessing HTML form data. This class implements the `Content` interface, allowing for HTML form uploads. + + + +```{ .python .httpx } +>>> form = httpx.Form({'name': '...'}) +>>> form +... +>>> form['name'] +... +>>> res = cli.post(url, content=form) +... +``` + +```{ .python .ahttpx .hidden } +>>> form = httpx.Form({'name': '...'}) +>>> form +... +>>> form['name'] +... +>>> res = await cli.post(url, content=form) +... +``` + +## Files + +The `Files` class provides an immutable multi-dict for accessing HTML form file uploads. This class implements the `StreamingContent` interface, allowing for HTML form file uploads. + + + +```{ .python .httpx } +>>> files = httpx.Files({'upload': httpx.File('data.json')}) +>>> files +... +>>> files['upload'] +... +>>> res = cli.post(url, content=files) +... +``` + +```{ .python .ahttpx .hidden } +>>> files = httpx.Files({'upload': httpx.File('data.json')}) +>>> files +... +>>> files['upload'] +... +>>> res = await cli.post(url, content=files) +... +``` + +## MultiPart + +The `MultiPart` class provides a wrapper for HTML form and files uploads. This class implements the `StreamingContent` interface, allowing for allowing for HTML form uploads including both data and file uploads. + + + +```{ .python .httpx } +>>> multipart = httpx.MultiPart(form={'name': '...'}, files={'avatar': httpx.File('image.png')}) +>>> multipart.form['name'] +... +>>> multipart.files['avatar'] +... +>>> res = cli.post(url, content=multipart) +``` + +```{ .python .ahttpx .hidden } +>>> multipart = httpx.MultiPart(form={'name': '...'}, files={'avatar': httpx.File('image.png')}) +>>> multipart.form['name'] +... +>>> multipart.files['avatar'] +... +>>> res = await cli.post(url, content=multipart) +``` + +## File + +The `File` class provides a wrapper for file uploads, and is used for uploads instead of passing a file object directly. + + + +```{ .python .httpx } +>>> file = httpx.File('upload.json') +>>> cli.post(url, content=file) +``` + +```{ .python .ahttpx .hidden } +>>> file = httpx.File('upload.json') +>>> await cli.post(url, content=file) +``` + +## JSON + +The `JSON` class provides a wrapper for JSON uploads. This class implements the `Content` interface, allowing for HTTP JSON uploads. + + + +```{ .python .httpx } +>>> data = httpx.JSON({...}) +>>> cli.put(url, content=data) +``` + +```{ .python .ahttpx .hidden } +>>> data = httpx.JSON({...}) +>>> await cli.put(url, content=data) +``` + +--- + +## Content + +An interface for constructing HTTP content, along with relevant headers. + +The following method must be implemented... + +* `.encode()` - Returns an `httx.Stream` representing the encoded data. +* `.content_type()` - Returns a `str` indicating the content type. + +--- + +← [Headers](headers.md) +[Streams](streams.md) → +  diff --git a/docs/headers.md b/docs/headers.md new file mode 100644 index 00000000..3b84e270 --- /dev/null +++ b/docs/headers.md @@ -0,0 +1,54 @@ +# Headers + +The `Headers` class provides an immutable case-insensitive multidict interface for accessing HTTP headers. + + + +```{ .python .httpx } +>>> headers = httpx.Headers({"Accept": "*/*"}) +>>> headers + +>>> headers['accept'] +'*/*' +``` + +```{ .python .ahttpx .hidden } +>>> headers = ahttpx.Headers({"Accept": "*/*"}) +>>> headers + +>>> headers['accept'] +'*/*' +``` + +Header values should always be printable ASCII strings. Attempting to set invalid header name or value strings will raise a `ValueError`. + +### Accessing headers + +Headers are accessed using a standard dictionary style interface... + +* `.get(key, default=None)` - *Return the value for a given key, or a default value. If multiple values for the key are present, only the first will be returned.* +* `.keys()` - *Return the unique keys of the headers. Each key will be a `str`.* +* `.values()` - *Return the values of the headers. Each value will be a `str`. If multiple values for a key are present, only the first will be returned.* +* `.items()` - *Return the key value pairs of the headers. Each item will be a two-tuple `(str, str)`. If multiple values for a key are present, only the first will be returned.* + +The following methods are also available for accessing headers as a multidict... + +* `.get_all(key, comma_delimited=False)` - *Return all the values for a given key. Returned as a list of zero or more `str` instances. If `comma_delimited` is set to `True` then any comma separated header values are split into a list of strings.* +* `.multi_items()` - *Return the key value pairs of the headers. Each item will be a two-tuple `(str, str)`. Repeated keys may occur.* +* `.multi_dict()` - *Return the headers as a dictionary, with each value being a list of one or more `str` instances.* + +### Modifying headers + +The following methods can be used to create modified header instances... + +* `.copy_set(key, value)` - *Return a new `Headers` instances, setting a header. Eg. `headers = headers.copy_set("Connection": "close")`*. +* `.copy_setdefault(key, value)` - *Return a new `Headers` instances, setting a header if it does not yet exist. Eg. `headers = headers.copy_setdefault("Content-Type": "text/html")`*. +* `.copy_append(key, value, comma_delimited=False)` - *Return a new `Headers` instances, setting or appending a header. If `comma_delimited` is set to `True`, then the append will be handled using comma delimiting instead of creating a new header. Eg. `headers = headers.copy_append("Accept-Encoding", "gzip", comma_delimited=True)`*. +* `.copy_remove(key)` - *Return a new `Headers` instances, removing a header. Eg. `headers = headers.copy_remove("User-Agent")`*. +* `.copy_update(headers)` - *Return a new `Headers` instances, updating multiple headers. Eg. `headers = headers.copy_update({"Authorization": "top secret"})`*. + +--- + +← [URLs](urls.md) +[Content Types](content-types.md) → +  \ No newline at end of file diff --git a/docs/img/butterfly.png b/docs/img/butterfly.png new file mode 100644 index 00000000..5e5f6b68 Binary files /dev/null and b/docs/img/butterfly.png differ diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..ded29f7b --- /dev/null +++ b/docs/index.md @@ -0,0 +1,112 @@ +

+ HTTPX +

+ +

HTTPX 1.0 — Prelease.

+ +--- + +A complete HTTP toolkit for Python. Supporting both client & server, and available in either sync or async flavors. + +--- + +*Installation...* + + + +```{ .shell .httpx } +$ pip install --pre httpx +``` + +```{ .shell .ahttpx .hidden } +$ pip install --pre ahttpx +``` + +*Making requests as a client...* + + + +```{ .python .httpx } +>>> import httpx + +>>> r = httpx.get('https://www.example.org/') +>>> r + +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'\n\n\nExample Domain...' +``` + +```{ .python .ahttpx .hidden } +>>> import ahttpx + +>>> r = await ahttpx.get('https://www.example.org/') +>>> r + +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'\n\n\nExample Domain...' +``` + +*Serving responses as the server...* + + + +```{ .python .httpx } +>>> import httpx + +>>> def app(request): +... content = httpx.HTML('hello, world.') +... return httpx.Response(200, content=content) + +>>> httpx.run(app) +Serving on http://127.0.0.1:8080/ (Press CTRL+C to quit) +``` + +```{ .python .ahttpx .hidden } +>>> import ahttpx + +>>> async def app(request): +... content = httpx.HTML('hello, world.') +... return httpx.Response(200, content=content) + +>>> await httpx.run(app) +Serving on http://127.0.0.1:8080/ (Press CTRL+C to quit) +``` + +--- + +# Documentation + +* [Quickstart](quickstart.md) +* [Clients](clients.md) +* [Servers](servers.md) +* [Requests](requests.md) +* [Responses](responses.md) +* [URLs](urls.md) +* [Headers](headers.md) +* [Content Types](content-types.md) +* [Streams](streams.md) +* [Connections](connections.md) +* [Parsers](parsers.md) +* [Network Backends](networking.md) + +--- + +# Collaboration + +The repository for this project is currently private. + +We’re looking at creating paid opportunities for working on open source software *which are properly compensated, flexible & well balanced.* + +If you're interested in a position working on this project, please send an intro: *kim@encode.io* + +--- + +

This design work is not yet licensed for reuse.
— 🦋 —

diff --git a/docs/networking.md b/docs/networking.md new file mode 100644 index 00000000..6375fdf2 --- /dev/null +++ b/docs/networking.md @@ -0,0 +1,381 @@ +# Network Backends + +The lowest level network abstractions in `httpx` are the `NetworkBackend` and `NetworkStream` classes. These provide a consistent interface onto the operations for working with a network stream, typically over a TCP connection. Different runtimes (threaded, trio & asyncio) are supported via alternative implementations of the core interface. + +--- + +## `NetworkBackend()` + +The default backend is instantiated via the `NetworkBackend` class... + + + +```{ .python .httpx } +>>> net = httpx.NetworkBackend() +>>> net + +``` + +```{ .python .ahttpx .hidden } +>>> net = ahttpx.NetworkBackend() +>>> net + +``` + +### `.connect(host, port)` + +A TCP stream is created using the `connect` method... + + + +```{ .python .httpx } +>>> net = httpx.NetworkBackend() +>>> stream = net.connect("www.encode.io", 80) +>>> stream + +``` + +```{ .python .ahttpx .hidden } +>>> net = ahttpx.NetworkBackend() +>>> stream = await net.connect("www.encode.io", 80) +>>> stream + +``` + +Streams support being used in a context managed style. The cleanest approach to resource management is to use `.connect(...)` in the context of a `with` block. + + + +```{ .python .httpx } +>>> net = httpx.NetworkBackend() +>>> with net.connect("dev.encode.io", 80) as stream: +>>> ... +>>> stream + +``` + +```{ .python .ahttpx .hidden } +>>> net = ahttpx.NetworkBackend() +>>> async with await net.connect("dev.encode.io", 80) as stream: +>>> ... +>>> stream + +``` + +## `NetworkStream(sock)` + +The `NetworkStream` class provides TCP stream abstraction, by providing a thin wrapper around a socket instance. + +Network streams do not provide any built-in thread or task locking. +Within `httpx` thread and task saftey is handled at the `Connection` layer. + +### `.read(max_bytes=None)` + +Read up to `max_bytes` bytes of data from the network stream. +If no limit is provided a default value of 64KB will be used. + +### `.write(data)` + +Write the given bytes of `data` to the network stream. + +### `.start_tls(ctx, hostname)` + +Upgrade a stream to TLS (SSL) connection for sending secure `https://` requests. + +`` + +### `.get_extra_info(key)` + +Return information about the underlying resource. May include... + +* `"client_addr"` - Return the client IP and port. +* `"server_addr"` - Return the server IP and port. +* `"ssl_object"` - Return an `ssl.SSLObject` instance. +* `"socket"` - Access the raw socket instance. + +### `.close()` + +Close the network stream. For TLS streams this will attempt to send a closing handshake before terminating the conmection. + + + +```{ .python .httpx } +>>> net = httpx.NetworkBackend() +>>> stream = net.connect("dev.encode.io", 80) +>>> try: +>>> ... +>>> finally: +>>> stream.close() +>>> stream + +``` + +```{ .python .ahttpx .hidden } +>>> net = ahttpx.NetworkBackend() +>>> stream = await net.connect("dev.encode.io", 80) +>>> try: +>>> ... +>>> finally: +>>> await stream.close() +>>> stream + +``` + +--- + +## Timeouts + +Network timeouts are handled using a context block API. + +This [design approach](https://vorpus.org/blog/timeouts-and-cancellation-for-humans) avoids timeouts needing to passed around throughout the stack, and provides an obvious and natural API to dealing with timeout contexts. + +### timeout(duration) + +The timeout context manager can be used to wrap socket operations anywhere in the stack. + +Here's an example of enforcing an overall 3 second timeout on a request. + + + +```{ .python .httpx } +>>> with httpx.Client() as cli: +>>> with httpx.timeout(3.0): +>>> res = cli.get('https://www.example.com') +>>> print(res) +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.Client() as cli: +>>> async with ahttpx.timeout(3.0): +>>> res = await cli.get('https://www.example.com') +>>> print(res) +``` + +Timeout contexts provide an API allowing for deadlines to be cancelled. + +### .cancel() + +In this example we enforce a 3 second timeout on *receiving the start of* a streaming HTTP response... + + + +```{ .python .httpx } +>>> with httpx.Client() as cli: +>>> with httpx.timeout(3.0) as t: +>>> with cli.stream('https://www.example.com') as r: +>>> t.cancel() +>>> print(">>>", res) +>>> for chunk in r.stream: +>>> print("...", chunk) +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.Client() as cli: +>>> async with ahttpx.timeout(3.0) as t: +>>> async with await cli.stream('https://www.example.com') as r: +>>> t.cancel() +>>> print(">>>", res) +>>> async for chunk in r.stream: +>>> print("...", chunk) +``` + +--- + +## Sending HTTP requests + +Let's take a look at how we can work directly with a network backend to send an HTTP request, and recieve an HTTP response. + + + +```{ .python .httpx } +import httpx +import ssl +import truststore + +net = httpx.NetworkBackend() +ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +req = b'\r\n'.join([ + b'GET / HTTP/1.1', + b'Host: www.example.com', + b'User-Agent: python/dev', + b'Connection: close', + b'', + b'', +]) + +# Use a 10 second overall timeout for the entire request/response. +with httpx.timeout(10.0): + # Use a 3 second timeout for the initial connection. + with httpx.timeout(3.0) as t: + # Open the connection & establish SSL. + with net.connect("www.example.com", 443) as stream: + stream.start_tls(ctx, hostname="www.example.com") + t.cancel() + # Send the request & read the response. + stream.write(req) + buffer = [] + while part := stream.read(): + buffer.append(part) + resp = b''.join(buffer) +``` + +```{ .python .ahttpx .hidden } +import ahttpx +import ssl +import truststore + +net = ahttpx.NetworkBackend() +ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +req = b'\r\n'.join([ + b'GET / HTTP/1.1', + b'Host: www.example.com', + b'User-Agent: python/dev', + b'Connection: close', + b'', + b'', +]) + +# Use a 10 second overall timeout for the entire request/response. +async with ahttpx.timeout(10.0): + # Use a 3 second timeout for the initial connection. + async with ahttpx.timeout(3.0) as t: + # Open the connection & establish SSL. + async with await net.connect("www.example.com", 443) as stream: + await stream.start_tls(ctx, hostname="www.example.com") + t.cancel() + # Send the request & read the response. + await stream.write(req) + buffer = [] + while part := await stream.read(): + buffer.append(part) + resp = b''.join(buffer) +``` + +The example above is somewhat contrived, there's no HTTP parsing implemented so we can't actually determine when the response is complete. We're using a `Connection: close` header to request that the server close the connection once the response is complete. + +A more complete example would require proper HTTP parsing. The `Connection` class implements an HTTP request/response interface, layered over a `NetworkStream`. + +--- + +## Custom network backends + +The interface for implementing custom network backends is provided by two classes... + +### `NetworkBackendInterface` + +The abstract interface implemented by `NetworkBackend`. See above for details. + +### `NetworkStreamInterface` + +The abstract interface implemented by `NetworkStream`. See above for details. + +### An example backend + +We can use these interfaces to implement custom functionality. For example, here we're providing a network backend that logs all the ingoing and outgoing bytes. + + + +```{ .python .httpx } +class RecordingBackend(httpx.NetworkBackendInterface): + def __init__(self): + self._backend = NetworkBackend() + + def connect(self, host, port): + # Delegate creating connections to the default + # network backend, and return a wrapped stream. + stream = self._backend.connect(host, port) + return RecordingStream(stream) + + +class RecordingStream(httpx.NetworkStreamInterface): + def __init__(self, stream): + self._stream = stream + + def read(self, max_bytes: int = None): + # Print all incoming data to the terminal. + data = self._stream.read(max_bytes) + lines = data.decode('ascii', errors='replace').splitlines() + for line in lines: + print("<<< ", line) + return data + + def write(self, data): + # Print all outgoing data to the terminal. + lines = data.decode('ascii', errors='replace').splitlines() + for line in lines: + print(">>> ", line) + self._stream.write(data) + + def start_tls(ctx, hostname): + self._stream.start_tls(ctx, hostname) + + def get_extra_info(key): + return self._stream.get_extra_info(key) + + def close(): + self._stream.close() +``` + +```{ .python .ahttpx .hidden } +class RecordingBackend(ahhtpx.NetworkBackendInterface): + def __init__(self): + self._backend = NetworkBackend() + + async def connect(self, host, port): + # Delegate creating connections to the default + # network backend, and return a wrapped stream. + stream = await self._backend.connect(host, port) + return RecordingStream(stream) + + +class RecordingStream(ahttpx.NetworkStreamInterface): + def __init__(self, stream): + self._stream = stream + + async def read(self, max_bytes: int = None): + # Print all incoming data to the terminal. + data = await self._stream.read(max_bytes) + lines = data.decode('ascii', errors='replace').splitlines() + for line in lines: + print("<<< ", line) + return data + + async def write(self, data): + # Print all outgoing data to the terminal. + lines = data.decode('ascii', errors='replace').splitlines() + for line in lines: + print(">>> ", line) + await self._stream.write(data) + + async def start_tls(ctx, hostname): + await self._stream.start_tls(ctx, hostname) + + def get_extra_info(key): + return self._stream.get_extra_info(key) + + async def close(): + await self._stream.close() +``` + +We can now instantiate a client using this network backend. + + + +```{ .python .httpx } +>>> transport = httpx.ConnectionPool(backend=RecordingBackend()) +>>> cli = httpx.Client(transport=transport) +>>> cli.get('https://www.example.com') +``` + +```{ .python .ahttpx .hidden } +>>> transport = ahttpx.ConnectionPool(backend=RecordingBackend()) +>>> cli = ahttpx.Client(transport=transport) +>>> await cli.get('https://www.example.com') +``` + +Custom network backends can also be used to provide functionality such as handling DNS caching for name lookups, or connecting via a UNIX domain socket instead of a TCP connection. + +--- + +← [Parsers](parsers.md) +  diff --git a/docs/parsers.md b/docs/parsers.md new file mode 100644 index 00000000..3416c923 --- /dev/null +++ b/docs/parsers.md @@ -0,0 +1,110 @@ +# Parsers + +### Client + + + +```{ .python .httpx } +stream = httpx.DuplexStream( + b'HTTP/1.1 200 OK\r\n' + b'Content-Length: 23\r\n' + b'Content-Type: application/json\r\n' + b'\r\n' + b'{"msg": "hello, world"}' +) +p = ahttpx.HTTPParser(stream, mode='CLIENT') + +# Send the request... +p.send_method_line(b'GET', b'/', b'HTTP/1.1') +p.send_headers([(b'Host', b'www.example.com')]) +p.send_body(b'') + +# Receive the response... +protocol, code, reason_phase = p.recv_status_line() +headers = p.recv_headers() +body = b'' +while buffer := p.recv_body(): + body += buffer +``` + +```{ .python .ahttpx .hidden } +stream = ahttpx.DuplexStream( + b'HTTP/1.1 200 OK\r\n' + b'Content-Length: 23\r\n' + b'Content-Type: application/json\r\n' + b'\r\n' + b'{"msg": "hello, world"}' +) +p = ahttpx.HTTPParser(stream, mode='CLIENT') + +# Send the request... +await p.send_method_line(b'GET', b'/', b'HTTP/1.1') +await p.send_headers([(b'Host', b'www.example.com')]) +await p.send_body(b'') + +# Receive the response... +protocol, code, reason_phase = await p.recv_status_line() +headers = await p.recv_headers() +body = b'' +while buffer := await p.recv_body(): + body += buffer +``` + +### Server + + + +```{ .python .httpx } +stream = httpx.DuplexStream( + b'GET / HTTP/1.1\r\n' + b'Host: www.example.com\r\n' + b'\r\n' +) +p = httpx.HTTPParser(stream, mode='SERVER') + +# Receive the request... +method, target, protocol = p.recv_method_line() +headers = p.recv_headers() +body = b'' +while buffer := p.recv_body(): + body += buffer + +# Send the response... +p.send_status_line(b'HTTP/1.1', 200, b'OK') +p.send_headers([ + (b'Content-Length', b'23'), + (b'Content-Type', b'application/json') +]) +p.send_body(b'{"msg": "hello, world"}') +p.send_body(b'') +``` + +```{ .python .ahttpx .hidden } +stream = ahttpx.DuplexStream( + b'GET / HTTP/1.1\r\n' + b'Host: www.example.com\r\n' + b'\r\n' +) +p = ahttpx.HTTPParser(stream, mode='SERVER') + +# Receive the request... +method, target, protocol = await p.recv_method_line() +headers = await p.recv_headers() +body = b'' +while buffer := await p.recv_body(): + body += buffer + +# Send the response... +await p.send_status_line(b'HTTP/1.1', 200, b'OK') +await p.send_headers([ + (b'Content-Length', b'23'), + (b'Content-Type', b'application/json') +]) +await p.send_body(b'{"msg": "hello, world"}') +await p.send_body(b'') +``` + +--- + +← [Connections](connections.md) +[Low Level Networking](networking.md) → diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 00000000..c3a60682 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,484 @@ +# QuickStart + +Install using ... + + + +```{ .shell .httpx } +$ pip install --pre httpx +``` + +```{ .shell .ahttpx .hidden } +$ pip install --pre ahttpx +``` + +First, start by importing `httpx`... + + + +```{ .python .httpx } +>>> import httpx +``` + +```{ .python .ahttpx .hidden } +>>> import ahttpx +``` + +Now, let’s try to get a webpage. + + + +```{ .python .httpx } +>>> r = httpx.get('https://httpbin.org/get') +>>> r + +``` + +```{ .python .ahttpx .hidden } +>>> r = await ahttpx.get('https://httpbin.org/get') +>>> r + +``` + +To make an HTTP `POST` request, including some content... + + + +```{ .python .httpx } +>>> form = httpx.Form({'key': 'value'}) +>>> r = httpx.post('https://httpbin.org/post', content=form) +``` + +```{ .python .ahttpx .hidden } +>>> form = httpx.Form({'key': 'value'}) +>>> r = await ahttpx.post('https://httpbin.org/post', content=form) +``` + +Shortcut methods for `PUT`, `PATCH`, and `DELETE` requests follow the same style... + + + +```{ .python .httpx } +>>> r = httpx.put('https://httpbin.org/put', content=form) +>>> r = httpx.patch('https://httpbin.org/patch', content=form) +>>> r = httpx.delete('https://httpbin.org/delete') +``` + +```{ .python .ahttpx .hidden } +>>> r = await ahttpx.put('https://httpbin.org/put', content=form) +>>> r = await ahttpx.patch('https://httpbin.org/patch', content=form) +>>> r = await ahttpx.delete('https://httpbin.org/delete') +``` + +## Passing Parameters in URLs + +To include URL query parameters in the request, construct a URL using the `params` keyword... + + + +```{ .python .httpx } +>>> params = {'key1': 'value1', 'key2': 'value2'} +>>> url = httpx.URL('https://httpbin.org/get', params=params) +>>> r = httpx.get(url) +``` + +```{ .python .ahttpx .hidden } +>>> params = {'key1': 'value1', 'key2': 'value2'} +>>> url = ahttpx.URL('https://httpbin.org/get', params=params) +>>> r = await ahttpx.get(url) +``` + +You can also pass a list of items as a value... + + + +```{ .python .httpx } +>>> params = {'key1': 'value1', 'key2': ['value2', 'value3']} +>>> url = httpx.URL('https://httpbin.org/get', params=params) +>>> r = httpx.get(url) +``` + +```{ .python .ahttpx .hidden } +>>> params = {'key1': 'value1', 'key2': ['value2', 'value3']} +>>> url = ahttpx.URL('https://httpbin.org/get', params=params) +>>> r = await ahttpx.get(url) +``` + +## Custom Headers + +To include additional headers in the outgoing request, use the `headers` keyword argument... + + + +```{ .python .httpx } +>>> url = 'https://httpbin.org/headers' +>>> headers = {'User-Agent': 'my-app/0.0.1'} +>>> r = httpx.get(url, headers=headers) +``` + +```{ .python .ahttpx .hidden } +>>> url = 'https://httpbin.org/headers' +>>> headers = {'User-Agent': 'my-app/0.0.1'} +>>> r = await ahttpx.get(url, headers=headers) +``` + +--- + +## Response Content + +HTTPX will automatically handle decoding the response content into unicode text. + + + +```{ .python .httpx } +>>> r = httpx.get('https://www.example.org/') +>>> r.text +'\n\n\nExample Domain...' +``` + +```{ .python .ahttpx .hidden } +>>> r = await ahttpx.get('https://www.example.org/') +>>> r.text +'\n\n\nExample Domain...' +``` + +## Binary Response Content + +The response content can also be accessed as bytes, for non-text responses. + + + +```{ .python .httpx } +>>> r.body +b'\n\n\nExample Domain...' +``` + +```{ .python .ahttpx .hidden } +>>> r.body +b'\n\n\nExample Domain...' +``` + +## JSON Response Content + +Often Web API responses will be encoded as JSON. + + + +```{ .python .httpx } +>>> r = httpx.get('https://httpbin.org/get') +>>> r.json() +{'args': {}, 'headers': {'Host': 'httpbin.org', 'User-Agent': 'dev', 'X-Amzn-Trace-Id': 'Root=1-679814d5-0f3d46b26686f5013e117085'}, 'origin': '21.35.60.128', 'url': 'https://httpbin.org/get'} +``` + +```{ .python .ahttpx .hidden } +>>> r = await ahttpx.get('https://httpbin.org/get') +>>> await r.json() +{'args': {}, 'headers': {'Host': 'httpbin.org', 'User-Agent': 'dev', 'X-Amzn-Trace-Id': 'Root=1-679814d5-0f3d46b26686f5013e117085'}, 'origin': '21.35.60.128', 'url': 'https://httpbin.org/get'} +``` + +--- + +## Sending Form Encoded Data + +Some types of HTTP requests, such as `POST` and `PUT` requests, can include data in the request body. One common way of including that is as form-encoded data, which is used for HTML forms. + + + +```{ .python .httpx } +>>> form = httpx.Form({'key1': 'value1', 'key2': 'value2'}) +>>> r = httpx.post("https://httpbin.org/post", content=form) +>>> r.json() +{ + ... + "form": { + "key2": "value2", + "key1": "value1" + }, + ... +} +``` + +```{ .python .ahttpx .hidden } +>>> form = ahttpx.Form({'key1': 'value1', 'key2': 'value2'}) +>>> r = await ahttpx.post("https://httpbin.org/post", content=form) +>>> await r.json() +{ + ... + "form": { + "key2": "value2", + "key1": "value1" + }, + ... +} +``` + +Form encoded data can also include multiple values from a given key. + + + +```{ .python .httpx } +>>> form = httpx.Form({'key1': ['value1', 'value2']}) +>>> r = httpx.post("https://httpbin.org/post", content=form) +>>> r.json() +{ + ... + "form": { + "key1": [ + "value1", + "value2" + ] + }, + ... +} +``` + +```{ .python .ahttpx .hidden } +>>> form = ahttpx.Form({'key1': ['value1', 'value2']}) +>>> r = await ahttpx.post("https://httpbin.org/post", content=form) +>>> await r.json() +{ + ... + "form": { + "key1": [ + "value1", + "value2" + ] + }, + ... +} +``` + +## Sending Multipart File Uploads + +You can also upload files, using HTTP multipart encoding. + + + +```{ .python .httpx } +>>> files = httpx.Files({'upload': httpx.File('uploads/report.xls')}) +>>> r = httpx.post("https://httpbin.org/post", content=files) +>>> r.json() +{ + ... + "files": { + "upload": "<... binary content ...>" + }, + ... +} +``` + +```{ .python .ahttpx .hidden } +>>> files = ahttpx.Files({'upload': httpx.File('uploads/report.xls')}) +>>> r = await ahttpx.post("https://httpbin.org/post", content=files) +>>> await r.json() +{ + ... + "files": { + "upload": "<... binary content ...>" + }, + ... +} +``` + +If you need to include non-file data fields in the multipart form, use the `data=...` parameter: + + + +```{ .python .httpx } +>>> form = {'message': 'Hello, world!'} +>>> files = {'upload': httpx.File('uploads/report.xls')} +>>> data = httpx.MultiPart(form=form, files=files) +>>> r = httpx.post("https://httpbin.org/post", content=data) +>>> r.json() +{ + ... + "files": { + "upload": "<... binary content ...>" + }, + "form": { + "message": "Hello, world!", + }, + ... +} +``` + +```{ .python .ahttpx .hidden } +>>> form = {'message': 'Hello, world!'} +>>> files = {'upload': httpx.File('uploads/report.xls')} +>>> data = ahttpx.MultiPart(form=form, files=files) +>>> r = await ahttpx.post("https://httpbin.org/post", content=data) +>>> await r.json() +{ + ... + "files": { + "upload": "<... binary content ...>" + }, + "form": { + "message": "Hello, world!", + }, + ... +} +``` + +## Sending JSON Encoded Data + +Form encoded data is okay if all you need is a simple key-value data structure. +For more complicated data structures you'll often want to use JSON encoding instead. + + + +```{ .python .httpx } +>>> data = {'integer': 123, 'boolean': True, 'list': ['a', 'b', 'c']} +>>> r = httpx.post("https://httpbin.org/post", content=httpx.JSON(data)) +>>> r.json() +{ + ... + "json": { + "boolean": true, + "integer": 123, + "list": [ + "a", + "b", + "c" + ] + }, + ... +} +``` + +```{ .python .ahttpx .hidden } +>>> data = {'integer': 123, 'boolean': True, 'list': ['a', 'b', 'c']} +>>> r = await ahttpx.post("https://httpbin.org/post", content=httpx.JSON(data)) +>>> await r.json() +{ + ... + "json": { + "boolean": true, + "integer": 123, + "list": [ + "a", + "b", + "c" + ] + }, + ... +} +``` + +## Sending Binary Request Data + +For other encodings, you should use the `content=...` parameter, passing +either a `bytes` type or a generator that yields `bytes`. + + + +```{ .python .httpx } +>>> content = b'Hello, world' +>>> r = httpx.post("https://httpbin.org/post", content=content) +``` + +```{ .python .ahttpx .hidden } +>>> content = b'Hello, world' +>>> r = await ahttpx.post("https://httpbin.org/post", content=content) +``` + +You may also want to set a custom `Content-Type` header when uploading +binary data. + +--- + +## Response Status Codes + +We can inspect the HTTP status code of the response: + + + +```{ .python .httpx } +>>> r = httpx.get('https://httpbin.org/get') +>>> r.status_code +200 +``` + +```{ .python .ahttpx .hidden } +>>> r = await ahttpx.get('https://httpbin.org/get') +>>> r.status_code +200 +``` + +## Response Headers + +The response headers are available as a dictionary-like interface. + + + +```{ .python .httpx } +>>> r.headers + +``` + +```{ .python .ahttpx .hidden } +>>> r.headers + +``` + +The `Headers` data type is case-insensitive, so you can use any capitalization. + + + +```{ .python .httpx } +>>> r.headers.get('Content-Type') +'application/json' + +>>> r.headers.get('content-type') +'application/json' +``` + +```{ .python .ahttpx .hidden } +>>> r.headers.get('Content-Type') +'application/json' + +>>> r.headers.get('content-type') +'application/json' +``` + +--- + +## Streaming Responses + +For large downloads you may want to use streaming responses that do not load the entire response body into memory at once. + +You can stream the binary content of the response... + + + +```{ .python .httpx } +>>> with httpx.stream("GET", "https://www.example.com") as r: +... for data in r.stream: +... print(data) +``` + +```{ .python .ahttpx .hidden } +>>> async with ahttpx.stream("GET", "https://www.example.com") as r: +... async for data in r.stream: +... print(data) +``` + +--- + +← [Home](index.md) +[Clients](clients.md) → +  \ No newline at end of file diff --git a/docs/requests.md b/docs/requests.md new file mode 100644 index 00000000..7f271251 --- /dev/null +++ b/docs/requests.md @@ -0,0 +1,178 @@ +# Requests + +The core elements of an HTTP request are the `method`, `url`, `headers` and `body`. + + + +```{ .python .httpx } +>>> req = httpx.Request('GET', 'https://www.example.com/') +>>> req + +>>> req.method +'GET' +>>> req.url + +>>> req.headers + +>>> req.body +b'' +``` + +```{ .python .ahttpx .hidden } +>>> req = ahttpx.Request('GET', 'https://www.example.com/') +>>> req + +>>> req.method +'GET' +>>> req.url + +>>> req.headers + +>>> req.body +b'' +``` + +## Working with the request headers + +The following headers have automatic behavior with `Requests` instances... + +* `Host` - A `Host` header must always be included on a request. This header is automatically populated from the `url`, using the `url.netloc` property. +* `Content-Length` - Requests including a request body must always include either a `Content-Length` header or a `Transfer-Encoding: chunked` header. This header is automatically populated if `content` is not `None` and the content is a known size. +* `Transfer-Encoding` - Requests automatically include a `Transfer-Encoding: chunked` header if `content` is not `None` and the content is an unkwown size. +* `Content-Type` - Requests automatically include a `Content-Type` header if `content` is set using the [Content Type] API. + +## Working with the request body + +Including binary data directly... + + + +```{ .python .httpx } +>>> headers = {'Content-Type': 'application/json'} +>>> content = json.dumps(...) +>>> httpx.Request('POST', 'https://echo.encode.io/', content=content) +``` + +```{ .python .ahttpx .hidden } +>>> headers = {'Content-Type': 'application/json'} +>>> content = json.dumps(...) +>>> ahttpx.Request('POST', 'https://echo.encode.io/', content=content) +``` + +## Working with content types + +Including JSON request content... + + + +```{ .python .httpx } +>>> data = httpx.JSON(...) +>>> httpx.Request('POST', 'https://echo.encode.io/', content=data) +``` + +```{ .python .ahttpx .hidden } +>>> data = ahttpx.JSON(...) +>>> ahttpx.Request('POST', 'https://echo.encode.io/', content=data) +``` + +Including form encoded request content... + + + +```{ .python .httpx } +>>> data = httpx.Form(...) +>>> httpx.Request('PUT', 'https://echo.encode.io/', content=data) +``` + +```{ .python .ahttpx .hidden } +>>> data = ahttpx.Form(...) +>>> ahttpx.Request('PUT', 'https://echo.encode.io/', content=data) +``` + +Including multipart file uploads... + + + +```{ .python .httpx } +>>> form = httpx.MultiPart(form={...}, files={...}) +>>> with httpx.Request('POST', 'https://echo.encode.io/', content=form) as req: +>>> req.headers +{...} +>>> req.stream + +``` + +```{ .python .ahttpx .hidden } +>>> form = ahttpx.MultiPart(form={...}, files={...}) +>>> async with ahttpx.Request('POST', 'https://echo.encode.io/', content=form) as req: +>>> req.headers +{...} +>>> req.stream + +``` + +Including direct file uploads... + + + +```{ .python .httpx } +>>> file = httpx.File('upload.json') +>>> with httpx.Request('POST', 'https://echo.encode.io/', content=file) as req: +>>> req.headers +{...} +>>> req.stream + +``` + +```{ .python .ahttpx .hidden } +>>> file = ahttpx.File('upload.json') +>>> async with ahttpx.Request('POST', 'https://echo.encode.io/', content=file) as req: +>>> req.headers +{...} +>>> req.stream + +``` + +## Accessing request content + +*In progress...* + + + +```{ .python .httpx } +>>> data = request.json() +``` + +```{ .python .ahttpx .hidden } +>>> data = await request.json() +``` + +... + + + +```{ .python .httpx } +>>> form = request.form() +``` + +```{ .python .ahttpx .hidden } +>>> form = await request.form() +``` + +... + + + +```{ .python .httpx } +>>> files = request.files() +``` + +```{ .python .ahttpx .hidden } +>>> files = await request.files() +``` + +--- + +← [Servers](servers.md) +[Responses](responses.md) → +  diff --git a/docs/responses.md b/docs/responses.md new file mode 100644 index 00000000..58ef2e49 --- /dev/null +++ b/docs/responses.md @@ -0,0 +1,131 @@ +# Responses + +The core elements of an HTTP response are the `status_code`, `headers` and `body`. + + + +```{ .python .httpx } +>>> resp = httpx.Response(200, headers={'Content-Type': 'text/plain'}, content=b'hello, world') +>>> resp + +>>> resp.status_code +200 +>>> resp.headers + +>>> resp.body +b'hello, world' +``` + +```{ .python .ahttpx .hidden } +>>> resp = ahttpx.Response(200, headers={'Content-Type': 'text/plain'}, content=b'hello, world') +>>> resp + +>>> resp.status_code +200 +>>> resp.headers + +>>> resp.body +b'hello, world' +``` + +## Working with the response headers + +The following headers have automatic behavior with `Response` instances... + +* `Content-Length` - Responses including a response body must always include either a `Content-Length` header or a `Transfer-Encoding: chunked` header. This header is automatically populated if `content` is not `None` and the content is a known size. +* `Transfer-Encoding` - Responses automatically include a `Transfer-Encoding: chunked` header if `content` is not `None` and the content is an unkwown size. +* `Content-Type` - Responses automatically include a `Content-Type` header if `content` is set using the [Content Type] API. + +## Working with content types + +Including HTML content... + + + +```{ .python .httpx } +>>> content = httpx.HTML('......') +>>> response = httpx.Response(200, content=content) +``` + +```{ .python .ahttpx .hidden } +>>> content = ahttpx.HTML('......') +>>> response = ahttpx.Response(200, content=content) +``` + +Including plain text content... + + + +```{ .python .httpx } +>>> content = httpx.Text('hello, world') +>>> response = httpx.Response(200, content=content) +``` + +```{ .python .ahttpx .hidden } +>>> content = ahttpx.Text('hello, world') +>>> response = ahttpx.Response(200, content=content) +``` + +Including JSON data... + + + +```{ .python .httpx } +>>> content = httpx.JSON({'message': 'hello, world'}) +>>> response = httpx.Response(200, content=content) +``` + +```{ .python .ahttpx .hidden } +>>> content = ahttpx.JSON({'message': 'hello, world'}) +>>> response = ahttpx.Response(200, content=content) +``` + +Including content from a file... + + + +```{ .python .httpx } +>>> content = httpx.File('index.html') +>>> with httpx.Response(200, content=content) as response: +... pass +``` + +```{ .python .ahttpx .hidden } +>>> content = ahttpx.File('index.html') +>>> async with ahttpx.Response(200, content=content) as response: +... pass +``` + +## Accessing response content + +... + + + +```{ .python .httpx } +>>> response.body +``` + +```{ .python .ahttpx .hidden } +>>> response.body +``` + +... + + + +```{ .python .httpx } +>>> response.text +... +``` + +```{ .python .ahttpx .hidden } +>>> response.text +... +``` + +--- + +← [Requests](requests.md) +[URLs](urls.md) → +  diff --git a/docs/servers.md b/docs/servers.md new file mode 100644 index 00000000..57e79c33 --- /dev/null +++ b/docs/servers.md @@ -0,0 +1,85 @@ +# Servers + +The HTTP server provides a simple request/response API. +This gives you a lightweight way to build web applications or APIs. + +### `serve_http(endpoint)` + + + +```{ .python .httpx } +>>> website = """ +... +... +... +... +... +...
hello, world
+... +... +... """ + +>>> def hello_world(request): +... content = httpx.HTML(website) +... return httpx.Response(200, content=content) + +>>> with httpx.serve_http(hello_world) as server: +... print(f"Serving on {server.url} (Press CTRL+C to quit)") +... server.wait() +Serving on http://127.0.0.1:8080/ (Press CTRL+C to quit) +``` + +```{ .python .ahttpx .hidden } +>>> import httpx + +>>> website = """ +... +... +... +... +... +...
hello, world
+... +... +... """ + +>>> async def hello_world(request): +... if request.path != '/': +... content = httpx.Text("Not found") +... return httpx.Response(404, content=content) +... content = httpx.HTML(website) +... return httpx.Response(200, content=content) + +>>> async with httpx.serve_http(hello_world) as server: +... print(f"Serving on {server.url} (Press CTRL+C to quit)") +... await server.wait() +Serving on http://127.0.0.1:8080/ (Press CTRL+C to quit) +``` + +--- + +*Docs in progress...* + +--- + +← [Clients](clients.md) +[Requests](requests.md) → +  diff --git a/docs/streams.md b/docs/streams.md new file mode 100644 index 00000000..53c32d68 --- /dev/null +++ b/docs/streams.md @@ -0,0 +1,88 @@ +# Streams + +Streams provide a minimal file-like interface for reading bytes from a data source. They are used as the abstraction for reading the body of a request or response. + +The interfaces here are simplified versions of Python's standard I/O operations. + +## Stream + +The base `Stream` class. The core of the interface is a subset of Python's `io.IOBase`... + +* `.read(size=-1)` - *(bytes)* Return the bytes from the data stream. If the `size` argument is omitted or negative then the entire stream will be read. If `size` is an positive integer then the call returns at most `size` bytes. A return value of `b''` indicates the end of the stream has been reached. +* `.write(self, data: bytes)` - *None* Write the given bytes to the data stream. May raise `NotImplmentedError` if this is not a writeable stream. +* `.close()` - Close the stream. Any further operations will raise a `ValueError`. + +Additionally, the following property is also defined... + +* `.size` - *(int or None)* Return an integer indicating the size of the stream, or `None` if the size is unknown. When working with HTTP this is used to either set a `Content-Length: ` header, or a `Content-Encoding: chunked` header. + +The `Stream` interface and `ContentType` interface are related, with streams being used as the abstraction for the bytewise representation, and content types being used to encapsulate the parsed data structure. + +For example, encoding some `JSON` data... + +```python +>>> data = httpx.JSON({'name': 'zelda', 'score': '478'}) +>>> stream = data.encode() +>>> stream.read() +b'{"name":"zelda","score":"478"}' +>>> stream.content_type +'application/json' +``` + +--- + +## ByteStream + +A byte stream returning fixed byte content. Similar to Python's `io.BytesIO` class. + +```python +>>> s = httpx.ByteStream(b'{"msg": "Hello, world!"}') +>>> s.read() +b'{"msg": "Hello, world!"}' +``` + +## FileStream + +A byte stream returning content from a file. + +The standard pattern for instantiating a `FileStream` is to use `File` as a context manager: + +```python +>>> with httpx.File('upload.json') as s: +... s.read() +b'{"msg": "Hello, world!"}' +``` + +## MultiPartStream + +A byte stream returning multipart upload data. + +The standard pattern for instantiating a `MultiPartStream` is to use `MultiPart` as a context manager: + +```python +>>> files = {'avatar-upload': 'image.png'} +>>> with httpx.MultiPart(files=files) as s: +... s.read() +# ... +``` + +## HTTPStream + +A byte stream returning unparsed content from an HTTP request or response. + +```python +>>> with httpx.Client() as cli: +... r = cli.get('https://www.example.com/') +... r.stream.read() +# ... +``` + +## GZipStream + +... + +--- + +← [Content Types](content-types.md) +[Connections](connections.md) → +  diff --git a/docs/templates/base.html b/docs/templates/base.html new file mode 100644 index 00000000..22fe4d37 --- /dev/null +++ b/docs/templates/base.html @@ -0,0 +1,186 @@ + + + + + + httpx + + + + + + + + + + + + + +
+ {{ content }} +
+ + \ No newline at end of file diff --git a/docs/urls.md b/docs/urls.md new file mode 100644 index 00000000..ef56b184 --- /dev/null +++ b/docs/urls.md @@ -0,0 +1,240 @@ +# URLs + +The `URL` class handles URL validation and parsing. + + + +```{ .python .httpx } +>>> url = httpx.URL('https://www.example.com/') +>>> url + +``` + +```{ .python .ahttpx .hidden } +>>> url = ahttpx.URL('https://www.example.com/') +>>> url + +``` + +URL components are normalised, following the same rules as internet browsers. + + + +```{ .python .httpx } +>>> url = httpx.URL('https://www.EXAMPLE.com:443/path/../main') +>>> url + +``` + +```{ .python .ahttpx .hidden } +>>> url = ahttpx.URL('https://www.EXAMPLE.com:443/path/../main') +>>> url + +``` + +Both absolute and relative URLs are valid. + + + +```{ .python .httpx } +>>> url = httpx.URL('/README.md') +>>> url + +``` + +```{ .python .ahttpx .hidden } +>>> url = ahttpx.URL('/README.md') +>>> url + +``` + +Coercing a URL to a `str` will always result in a printable ASCII string. + + + +```{ .python .httpx } +>>> url = httpx.URL('https://example.com/path to here?search=🦋') +>>> str(url) +'https://example.com/path%20to%20here?search=%F0%9F%A6%8B' +``` + +```{ .python .ahttpx .hidden } +>>> url = ahttpx.URL('https://example.com/path to here?search=🦋') +>>> str(url) +'https://example.com/path%20to%20here?search=%F0%9F%A6%8B' +``` + +### URL components + +The following properties are available for accessing the component parts of a URL. + +* `.scheme` - *str. ASCII. Normalised to lowercase.* +* `.userinfo` - *str. ASCII. URL encoded.* +* `.username` - *str. Unicode.* +* `.password` - *str. Unicode.* +* `.host` - *str. ASCII. IDNA encoded.* +* `.port` - *int or None. Scheme default ports are normalised to None.* +* `.authority` - *str. ASCII. IDNA encoded. Eg. "example.com", "example.com:1337", "xn--p1ai".* +* `.path` - *str. Unicode.* +* `.query` - *str. ASCII. URL encoded.* +* `.target` - *str. ASCII. URL encoded.* +* `.fragment` - *str. ASCII. URL encoded.* + +A parsed representation of the query parameters is accessible with the `.params` property. + +* `.params` - [`QueryParams`](#query-parameters) + +URLs can be instantiated from their components... + + + +```{ .python .httpx } +>>> httpx.URL(scheme="https", host="example.com", path="/") + +``` + +```{ .python .ahttpx .hidden } +>>> ahttpx.URL(scheme="https", host="example.com", path="/") + +``` + +Or using both the string form and query parameters... + + + +```{ .python .httpx } +>>> httpx.URL("https://example.com/", params={"search": "some text"}) + +``` + +```{ .python .ahttpx .hidden } +>>> ahttpx.URL("https://example.com/", params={"search": "some text"}) + +``` + +### Modifying URLs + +Instances of `URL` are immutable, meaning their value cannot be changed. Instead new modified instances may be created. + +* `.copy_with(**components)` - *Return a new URL, updating one or more components. Eg. `url = url.copy_with(scheme="https")`*. +* `.copy_set_param(key, value)` - *Return a new URL, setting a query parameter. Eg. `url = url.copy_set_param("sort_by", "price")`*. +* `.copy_append_param(key, value)` - *Return a new URL, setting or appending a query parameter. Eg. `url = url.copy_append_param("tag", "sale")`*. +* `.copy_remove_param(key)` - *Return a new URL, removing a query parameter. Eg. `url = url.copy_remove_param("max_price")`*. +* `.copy_update_params(params)` - *Return a new URL, updating the query parameters. Eg. `url = url.copy_update_params({"color_scheme": "dark"})`*. +* `.join(url)` - *Return a new URL, given this URL as the base and another URL as the target. Eg. `url = url.join("../navigation")`*. + +--- + +## Query Parameters + +The `QueryParams` class provides an immutable multi-dict for accessing URL query parameters. + +They can be instantiated from a dictionary. + + + +```{ .python .httpx } +>>> params = httpx.QueryParams({"color": "black", "size": "medium"}) +>>> params + +``` + +```{ .python .ahttpx .hidden } +>>> params = ahttpx.QueryParams({"color": "black", "size": "medium"}) +>>> params + +``` + +Multiple values for a single key are valid. + + + +```{ .python .httpx } +>>> params = httpx.QueryParams({"filter": ["60GHz", "75GHz", "100GHz"]}) +>>> params + +``` + +```{ .python .ahttpx .hidden } +>>> params = ahttpx.QueryParams({"filter": ["60GHz", "75GHz", "100GHz"]}) +>>> params + +``` + +They can also be instantiated directly from a query string. + + + +```{ .python .httpx } +>>> params = httpx.QueryParams("color=black&size=medium") +>>> params + +``` + +```{ .python .ahttpx .hidden } +>>> params = ahttpx.QueryParams("color=black&size=medium") +>>> params + +``` + +Keys and values are always represented as strings. + + + +```{ .python .httpx } +>>> params = httpx.QueryParams("sort_by=published&author=natalie") +>>> params["sort_by"] +'published' +``` + +```{ .python .ahttpx .hidden } +>>> params = ahttpx.QueryParams("sort_by=published&author=natalie") +>>> params["sort_by"] +'published' +``` + +When coercing query parameters to strings you'll see the same escaping behavior as HTML form submissions. The result will always be a printable ASCII string. + + + +```{ .python .httpx } +>>> params = httpx.QueryParams({"email": "user@example.com", "search": "How HTTP works!"}) +>>> str(params) +'email=user%40example.com&search=How+HTTP+works%21' +``` + +```{ .python .ahttpx .hidden } +>>> params = ahttpx.QueryParams({"email": "user@example.com", "search": "How HTTP works!"}) +>>> str(params) +'email=user%40example.com&search=How+HTTP+works%21' +``` + +### Accessing query parameters + +Query parameters are accessed using a standard dictionary style interface... + +* `.get(key, default=None)` - *Return the value for a given key, or a default value. If multiple values for the key are present, only the first will be returned.* +* `.keys()` - *Return the unique keys of the query parameters. Each key will be a `str` instance.* +* `.values()` - *Return the values of the query parameters. Each value will be a list of one or more `str` instances.* +* `.items()` - *Return the key value pairs of the query parameters. Each item will be a two-tuple including a `str` instance as the key, and a list of one or more `str` instances as the value.* + +The following methods are also available for accessing query parameters as a multidict... + +* `.get_all(key)` - *Return all the values for a given key. Returned as a list of zero or more `str` instances.* +* `.multi_items()` - *Return the key value pairs of the query parameters. Each item will be a two-tuple `(str, str)`. Repeated keys may occur.* +* `.multi_dict()` - *Return the query parameters as a dictionary, with each value being a list of one or more `str` instances.* + +### Modifying query parameters + +The following methods can be used to create modified query parameter instances... + +* `.copy_set(key, value)` +* `.copy_append(key, value)` +* `.copy_remove(key)` +* `.copy_update(params)` + +--- + +← [Responses](responses.md) +[Headers](headers.md) → +  \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..e708b5d9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "httpx" +description = "HTTP, for Python." +requires-python = ">=3.10" +authors = [ + { name = "Tom Christie", email = "tom@tomchristie.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Internet :: WWW/HTTP", +] +dependencies = [ + "certifi", +] +dynamic = ["version"] + +[tool.hatch.version] +path = "src/httpx/__version__.py" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..f4d4bb38 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +-e . + +# Build... +build==1.2.2 + +# Test... +mypy==1.15.0 +pytest==8.3.5 +pytest-cov==6.1.1 + +# Sync & Async mirroring... +unasync==0.6.0 + +# Documentation... +click==8.2.1 +jinja2==3.1.6 +markdown==3.8 diff --git a/scripts/build b/scripts/build new file mode 100755 index 00000000..c7e14690 --- /dev/null +++ b/scripts/build @@ -0,0 +1,32 @@ +#!/bin/sh + +PKG=$1 + +if [ "$PKG" != "httpx" ] && [ "$PKG" != "ahttpx" ] ; then + echo "build [httpx|ahttpx]" + exit 1 +fi + +export PREFIX="" +if [ -d 'venv' ] ; then + export PREFIX="venv/bin/" +fi + +# Create pyproject-httpx.toml and pyproject-ahttpx.toml +cp pyproject.toml pyproject-httpx.toml +cat pyproject-httpx.toml | sed 's/name = "httpx"/name = "ahttpx"/' > pyproject-ahttpx.toml + +# Build the releases +if [ "$PKG" == "httpx" ]; then + ${PREFIX}python -m build +fi +if [ "$PKG" == "ahttpx" ]; then + cp pyproject-ahttpx.toml pyproject.toml + ${PREFIX}python -m build + cp pyproject-httpx.toml pyproject.toml +fi + +# Clean up +rm pyproject-httpx.toml pyproject-ahttpx.toml + +echo $PKG \ No newline at end of file diff --git a/scripts/docs b/scripts/docs new file mode 100755 index 00000000..8c53da47 --- /dev/null +++ b/scripts/docs @@ -0,0 +1,153 @@ +#!venv/bin/python +import pathlib +import posixpath + +import click +import ghp_import +import logging +import httpx +import jinja2 +import markdown + +import xml.etree.ElementTree as etree + + +pages = { + '/': 'docs/index.md', + '/quickstart': 'docs/quickstart.md', + '/clients': 'docs/clients.md', + '/servers': 'docs/servers.md', + '/requests': 'docs/requests.md', + '/responses': 'docs/responses.md', + '/urls': 'docs/urls.md', + '/headers': 'docs/headers.md', + '/content-types': 'docs/content-types.md', + '/streams': 'docs/streams.md', + '/connections': 'docs/connections.md', + '/parsers': 'docs/parsers.md', + '/networking': 'docs/networking.md', + '/about': 'docs/about.md', +} + +def path_to_url(path): + if path == "index.md": + return "/" + return f"/{path[:-3]}" + + +class URLsProcessor(markdown.treeprocessors.Treeprocessor): + def __init__(self, state): + self.state = state + + def run(self, root: etree.Element) -> etree.Element: + for element in root.iter(): + if element.tag == 'a': + key = 'href' + elif element.tag == 'img': + key = 'src' + else: + continue + + url_or_path = element.get(key) + if url_or_path is not None: + output_url = self.rewrite_url(url_or_path) + element.set(key, output_url) + + return root + + def rewrite_url(self, href: str) -> str: + if not href.endswith('.md'): + return href + + current_url = path_to_url(self.state.file) + linked_url = path_to_url(href) + return posixpath.relpath(linked_url, start=current_url) + + +class BuildState: + def __init__(self): + self.file = '' + + +state = BuildState() +env = jinja2.Environment( + loader=jinja2.FileSystemLoader('docs/templates'), + autoescape=False +) +template = env.get_template('base.html') +md = markdown.Markdown(extensions=['fenced_code']) +md.treeprocessors.register( + item=URLsProcessor(state), + name='urls', + priority=10, +) + + +def not_found(): + text = httpx.Text('Not Found') + return httpx.Response(404, content=text) + + +def web_server(request): + if request.url.path not in pages: + return not_found() + + file = pages[request.url.path] + text = pathlib.Path(file).read_text() + + state.file = file + content = md.convert(text) + html = template.render(content=content).encode('utf-8') + content = httpx.HTML(html) + return httpx.Response(200, content=html) + + +@click.group() +def main(): + pass + + +@main.command() +def build(): + pathlib.Path("build").mkdir(exist_ok=True) + + for url, path in pages.items(): + basename = url.lstrip("/") + output = f"build/{basename}.html" if basename else "build/index.html" + text = pathlib.Path(path).read_text() + content = md.convert(text) + html = template.render(content=content) + pathlib.Path(output).write_text(html) + print(f"Built {output}") + + +@main.command() +def serve(): + logging.basicConfig( + format="%(levelname)s [%(asctime)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO + ) + + with httpx.serve_http(web_server) as server: + server.wait() + + +@main.command() +def deploy(): + ghp_import.ghp_import( + "build", + mesg="Documentation deploy", + remote="origin", + branch="gh-pages", + push=True, + force=False, + use_shell=False, + no_history=False, + nojekyll=True, + ) + print(f"Deployed to GitHub") + + +if __name__ == "__main__": + main() diff --git a/scripts/install b/scripts/install new file mode 100755 index 00000000..1b531e57 --- /dev/null +++ b/scripts/install @@ -0,0 +1,13 @@ +#!/bin/sh + +set -x + +if [ -z "$GITHUB_ACTIONS" ]; then + python3 -m venv venv + PIP="venv/bin/pip" +else + PIP="pip" +fi + +"$PIP" install -U pip +"$PIP" install -r requirements.txt diff --git a/scripts/publish b/scripts/publish new file mode 100755 index 00000000..6e6955f5 --- /dev/null +++ b/scripts/publish @@ -0,0 +1,15 @@ +#!/bin/sh + +PKG=$1 + +if [ "$PKG" != "httpx" ] && [ "$PKG" != "ahttpx" ] ; then + echo "publish [httpx|ahttpx]" + exit 1 +fi + +export PREFIX="" +if [ -d 'venv' ] ; then + export PREFIX="venv/bin/" +fi +${PREFIX}pip install -q twine +${PREFIX}twine upload dist/$PKG-* diff --git a/scripts/test b/scripts/test new file mode 100755 index 00000000..1e0812cd --- /dev/null +++ b/scripts/test @@ -0,0 +1,10 @@ +#!/bin/sh + +export PREFIX="" +if [ -d 'venv' ] ; then + export PREFIX="venv/bin/" +fi + +${PREFIX}mypy src/httpx +${PREFIX}mypy src/ahttpx +${PREFIX}pytest --cov src/httpx tests diff --git a/scripts/unasync b/scripts/unasync new file mode 100755 index 00000000..67d66b5c --- /dev/null +++ b/scripts/unasync @@ -0,0 +1,29 @@ +#!venv/bin/python +import unasync + +unasync.unasync_files( + fpath_list = [ + "src/ahttpx/__init__.py", + "src/ahttpx/__version__.py", + "src/ahttpx/_client.py", + "src/ahttpx/_content.py", + "src/ahttpx/_headers.py", + "src/ahttpx/_parsers.py", + "src/ahttpx/_pool.py", + "src/ahttpx/_quickstart.py", + "src/ahttpx/_response.py", + "src/ahttpx/_request.py", + "src/ahttpx/_server.py", + "src/ahttpx/_streams.py", + "src/ahttpx/_urlencode.py", + "src/ahttpx/_urlparse.py", + "src/ahttpx/_urls.py", + ], + rules = [ + unasync.Rule( + "src/ahttpx/", + "src/httpx/", + additional_replacements={"ahttpx": "httpx"} + ), + ] +) diff --git a/src/ahttpx/__init__.py b/src/ahttpx/__init__.py new file mode 100644 index 00000000..9e589ab6 --- /dev/null +++ b/src/ahttpx/__init__.py @@ -0,0 +1,65 @@ +from .__version__ import __title__, __version__ +from ._client import * # Client +from ._content import * # Content, File, Files, Form, HTML, JSON, MultiPart, Text +from ._headers import * # Headers +from ._network import * # NetworkBackend, NetworkStream, timeout +from ._parsers import * # HTTPParser, ProtocolError +from ._pool import * # Connection, ConnectionPool, Transport +from ._quickstart import * # get, post, put, patch, delete +from ._response import * # Response +from ._request import * # Request +from ._streams import * # ByteStream, DuplexStream, FileStream, HTTPStream, Stream +from ._server import * # serve_http, run +from ._urlencode import * # quote, unquote, urldecode, urlencode +from ._urls import * # QueryParams, URL + + +__all__ = [ + "__title__", + "__version__", + "ByteStream", + "Client", + "Connection", + "ConnectionPool", + "Content", + "delete", + "DuplexStream", + "File", + "FileStream", + "Files", + "Form", + "get", + "Headers", + "HTML", + "HTTPParser", + "HTTPStream", + "JSON", + "MultiPart", + "NetworkBackend", + "NetworkStream", + "open_connection", + "post", + "ProtocolError", + "put", + "patch", + "Response", + "Request", + "run", + "serve_http", + "Stream", + "Text", + "timeout", + "Transport", + "QueryParams", + "quote", + "unquote", + "URL", + "urldecode", + "urlencode", +] + + +__locals = locals() +for __name in __all__: + if not __name.startswith('__'): + setattr(__locals[__name], "__module__", "httpx") diff --git a/src/ahttpx/__version__.py b/src/ahttpx/__version__.py new file mode 100644 index 00000000..309fcb32 --- /dev/null +++ b/src/ahttpx/__version__.py @@ -0,0 +1,2 @@ +__title__ = "ahttpx" +__version__ = "1.0.dev3" \ No newline at end of file diff --git a/src/ahttpx/_client.py b/src/ahttpx/_client.py new file mode 100644 index 00000000..6326ac5d --- /dev/null +++ b/src/ahttpx/_client.py @@ -0,0 +1,156 @@ +import types +import typing + +from ._content import Content +from ._headers import Headers +from ._pool import ConnectionPool, Transport +from ._request import Request +from ._response import Response +from ._streams import Stream +from ._urls import URL + +__all__ = ["Client"] + + +class Client: + def __init__( + self, + url: URL | str | None = None, + headers: Headers | typing.Mapping[str, str] | None = None, + transport: Transport | None = None, + ): + if url is None: + url = "" + if headers is None: + headers = {"User-Agent": "dev"} + if transport is None: + transport = ConnectionPool() + + self.url = URL(url) + self.headers = Headers(headers) + self.transport = transport + self.via = RedirectMiddleware(self.transport) + + def build_request( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Request: + return Request( + method=method, + url=self.url.join(url), + headers=self.headers.copy_update(headers), + content=content, + ) + + async def request( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = self.build_request(method, url, headers=headers, content=content) + async with await self.via.send(request) as response: + await response.read() + return response + + async def stream( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = self.build_request(method, url, headers=headers, content=content) + return await self.via.send(request) + + async def get( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + ): + return await self.request("GET", url, headers=headers) + + async def post( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + return await self.request("POST", url, headers=headers, content=content) + + async def put( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + return await self.request("PUT", url, headers=headers, content=content) + + async def patch( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + return await self.request("PATCH", url, headers=headers, content=content) + + async def delete( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + ): + return await self.request("DELETE", url, headers=headers) + + async def close(self): + await self.transport.close() + + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + await self.close() + + def __repr__(self): + return f"" + + +class RedirectMiddleware(Transport): + def __init__(self, transport: Transport) -> None: + self._transport = transport + + def is_redirect(self, response: Response) -> bool: + return ( + response.status_code in (301, 302, 303, 307, 308) + and "Location" in response.headers + ) + + def build_redirect_request(self, request: Request, response: Response) -> Request: + raise NotImplementedError() + + async def send(self, request: Request) -> Response: + while True: + response = await self._transport.send(request) + + if not self.is_redirect(response): + return response + + # If we have a redirect, then we read the body of the response. + # Ensures that the HTTP connection is available for a new + # request/response cycle. + await response.read() + await response.close() + + # We've made a request-response and now need to issue a redirect request. + request = self.build_redirect_request(request, response) + + async def close(self): + pass diff --git a/src/ahttpx/_content.py b/src/ahttpx/_content.py new file mode 100644 index 00000000..e2c4aaa9 --- /dev/null +++ b/src/ahttpx/_content.py @@ -0,0 +1,378 @@ +import json +import os +import typing + +from ._streams import Stream, ByteStream, FileStream, MultiPartStream +from ._urlencode import urldecode, urlencode + +__all__ = [ + "Content", + "Form", + "File", + "Files", + "JSON", + "MultiPart", + "Text", + "HTML", +] + +# https://github.com/nginx/nginx/blob/master/conf/mime.types +_content_types = { + ".json": "application/json", + ".js": "application/javascript", + ".html": "text/html", + ".css": "text/css", + ".png": "image/png", + ".jpeg": "image/jpeg", + ".jpg": "image/jpeg", + ".gif": "image/gif", +} + + +class Content: + def encode(self) -> Stream: + raise NotImplementedError() + + def content_type(self) -> str: + raise NotImplementedError() + + +class Form(typing.Mapping[str, str], Content): + """ + HTML form data, as an immutable multi-dict. + Form parameters, as a multi-dict. + """ + + def __init__( + self, + form: ( + typing.Mapping[str, str | typing.Sequence[str]] + | typing.Sequence[tuple[str, str]] + | str + | None + ) = None, + ) -> None: + d: dict[str, list[str]] = {} + + if form is None: + d = {} + elif isinstance(form, str): + d = urldecode(form) + elif isinstance(form, typing.Mapping): + # Convert dict inputs like: + # {"a": "123", "b": ["456", "789"]} + # To dict inputs where values are always lists, like: + # {"a": ["123"], "b": ["456", "789"]} + d = {k: [v] if isinstance(v, str) else list(v) for k, v in form.items()} + else: + # Convert list inputs like: + # [("a", "123"), ("a", "456"), ("b", "789")] + # To a dict representation, like: + # {"a": ["123", "456"], "b": ["789"]} + for k, v in form: + d.setdefault(k, []).append(v) + + self._dict = d + + # Content API + + def encode(self) -> Stream: + content = str(self).encode("ascii") + return ByteStream(content) + + def content_type(self) -> str: + return "application/x-www-form-urlencoded" + + # Dict operations + + def keys(self) -> typing.KeysView[str]: + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, str]: + return {k: v[0] for k, v in self._dict.items()}.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + if key in self._dict: + return self._dict[key][0] + return default + + # Multi-dict operations + + def multi_items(self) -> list[tuple[str, str]]: + multi_items: list[tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def multi_dict(self) -> dict[str, list[str]]: + return {k: list(v) for k, v in self._dict.items()} + + def get_list(self, key: str) -> list[str]: + return list(self._dict.get(key, [])) + + # Update operations + + def copy_set(self, key: str, value: str) -> "Form": + d = self.multi_dict() + d[key] = [value] + return Form(d) + + def copy_append(self, key: str, value: str) -> "Form": + d = self.multi_dict() + d[key] = d.get(key, []) + [value] + return Form(d) + + def copy_remove(self, key: str) -> "Form": + d = self.multi_dict() + d.pop(key, None) + return Form(d) + + # Accessors & built-ins + + def __getitem__(self, key: str) -> str: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, Form) and + sorted(self.multi_items()) == sorted(other.multi_items()) + ) + + def __str__(self) -> str: + return urlencode(self.multi_dict()) + + def __repr__(self) -> str: + return f"" + + +class File(Content): + """ + Wrapper class used for files in uploads and multipart requests. + """ + + def __init__(self, path: str): + self._path = path + + def name(self) -> str: + return os.path.basename(self._path) + + def size(self) -> int: + return os.path.getsize(self._path) + + def encode(self) -> Stream: + return FileStream(self._path) + + def content_type(self) -> str: + _, ext = os.path.splitext(self._path) + ct = _content_types.get(ext, "application/octet-stream") + if ct.startswith('text/'): + ct += "; charset='utf-8'" + return ct + + def __lt__(self, other: typing.Any) -> bool: + return isinstance(other, File) and other._path < self._path + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, File) and other._path == self._path + + def __repr__(self) -> str: + return f"" + + +class Files(typing.Mapping[str, File], Content): + """ + File parameters, as a multi-dict. + """ + + def __init__( + self, + files: ( + typing.Mapping[str, File | typing.Sequence[File]] + | typing.Sequence[tuple[str, File]] + | None + ) = None, + boundary: str = '' + ) -> None: + d: dict[str, list[File]] = {} + + if files is None: + d = {} + elif isinstance(files, typing.Mapping): + d = {k: [v] if isinstance(v, File) else list(v) for k, v in files.items()} + else: + d = {} + for k, v in files: + d.setdefault(k, []).append(v) + + self._dict = d + self._boundary = boundary or os.urandom(16).hex() + + # Standard dict interface + def keys(self) -> typing.KeysView[str]: + return self._dict.keys() + + def values(self) -> typing.ValuesView[File]: + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, File]: + return {k: v[0] for k, v in self._dict.items()}.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + if key in self._dict: + return self._dict[key][0] + return None + + # Multi dict interface + def multi_items(self) -> list[tuple[str, File]]: + multi_items: list[tuple[str, File]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def multi_dict(self) -> dict[str, list[File]]: + return {k: list(v) for k, v in self._dict.items()} + + def get_list(self, key: str) -> list[File]: + return list(self._dict.get(key, [])) + + # Content interface + def encode(self) -> Stream: + return MultiPart(files=self).encode() + + def content_type(self) -> str: + return f"multipart/form-data; boundary={self._boundary}" + + # Builtins + def __getitem__(self, key: str) -> File: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, Files) and + sorted(self.multi_items()) == sorted(other.multi_items()) + ) + + def __repr__(self) -> str: + return f"" + + +class JSON(Content): + def __init__(self, data: typing.Any) -> None: + self._data = data + + def encode(self) -> Stream: + content = json.dumps( + self._data, + ensure_ascii=False, + separators=(",", ":"), + allow_nan=False + ).encode("utf-8") + return ByteStream(content) + + def content_type(self) -> str: + return "application/json" + + def __repr__(self) -> str: + return f"" + + +class Text(Content): + def __init__(self, text: str) -> None: + self._text = text + + def encode(self) -> Stream: + content = self._text.encode("utf-8") + return ByteStream(content) + + def content_type(self) -> str: + return "text/plain; charset='utf-8'" + + def __repr__(self) -> str: + return f"" + + +class HTML(Content): + def __init__(self, text: str) -> None: + self._text = text + + def encode(self) -> Stream: + content = self._text.encode("utf-8") + return ByteStream(content) + + def content_type(self) -> str: + return "text/html; charset='utf-8'" + + def __repr__(self) -> str: + return f"" + + +class MultiPart(Content): + def __init__( + self, + form: ( + Form + | typing.Mapping[str, str | typing.Sequence[str]] + | typing.Sequence[tuple[str, str]] + | str + | None + ) = None, + files: ( + Files + | typing.Mapping[str, File | typing.Sequence[File]] + | typing.Sequence[tuple[str, File]] + | None + ) = None, + boundary: str | None = None + ): + self._form = form if isinstance(form , Form) else Form(form) + self._files = files if isinstance(files, Files) else Files(files) + self._boundary = os.urandom(16).hex() if boundary is None else boundary + + @property + def form(self) -> Form: + return self._form + + @property + def files(self) -> Files: + return self._files + + def encode(self) -> Stream: + form = [(key, value) for key, value in self._form.items()] + files = [(key, file._path) for key, file in self._files.items()] + return MultiPartStream(form, files, boundary=self._boundary) + + def content_type(self) -> str: + return f"multipart/form-data; boundary={self._boundary}" + + def __repr__(self) -> str: + return f"" diff --git a/src/ahttpx/_headers.py b/src/ahttpx/_headers.py new file mode 100644 index 00000000..dade8058 --- /dev/null +++ b/src/ahttpx/_headers.py @@ -0,0 +1,243 @@ +import re +import typing + + +__all__ = ["Headers"] + + +VALID_HEADER_CHARS = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789" + "!#$%&'*+-.^_`|~" +) + + +# TODO... +# +# * Comma folded values, eg. `Vary: ...` +# * Multiple Set-Cookie headers. +# * Non-ascii support. +# * Ordering, including `Host` header exception. + + +def headername(name: str) -> str: + if name.strip(VALID_HEADER_CHARS) or not name: + raise ValueError(f"Invalid HTTP header name {name!r}.") + return name + + +def headervalue(value: str) -> str: + value = value.strip(" ") + if not value or not value.isascii() or not value.isprintable(): + raise ValueError(f"Invalid HTTP header value {value!r}.") + return value + + +class Headers(typing.Mapping[str, str]): + def __init__( + self, + headers: typing.Mapping[str, str] | typing.Sequence[tuple[str, str]] | None = None, + ) -> None: + # {'accept': ('Accept', '*/*')} + d: dict[str, str] = {} + + if isinstance(headers, typing.Mapping): + # Headers({ + # 'Content-Length': '1024', + # 'Content-Type': 'text/plain; charset=utf-8', + # ) + d = {headername(k): headervalue(v) for k, v in headers.items()} + elif headers is not None: + # Headers([ + # ('Location', 'https://www.example.com'), + # ('Set-Cookie', 'session_id=3498jj489jhb98jn'), + # ]) + d = {headername(k): headervalue(v) for k, v in headers} + + self._dict = d + + def keys(self) -> typing.KeysView[str]: + """ + Return all the header keys. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.keys()) == ["Accept", "User-Agent"] + """ + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + """ + Return all the header values. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.values()) == ["*/*", "python/httpx"] + """ + return self._dict.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return all headers as (key, value) tuples. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.items()) == [("Accept", "*/*"), ("User-Agent", "python/httpx")] + """ + return self._dict.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert h.get("User-Agent") == "python/httpx" + """ + for k, v in self._dict.items(): + if k.lower() == key.lower(): + return v + return default + + def copy_set(self, key: str, value: str) -> "Headers": + """ + Return a new Headers instance, setting the value of a key. + + Usage: + + h = httpx.Headers({"Expires": "0"}) + h = h.copy_set("Expires", "Wed, 21 Oct 2015 07:28:00 GMT") + assert h == httpx.Headers({"Expires": "Wed, 21 Oct 2015 07:28:00 GMT"}) + """ + l = [] + seen = False + + # Either insert... + for k, v in self._dict.items(): + if k.lower() == key.lower(): + l.append((key, value)) + seen = True + else: + l.append((k, v)) + + # Or append... + if not seen: + l.append((key, value)) + + return Headers(l) + + def copy_remove(self, key: str) -> "Headers": + """ + Return a new Headers instance, removing the value of a key. + + Usage: + + h = httpx.Headers({"Accept": "*/*"}) + h = h.copy_remove("Accept") + assert h == httpx.Headers({}) + """ + h = {k: v for k, v in self._dict.items() if k.lower() != key.lower()} + return Headers(h) + + def copy_update(self, update: "Headers" | typing.Mapping[str, str] | None) -> "Headers": + """ + Return a new Headers instance, removing the value of a key. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + h = h.copy_update({"Accept-Encoding": "gzip"}) + assert h == httpx.Headers({"Accept": "*/*", "Accept-Encoding": "gzip", "User-Agent": "python/httpx"}) + """ + if update is None: + return self + + new = update if isinstance(update, Headers) else Headers(update) + + # Remove updated items using a case-insensitive approach... + keys = set([key.lower() for key in new.keys()]) + h = {k: v for k, v in self._dict.items() if k.lower() not in keys} + + # Perform the actual update... + h.update(dict(new)) + + return Headers(h) + + def __getitem__(self, key: str) -> str: + match = key.lower() + for k, v in self._dict.items(): + if k.lower() == match: + return v + raise KeyError(key) + + def __contains__(self, key: typing.Any) -> bool: + match = key.lower() + return any(k.lower() == match for k in self._dict.keys()) + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + self_lower = {k.lower(): v for k, v in self.items()} + other_lower = {k.lower(): v for k, v in Headers(other).items()} + return self_lower == other_lower + + def __repr__(self) -> str: + return f"" + + +def parse_opts_header(header: str) -> tuple[str, dict[str, str]]: + # The Content-Type header is described in RFC 2616 'Content-Type' + # https://datatracker.ietf.org/doc/html/rfc2616#section-14.17 + + # The 'type/subtype; parameter' format is described in RFC 2616 'Media Types' + # https://datatracker.ietf.org/doc/html/rfc2616#section-3.7 + + # Parameter quoting is described in RFC 2616 'Transfer Codings' + # https://datatracker.ietf.org/doc/html/rfc2616#section-3.6 + + header = header.strip() + content_type = '' + params = {} + + # Match the content type (up to the first semicolon or end) + match = re.match(r'^([^;]+)', header) + if match: + content_type = match.group(1).strip().lower() + rest = header[match.end():] + else: + return '', {} + + # Parse parameters, accounting for quoted strings + param_pattern = re.compile(r''' + ;\s* # Semicolon + optional whitespace + (?P[^=;\s]+) # Parameter key + = # Equal sign + (?P # Parameter value: + "(?:[^"\\]|\\.)*" # Quoted string with escapes + | # OR + [^;]* # Unquoted string (until semicolon) + ) + ''', re.VERBOSE) + + for match in param_pattern.finditer(rest): + key = match.group('key').lower() + value = match.group('value').strip() + if value.startswith('"') and value.endswith('"'): + # Remove surrounding quotes and unescape + value = re.sub(r'\\(.)', r'\1', value[1:-1]) + params[key] = value + + return content_type, params diff --git a/src/ahttpx/_network.py b/src/ahttpx/_network.py new file mode 100644 index 00000000..957e0361 --- /dev/null +++ b/src/ahttpx/_network.py @@ -0,0 +1,120 @@ +import asyncio +import ssl +import types +import typing + +import certifi + +from ._streams import Stream + + +__all__ = ["NetworkBackend", "NetworkStream", "timeout"] + + +class NetworkStream(Stream): + def __init__( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str = '' + ) -> None: + self._reader = reader + self._writer = writer + self._address = address + self._tls = False + self._closed = False + + async def read(self, size: int = -1) -> bytes: + if size < 0: + size = 64 * 1024 + return await self._reader.read(size) + + async def write(self, buffer: bytes) -> None: + self._writer.write(buffer) + await self._writer.drain() + + async def close(self) -> None: + if not self._closed: + self._writer.close() + await self._writer.wait_closed() + self._closed = True + + def __repr__(self): + description = "" + description += " TLS" if self._tls else "" + description += " CLOSED" if self._closed else "" + return f"" + + def __del__(self): + if not self._closed: + import warnings + warnings.warn("NetworkStream was garbage collected without being closed.") + + # Context managed usage... + async def __aenter__(self) -> "NetworkStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + await self.close() + + +class NetworkServer: + def __init__(self, host: str, port: int, server: asyncio.Server): + self.host = host + self.port = port + self._server = server + + # Context managed usage... + async def __aenter__(self) -> "NetworkServer": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + self._server.close() + await self._server.wait_closed() + + +class NetworkBackend: + def __init__(self, ssl_ctx: ssl.SSLContext | None = None): + self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx + + def create_default_context(self) -> ssl.SSLContext: + import certifi + return ssl.create_default_context(cafile=certifi.where()) + + async def connect(self, host: str, port: int) -> NetworkStream: + """ + Connect to the given address, returning a Stream instance. + """ + address = f"{host}:{port}" + reader, writer = await asyncio.open_connection(host, port) + return NetworkStream(reader, writer, address=address) + + async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream: + """ + Connect to the given address, returning a Stream instance. + """ + address = f"{host}:{port}" + reader, writer = await asyncio.open_connection(host, port) + await writer.start_tls(self._ssl_ctx, server_hostname=hostname) + return NetworkStream(reader, writer, address=address) + + async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer: + async def callback(reader, writer): + stream = NetworkStream(reader, writer) + await handler(stream) + + server = await asyncio.start_server(callback, host, port) + return NetworkServer(host, port, server) + + +Semaphore = asyncio.Semaphore +Lock = asyncio.Lock +timeout = asyncio.timeout +sleep = asyncio.sleep diff --git a/src/ahttpx/_parsers.py b/src/ahttpx/_parsers.py new file mode 100644 index 00000000..8a52a56f --- /dev/null +++ b/src/ahttpx/_parsers.py @@ -0,0 +1,515 @@ +import enum + +from ._streams import Stream + +__all__ = ['HTTPParser', 'Mode', 'ProtocolError'] + + +# TODO... + +# * Upgrade +# * CONNECT + +# * Support 'Expect: 100 Continue' +# * Add 'Error' state transitions +# * Add tests to trickle data +# * Add type annotations + +# * Optional... HTTP/1.0 support +# * Read trailing headers on Transfer-Encoding: chunked. Not just '\r\n'. +# * When writing Transfer-Encoding: chunked, split large writes into buffer size. +# * When reading Transfer-Encoding: chunked, handle incomplete reads from large chunk sizes. +# * .read() doesn't document if will always return maximum available. + +# * validate method, target, protocol in request line +# * validate protocol, status_code, reason_phrase in response line +# * validate name, value on headers + + +class State(enum.Enum): + WAIT = 0 + SEND_METHOD_LINE = 1 + SEND_STATUS_LINE = 2 + SEND_HEADERS = 3 + SEND_BODY = 4 + RECV_METHOD_LINE = 5 + RECV_STATUS_LINE = 6 + RECV_HEADERS = 7 + RECV_BODY = 8 + DONE = 9 + CLOSED = 10 + + +class Mode(enum.Enum): + CLIENT = 0 + SERVER = 1 + + +# The usual transitions will be... + +# IDLE, IDLE +# SEND_HEADERS, IDLE +# SEND_BODY, IDLE +# DONE, IDLE +# DONE, SEND_HEADERS +# DONE, SEND_BODY +# DONE, DONE + +# Then either back to IDLE, IDLE +# or move to CLOSED, CLOSED + +# 1. It is also valid for the server to start +# sending the response without waiting for the +# complete request. +# 2. 1xx status codes are interim states, and +# transition from SEND_HEADERS back to IDLE +# 3. ... + +class ProtocolError(Exception): + pass + + +class HTTPParser: + """ + Usage... + + client = HTTPParser(writer, reader) + client.send_method_line() + client.send_headers() + client.send_body() + client.recv_status_line() + client.recv_headers() + client.recv_body() + client.complete() + client.close() + """ + def __init__(self, stream: Stream, mode: str) -> None: + self.stream = stream + self.parser = ReadAheadParser(stream) + self.mode = {'CLIENT': Mode.CLIENT, 'SERVER': Mode.SERVER}[mode] + + # Track state... + if self.mode == Mode.CLIENT: + self.send_state: State = State.SEND_METHOD_LINE + self.recv_state: State = State.WAIT + else: + self.recv_state = State.RECV_METHOD_LINE + self.send_state = State.WAIT + + # Track message framing... + self.send_content_length: int | None = 0 + self.recv_content_length: int | None = 0 + self.send_seen_length = 0 + self.recv_seen_length = 0 + + # Track connection keep alive... + self.send_keep_alive = True + self.recv_keep_alive = True + + # Special states... + self.processing_1xx = False + + async def send_method_line(self, method: bytes, target: bytes, protocol: bytes) -> None: + """ + Send the initial request line: + + >>> p.send_method_line(b'GET', b'/', b'HTTP/1.1') + + Sending state will switch to SEND_HEADERS state. + """ + if self.send_state != State.SEND_METHOD_LINE: + msg = f"Called 'send_method_line' in invalid state {self.send_state}" + raise ProtocolError(msg) + + # Send initial request line, eg. "GET / HTTP/1.1" + if protocol != b'HTTP/1.1': + raise ProtocolError("Sent unsupported protocol version") + data = b" ".join([method, target, protocol]) + b"\r\n" + await self.stream.write(data) + + self.send_state = State.SEND_HEADERS + self.recv_state = State.RECV_STATUS_LINE + + async def send_status_line(self, protocol: bytes, status_code: int, reason: bytes) -> None: + """ + Send the initial response line: + + >>> p.send_method_line(b'HTTP/1.1', 200, b'OK') + + Sending state will switch to SEND_HEADERS state. + """ + if self.send_state != State.SEND_STATUS_LINE: + msg = f"Called 'send_status_line' in invalid state {self.send_state}" + raise ProtocolError(msg) + + # Send initial request line, eg. "GET / HTTP/1.1" + if protocol != b'HTTP/1.1': + raise ProtocolError("Sent unsupported protocol version") + status_code_bytes = str(status_code).encode('ascii') + data = b" ".join([protocol, status_code_bytes, reason]) + b"\r\n" + await self.stream.write(data) + + self.send_state = State.SEND_HEADERS + + async def send_headers(self, headers: list[tuple[bytes, bytes]]) -> None: + """ + Send the request headers: + + >>> p.send_headers([(b'Host', b'www.example.com')]) + + Sending state will switch to SEND_BODY state. + """ + if self.send_state != State.SEND_HEADERS: + msg = f"Called 'send_headers' in invalid state {self.send_state}" + raise ProtocolError(msg) + + # Update header state + seen_host = False + for name, value in headers: + lname = name.lower() + if lname == b'host': + seen_host = True + elif lname == b'content-length': + self.send_content_length = bounded_int( + value, + max_digits=20, + exc_text="Sent invalid Content-Length" + ) + elif lname == b'connection' and value == b'close': + self.send_keep_alive = False + elif lname == b'transfer-encoding' and value == b'chunked': + self.send_content_length = None + + if self.mode == Mode.CLIENT and not seen_host: + raise ProtocolError("Request missing 'Host' header") + + # Send request headers + lines = [name + b": " + value + b"\r\n" for name, value in headers] + data = b"".join(lines) + b"\r\n" + await self.stream.write(data) + + self.send_state = State.SEND_BODY + + async def send_body(self, body: bytes) -> None: + """ + Send the request body. An empty bytes argument indicates the end of the stream: + + >>> p.send_body(b'') + + Sending state will switch to DONE. + """ + if self.send_state != State.SEND_BODY: + msg = f"Called 'send_body' in invalid state {self.send_state}" + raise ProtocolError(msg) + + if self.send_content_length is None: + # Transfer-Encoding: chunked + self.send_seen_length += len(body) + marker = f'{len(body):x}\r\n'.encode('ascii') + await self.stream.write(marker + body + b'\r\n') + + else: + # Content-Length: xxx + self.send_seen_length += len(body) + if self.send_seen_length > self.send_content_length: + msg = 'Too much data sent for declared Content-Length' + raise ProtocolError(msg) + if self.send_seen_length < self.send_content_length and body == b'': + msg = 'Not enough data sent for declared Content-Length' + raise ProtocolError(msg) + if body: + await self.stream.write(body) + + if body == b'': + # Handle body close + self.send_state = State.DONE + + async def recv_method_line(self) -> tuple[bytes, bytes, bytes]: + """ + Receive the initial request method line: + + >>> method, target, protocol = p.recv_status_line() + + Receive state will switch to RECV_HEADERS. + """ + if self.recv_state != State.RECV_METHOD_LINE: + msg = f"Called 'recv_method_line' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + # Read initial response line, eg. "GET / HTTP/1.1" + exc_text = "reading request method line" + line = await self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text) + method, target, protocol = line.split(b" ", 2) + if protocol != b'HTTP/1.1': + raise ProtocolError("Received unsupported protocol version") + + self.recv_state = State.RECV_HEADERS + self.send_state = State.SEND_STATUS_LINE + return method, target, protocol + + async def recv_status_line(self) -> tuple[bytes, int, bytes]: + """ + Receive the initial response status line: + + >>> protocol, status_code, reason_phrase = p.recv_status_line() + + Receive state will switch to RECV_HEADERS. + """ + if self.recv_state != State.RECV_STATUS_LINE: + msg = f"Called 'recv_status_line' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + # Read initial response line, eg. "HTTP/1.1 200 OK" + exc_text = "reading response status line" + line = await self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text) + protocol, status_code_str, reason_phrase = line.split(b" ", 2) + if protocol != b'HTTP/1.1': + raise ProtocolError("Received unsupported protocol version") + + status_code = bounded_int( + status_code_str, + max_digits=3, + exc_text="Received invalid status code" + ) + if status_code < 100: + raise ProtocolError("Received invalid status code") + # 1xx status codes preceed the final response status code + self.processing_1xx = status_code < 200 + + self.recv_state = State.RECV_HEADERS + return protocol, status_code, reason_phrase + + async def recv_headers(self) -> list[tuple[bytes, bytes]]: + """ + Receive the response headers: + + >>> headers = p.recv_status_line() + + Receive state will switch to RECV_BODY by default. + Receive state will revert to RECV_STATUS_CODE for interim 1xx responses. + """ + if self.recv_state != State.RECV_HEADERS: + msg = f"Called 'recv_headers' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + # Read response headers + headers = [] + exc_text = "reading response headers" + while line := await self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text): + name, value = line.split(b":", 1) + value = value.strip(b" ") + headers.append((name, value)) + + # Update header state + seen_host = False + for name, value in headers: + lname = name.lower() + if lname == b'host': + seen_host = True + elif lname == b'content-length': + self.recv_content_length = bounded_int( + value, + max_digits=20, + exc_text="Received invalid Content-Length" + ) + elif lname == b'connection' and value == b'close': + self.recv_keep_alive = False + elif lname == b'transfer-encoding' and value == b'chunked': + self.recv_content_length = None + + if self.mode == Mode.SERVER and not seen_host: + raise ProtocolError("Request missing 'Host' header") + + if self.processing_1xx: + # 1xx status codes preceed the final response status code + self.processing_1xx = False + self.recv_state = State.RECV_STATUS_LINE + else: + self.recv_state = State.RECV_BODY + return headers + + async def recv_body(self) -> bytes: + """ + Receive the response body. An empty byte string indicates the end of the stream: + + >>> buffer = bytearray() + >>> while body := p.recv_body() + >>> buffer.extend(body) + + The server will switch to DONE. + """ + if self.recv_state != State.RECV_BODY: + msg = f"Called 'recv_body' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + if self.recv_content_length is None: + # Transfer-Encoding: chunked + exc_text = 'reading chunk size' + line = await self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text) + sizestr, _, _ = line.partition(b";") + + exc_text = "Received invalid chunk size" + size = bounded_hex(sizestr, max_digits=8, exc_text=exc_text) + if size > 0: + body = await self.parser.read(size=size) + exc_text = 'reading chunk data' + await self.parser.read_until(b"\r\n", max_size=2, exc_text=exc_text) + self.recv_seen_length += len(body) + else: + body = b'' + exc_text = 'reading chunk termination' + await self.parser.read_until(b"\r\n", max_size=2, exc_text=exc_text) + + else: + # Content-Length: xxx + remaining = self.recv_content_length - self.recv_seen_length + size = min(remaining, 4096) + body = await self.parser.read(size=size) + self.recv_seen_length += len(body) + if self.recv_seen_length < self.recv_content_length and body == b'': + msg = 'Not enough data received for declared Content-Length' + raise ProtocolError(msg) + + if body == b'': + # Handle body close + self.recv_state = State.DONE + return body + + async def complete(self): + is_fully_complete = self.send_state == State.DONE and self.recv_state == State.DONE + is_keepalive = self.send_keep_alive and self.recv_keep_alive + + if not (is_fully_complete and is_keepalive): + await self.close() + return + + if self.mode == Mode.CLIENT: + self.send_state = State.SEND_METHOD_LINE + self.recv_state = State.WAIT + else: + self.recv_state = State.RECV_METHOD_LINE + self.send_state = State.WAIT + + self.send_content_length = 0 + self.recv_content_length = 0 + self.send_seen_length = 0 + self.recv_seen_length = 0 + self.send_keep_alive = True + self.recv_keep_alive = True + self.processing_1xx = False + + async def close(self): + if self.send_state != State.CLOSED: + self.send_state = State.CLOSED + self.recv_state = State.CLOSED + await self.stream.close() + + def is_idle(self) -> bool: + return ( + self.send_state == State.SEND_METHOD_LINE or + self.recv_state == State.RECV_METHOD_LINE + ) + + def is_closed(self) -> bool: + return self.send_state == State.CLOSED + + def description(self) -> str: + return { + State.SEND_METHOD_LINE: "idle", + State.CLOSED: "closed", + }.get(self.send_state, "active") + + def __repr__(self) -> str: + cl_state = self.send_state.name + sr_state = self.recv_state.name + detail = f"client {cl_state}, server {sr_state}" + return f'' + + +class ReadAheadParser: + """ + A buffered I/O stream, with methods for read-ahead parsing. + """ + def __init__(self, stream: Stream) -> None: + self._buffer = b'' + self._stream = stream + self._chunk_size = 4096 + + async def _read_some(self) -> bytes: + if self._buffer: + ret, self._buffer = self._buffer, b'' + return ret + return await self._stream.read(self._chunk_size) + + def _push_back(self, buffer): + assert self._buffer == b'' + self._buffer = buffer + + async def read(self, size: int) -> bytes: + """ + Read and return up to 'size' bytes from the stream, with I/O buffering provided. + + * Returns b'' to indicate connection close. + """ + buffer = bytearray() + while len(buffer) < size: + chunk = await self._read_some() + if not chunk: + break + buffer.extend(chunk) + + if len(buffer) > size: + buffer, push_back = buffer[:size], buffer[size:] + self._push_back(bytes(push_back)) + return bytes(buffer) + + async def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes: + """ + Read and return bytes from the stream, delimited by marker. + + * The marker is not included in the return bytes. + * The marker is consumed from the I/O stream. + * Raises `ProtocolError` if the stream closes before a marker occurance. + * Raises `ProtocolError` if marker did not occur within 'max_size + len(marker)' bytes. + """ + buffer = bytearray() + while len(buffer) <= max_size: + chunk = await self._read_some() + if not chunk: + # stream closed before marker found. + raise ProtocolError(f"Stream closed early {exc_text}") + start_search = max(len(buffer) - len(marker), 0) + buffer.extend(chunk) + index = buffer.find(marker, start_search) + + if index > max_size: + # marker was found, though 'max_size' exceeded. + raise ProtocolError(f"Exceeded maximum size {exc_text}") + elif index >= 0: + endindex = index + len(marker) + self._push_back(bytes(buffer[endindex:])) + return bytes(buffer[:index]) + + raise ProtocolError(f"Exceeded maximum size {exc_text}") + + +def bounded_int(intstr: bytes, max_digits: int, exc_text: str): + if len(intstr) > max_digits: + # Length of bytestring exceeds maximum. + raise ProtocolError(exc_text) + if len(intstr.strip(b'0123456789')) != 0: + # Contains invalid characters. + raise ProtocolError(exc_text) + + return int(intstr) + + +def bounded_hex(hexstr: bytes, max_digits: int, exc_text: str): + if len(hexstr) > max_digits: + # Length of bytestring exceeds maximum. + raise ProtocolError(exc_text) + if len(hexstr.strip(b'0123456789abcdefABCDEF')) != 0: + # Contains invalid characters. + raise ProtocolError(exc_text) + + return int(hexstr, base=16) diff --git a/src/ahttpx/_pool.py b/src/ahttpx/_pool.py new file mode 100644 index 00000000..f712cfac --- /dev/null +++ b/src/ahttpx/_pool.py @@ -0,0 +1,284 @@ +import time +import typing +import types + +from ._content import Content +from ._headers import Headers +from ._network import Lock, NetworkBackend, Semaphore +from ._parsers import HTTPParser +from ._response import Response +from ._request import Request +from ._streams import HTTPStream, Stream +from ._urls import URL + + +__all__ = [ + "Transport", + "ConnectionPool", + "Connection", + "open_connection", +] + + +class Transport: + async def send(self, request: Request) -> Response: + raise NotImplementedError() + + async def close(self): + pass + + async def request( + self, + method: str, + url: URL | str, + headers: Headers | dict[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = Request(method, url, headers=headers, content=content) + async with await self.send(request) as response: + await response.read() + return response + + async def stream( + self, + method: str, + url: URL | str, + headers: Headers | dict[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = Request(method, url, headers=headers, content=content) + response = await self.send(request) + return response + + +class ConnectionPool(Transport): + def __init__(self, backend: NetworkBackend | None = None): + if backend is None: + backend = NetworkBackend() + + self._connections: list[Connection] = [] + self._network_backend = backend + self._limit_concurrency = Semaphore(100) + self._closed = False + + # Public API... + async def send(self, request: Request) -> Response: + if self._closed: + raise RuntimeError("ConnectionPool is closed.") + + # TODO: concurrency limiting + await self._cleanup() + connection = await self._get_connection(request) + response = await connection.send(request) + return response + + async def close(self): + self._closed = True + closing = list(self._connections) + self._connections = [] + for conn in closing: + await conn.close() + + # Create or reuse connections as required... + async def _get_connection(self, request: Request) -> "Connection": + # Attempt to reuse an existing connection. + url = request.url + origin = URL(scheme=url.scheme, host=url.host, port=url.port) + now = time.monotonic() + for conn in self._connections: + if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now): + return conn + + # Or else create a new connection. + conn = await open_connection( + origin, + hostname=request.headers["Host"], + backend=self._network_backend + ) + self._connections.append(conn) + return conn + + # Connection pool management... + async def _cleanup(self) -> None: + now = time.monotonic() + for conn in list(self._connections): + if conn.is_expired(now): + await conn.close() + if conn.is_closed(): + self._connections.remove(conn) + + @property + def connections(self) -> typing.List['Connection']: + return [c for c in self._connections] + + def description(self) -> str: + counts = {"active": 0} + for status in [c.description() for c in self._connections]: + counts[status] = counts.get(status, 0) + 1 + return ", ".join(f"{count} {status}" for status, count in counts.items()) + + # Builtins... + def __repr__(self) -> str: + return f"" + + def __del__(self): + if not self._closed: + import warnings + warnings.warn("ConnectionPool was garbage collected without being closed.") + + async def __aenter__(self) -> "ConnectionPool": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + await self.close() + + +class Connection(Transport): + def __init__(self, stream: Stream, origin: URL | str): + self._stream = stream + self._origin = URL(origin) + self._keepalive_duration = 5.0 + self._idle_expiry = time.monotonic() + self._keepalive_duration + self._request_lock = Lock() + self._parser = HTTPParser(stream, mode='CLIENT') + + # API for connection pool management... + def origin(self) -> URL: + return self._origin + + def is_idle(self) -> bool: + return self._parser.is_idle() + + def is_expired(self, when: float) -> bool: + return self._parser.is_idle() and when > self._idle_expiry + + def is_closed(self) -> bool: + return self._parser.is_closed() + + def description(self) -> str: + return self._parser.description() + + # API entry points... + async def send(self, request: Request) -> Response: + #async with self._request_lock: + # try: + await self._send_head(request) + await self._send_body(request) + code, headers = await self._recv_head() + stream = HTTPStream(self._recv_body, self._complete) + # TODO... + return Response(code, headers=headers, content=stream) + # finally: + # await self._cycle_complete() + + async def close(self) -> None: + async with self._request_lock: + await self._close() + + # Top-level API for working directly with a connection. + async def request( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + url = self._origin.join(url) + request = Request(method, url, headers=headers, content=content) + async with await self.send(request) as response: + await response.read() + return response + + async def stream( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + url = self._origin.join(url) + request = Request(method, url, headers=headers, content=content) + return await self.send(request) + + # Send the request... + async def _send_head(self, request: Request) -> None: + method = request.method.encode('ascii') + target = request.url.target.encode('ascii') + protocol = b'HTTP/1.1' + await self._parser.send_method_line(method, target, protocol) + headers = [ + (k.encode('ascii'), v.encode('ascii')) + for k, v in request.headers.items() + ] + await self._parser.send_headers(headers) + + async def _send_body(self, request: Request) -> None: + while data := await request.stream.read(64 * 1024): + await self._parser.send_body(data) + await self._parser.send_body(b'') + + # Receive the response... + async def _recv_head(self) -> tuple[int, Headers]: + _, code, _ = await self._parser.recv_status_line() + h = await self._parser.recv_headers() + headers = Headers([ + (k.decode('ascii'), v.decode('ascii')) + for k, v in h + ]) + return code, headers + + async def _recv_body(self) -> bytes: + return await self._parser.recv_body() + + # Request/response cycle complete... + async def _complete(self) -> None: + await self._parser.complete() + self._idle_expiry = time.monotonic() + self._keepalive_duration + + async def _close(self) -> None: + await self._parser.close() + + # Builtins... + def __repr__(self) -> str: + return f"" + + async def __aenter__(self) -> "Connection": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + await self.close() + + +async def open_connection( + url: URL | str, + hostname: str = '', + backend: NetworkBackend | None = None, + ) -> Connection: + + if isinstance(url, str): + url = URL(url) + + if url.scheme not in ("http", "https"): + raise ValueError("URL scheme must be 'http://' or 'https://'.") + if backend is None: + backend = NetworkBackend() + + host = url.host + port = url.port or {"http": 80, "https": 443}[url.scheme] + + if url.scheme == "https": + stream = await backend.connect_tls(host, port, hostname) + else: + stream = await backend.connect(host, port) + + return Connection(stream, url) diff --git a/src/ahttpx/_quickstart.py b/src/ahttpx/_quickstart.py new file mode 100644 index 00000000..8b6e12ff --- /dev/null +++ b/src/ahttpx/_quickstart.py @@ -0,0 +1,49 @@ +import typing + +from ._client import Client +from ._content import Content +from ._headers import Headers +from ._streams import Stream +from ._urls import URL + + +__all__ = ['get', 'post', 'put', 'patch', 'delete'] + + +async def get( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, +): + async with Client() as client: + return await client.request("GET", url=url, headers=headers) + +async def post( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, +): + async with Client() as client: + return await client.request("POST", url, headers=headers, content=content) + +async def put( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, +): + async with Client() as client: + return await client.request("PUT", url, headers=headers, content=content) + +async def patch( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, +): + async with Client() as client: + return await client.request("PATCH", url, headers=headers, content=content) + +async def delete( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, +): + async with Client() as client: + return await client.request("DELETE", url=url, headers=headers) diff --git a/src/ahttpx/_request.py b/src/ahttpx/_request.py new file mode 100644 index 00000000..78b82282 --- /dev/null +++ b/src/ahttpx/_request.py @@ -0,0 +1,93 @@ +import types +import typing + +from ._content import Content +from ._streams import ByteStream, Stream +from ._headers import Headers +from ._urls import URL + +__all__ = ["Request"] + + +class Request: + def __init__( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + self.method = method + self.url = URL(url) + self.headers = Headers(headers) + self.stream: Stream = ByteStream(b"") + + # https://datatracker.ietf.org/doc/html/rfc2616#section-14.23 + # RFC 2616, Section 14.23, Host. + # + # A client MUST include a Host header field in all HTTP/1.1 request messages. + if "Host" not in self.headers: + self.headers = self.headers.copy_set("Host", self.url.netloc) + + if content is not None: + if isinstance(content, bytes): + self.stream = ByteStream(content) + elif isinstance(content, Stream): + self.stream = content + elif isinstance(content, Content): + ct = content.content_type() + self.stream = content.encode() + self.headers = self.headers.copy_set("Content-Type", ct) + else: + raise TypeError(f'Expected `Content | Stream | bytes | None` got {type(content)}') + + # https://datatracker.ietf.org/doc/html/rfc2616#section-4.3 + # RFC 2616, Section 4.3, Message Body. + # + # The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers. + content_length: int | None = self.stream.size + if content_length is None: + self.headers = self.headers.copy_set("Transfer-Encoding", "chunked") + elif content_length > 0: + self.headers = self.headers.copy_set("Content-Length", str(content_length)) + + elif method in ("POST", "PUT", "PATCH"): + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2 + # RFC 7230, Section 3.3.2, Content Length. + # + # A user agent SHOULD send a Content-Length in a request message when no + # Transfer-Encoding is sent and the request method defines a meaning for + # an enclosed payload body. For example, a Content-Length header field is + # normally sent in a POST request even when the value is 0. + # (indicating an empty payload body). + self.headers = self.headers.copy_set("Content-Length", "0") + + @property + def body(self) -> bytes: + if not hasattr(self, '_body'): + raise RuntimeError("'.body' cannot be accessed without calling '.read()'") + return self._body + + async def read(self) -> bytes: + if not hasattr(self, '_body'): + self._body = await self.stream.read() + self.stream = ByteStream(self._body) + return self._body + + async def close(self) -> None: + await self.stream.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + await self.close() + + def __repr__(self): + return f"" diff --git a/src/ahttpx/_response.py b/src/ahttpx/_response.py new file mode 100644 index 00000000..db1de832 --- /dev/null +++ b/src/ahttpx/_response.py @@ -0,0 +1,158 @@ +import types +import typing + +from ._content import Content +from ._streams import ByteStream, Stream +from ._headers import Headers, parse_opts_header + +__all__ = ["Response"] + +# We're using the same set as stdlib `http.HTTPStatus` here... +# +# https://github.com/python/cpython/blob/main/Lib/http/__init__.py +_codes = { + 100: "Continue", + 101: "Switching Protocols", + 102: "Processing", + 103: "Early Hints", + 200: "OK", + 201: "Created", + 202: "Accepted", + 203: "Non-Authoritative Information", + 204: "No Content", + 205: "Reset Content", + 206: "Partial Content", + 207: "Multi-Status", + 208: "Already Reported", + 226: "IM Used", + 300: "Multiple Choices", + 301: "Moved Permanently", + 302: "Found", + 303: "See Other", + 304: "Not Modified", + 305: "Use Proxy", + 307: "Temporary Redirect", + 308: "Permanent Redirect", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 412: "Precondition Failed", + 413: "Content Too Large", + 414: "URI Too Long", + 415: "Unsupported Media Type", + 416: "Range Not Satisfiable", + 417: "Expectation Failed", + 418: "I'm a Teapot", + 421: "Misdirected Request", + 422: "Unprocessable Content", + 423: "Locked", + 424: "Failed Dependency", + 425: "Too Early", + 426: "Upgrade Required", + 428: "Precondition Required", + 429: "Too Many Requests", + 431: "Request Header Fields Too Large", + 451: "Unavailable For Legal Reasons", + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + 505: "HTTP Version Not Supported", + 506: "Variant Also Negotiates", + 507: "Insufficient Storage", + 508: "Loop Detected", + 510: "Not Extended", + 511: "Network Authentication Required", +} + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + self.status_code = status_code + self.headers = Headers(headers) + self.stream: Stream = ByteStream(b"") + + if content is not None: + if isinstance(content, bytes): + self.stream = ByteStream(content) + elif isinstance(content, Stream): + self.stream = content + elif isinstance(content, Content): + ct = content.content_type() + self.stream = content.encode() + self.headers = self.headers.copy_set("Content-Type", ct) + else: + raise TypeError(f'Expected `Content | Stream | bytes | None` got {type(content)}') + + # https://datatracker.ietf.org/doc/html/rfc2616#section-4.3 + # RFC 2616, Section 4.3, Message Body. + # + # All 1xx (informational), 204 (no content), and 304 (not modified) responses + # MUST NOT include a message-body. All other responses do include a + # message-body, although it MAY be of zero length. + if status_code >= 200 and status_code != 204 and status_code != 304: + content_length: int | None = self.stream.size + if content_length is None: + self.headers = self.headers.copy_set("Transfer-Encoding", "chunked") + else: + self.headers = self.headers.copy_set("Content-Length", str(content_length)) + + @property + def reason_phrase(self): + return _codes.get(self.status_code, "Unknown Status Code") + + @property + def body(self) -> bytes: + if not hasattr(self, '_body'): + raise RuntimeError("'.body' cannot be accessed without calling '.read()'") + return self._body + + @property + def text(self) -> str: + if not hasattr(self, '_body'): + raise RuntimeError("'.text' cannot be accessed without calling '.read()'") + if not hasattr(self, '_text'): + ct = self.headers.get('Content-Type', '') + media, opts = parse_opts_header(ct) + charset = 'utf-8' + if media.startswith('text/'): + charset = opts.get('charset', 'utf-8') + self._text = self._body.decode(charset) + return self._text + + async def read(self) -> bytes: + if not hasattr(self, '_body'): + self._body = await self.stream.read() + return self._body + + async def close(self) -> None: + await self.stream.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + await self.close() + + def __repr__(self): + return f"" diff --git a/src/ahttpx/_server.py b/src/ahttpx/_server.py new file mode 100644 index 00000000..a9103cc9 --- /dev/null +++ b/src/ahttpx/_server.py @@ -0,0 +1,126 @@ +import contextlib +import logging +import time + +from ._content import Text +from ._parsers import HTTPParser +from ._request import Request +from ._response import Response +from ._network import NetworkBackend, sleep +from ._streams import HTTPStream + +__all__ = [ + "serve_http", "run" +] + +logger = logging.getLogger("httpx.server") + + +class ConnectionClosed(Exception): + pass + + +class HTTPConnection: + def __init__(self, stream, endpoint): + self._stream = stream + self._endpoint = endpoint + self._parser = HTTPParser(stream, mode='SERVER') + self._keepalive_duration = 5.0 + self._idle_expiry = time.monotonic() + self._keepalive_duration + + # API entry points... + async def handle_requests(self): + try: + while not self._parser.is_closed(): + method, url, headers = await self._recv_head() + stream = HTTPStream(self._recv_body, self._complete) + # TODO: Handle endpoint exceptions + async with Request(method, url, headers=headers, content=stream) as request: + try: + response = await self._endpoint(request) + status_line = f"{request.method} {request.url.target} [{response.status_code} {response.reason_phrase}]" + logger.info(status_line) + except Exception: + logger.error("Internal Server Error", exc_info=True) + content = Text("Internal Server Error") + err = Response(code=500, content=content) + await self._send_head(err) + await self._send_body(err) + else: + await self._send_head(response) + await self._send_body(response) + except Exception: + logger.error("Internal Server Error", exc_info=True) + + async def close(self): + self._parser.close() + + # Receive the request... + async def _recv_head(self) -> tuple[str, str, list[tuple[str, str]]]: + method, target, _ = await self._parser.recv_method_line() + m = method.decode('ascii') + t = target.decode('ascii') + headers = await self._parser.recv_headers() + h = [ + (k.decode('latin-1'), v.decode('latin-1')) + for k, v in headers + ] + return m, t, h + + async def _recv_body(self): + return await self._parser.recv_body() + + # Return the response... + async def _send_head(self, response: Response): + protocol = b"HTTP/1.1" + status = response.status_code + reason = response.reason_phrase.encode('ascii') + await self._parser.send_status_line(protocol, status, reason) + headers = [ + (k.encode('ascii'), v.encode('ascii')) + for k, v in response.headers.items() + ] + await self._parser.send_headers(headers) + + async def _send_body(self, response: Response): + while data := await response.stream.read(64 * 1024): + await self._parser.send_body(data) + await self._parser.send_body(b'') + + # Start it all over again... + async def _complete(self): + await self._parser.complete + self._idle_expiry = time.monotonic() + self._keepalive_duration + + +class HTTPServer: + def __init__(self, host, port): + self.url = f"http://{host}:{port}/" + + async def wait(self): + while(True): + await sleep(1) + + +@contextlib.asynccontextmanager +async def serve_http(endpoint): + async def handler(stream): + connection = HTTPConnection(stream, endpoint) + await connection.handle_requests() + + logging.basicConfig( + format="%(levelname)s [%(asctime)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG + ) + + backend = NetworkBackend() + async with await backend.serve("127.0.0.1", 8080, handler) as server: + server = HTTPServer(server.host, server.port) + logger.info(f"Serving on {server.url} (Press CTRL+C to quit)") + yield server + + +async def run(app): + async with await serve_http(app) as server: + server.wait() diff --git a/src/ahttpx/_streams.py b/src/ahttpx/_streams.py new file mode 100644 index 00000000..d5e5ad0d --- /dev/null +++ b/src/ahttpx/_streams.py @@ -0,0 +1,235 @@ +import io +import types +import os + + +class Stream: + async def read(self, size: int=-1) -> bytes: + raise NotImplementedError() + + async def write(self, data: bytes) -> None: + raise NotImplementedError() + + async def close(self) -> None: + raise NotImplementedError() + + @property + def size(self) -> int | None: + return None + + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + await self.close() + + +class ByteStream(Stream): + def __init__(self, data: bytes = b''): + self._buffer = io.BytesIO(data) + self._size = len(data) + + async def read(self, size: int=-1) -> bytes: + return self._buffer.read(size) + + async def close(self) -> None: + self._buffer.close() + + @property + def size(self) -> int | None: + return self._size + + +class DuplexStream(Stream): + """ + DuplexStream supports both `read` and `write` operations, + which are applied to seperate buffers. + + This stream can be used for testing network parsers. + """ + + def __init__(self, data: bytes = b''): + self._read_buffer = io.BytesIO(data) + self._write_buffer = io.BytesIO() + + async def read(self, size: int=-1) -> bytes: + return self._read_buffer.read(size) + + async def write(self, buffer: bytes): + return self._write_buffer.write(buffer) + + async def close(self) -> None: + self._read_buffer.close() + self._write_buffer.close() + + def input_bytes(self) -> bytes: + return self._read_buffer.getvalue() + + def output_bytes(self) -> bytes: + return self._write_buffer.getvalue() + + +class FileStream(Stream): + def __init__(self, path): + self._path = path + self._fileobj = None + self._size = None + + async def read(self, size: int=-1) -> bytes: + if self._fileobj is None: + raise ValueError('I/O operation on unopened file') + return self._fileobj.read(size) + + async def open(self): + self._fileobj = open(self._path, 'rb') + self._size = os.path.getsize(self._path) + return self + + async def close(self) -> None: + if self._fileobj is not None: + self._fileobj.close() + + @property + def size(self) -> int | None: + return self._size + + async def __aenter__(self): + await self.open() + return self + + +class HTTPStream(Stream): + def __init__(self, next_chunk, complete): + self._next_chunk = next_chunk + self._complete = complete + self._buffer = io.BytesIO() + + async def read(self, size=-1) -> bytes: + sections = [] + length = 0 + + # If we have any data in the buffer read that and clear the buffer. + buffered = self._buffer.read() + if buffered: + sections.append(buffered) + length += len(buffered) + self._buffer.seek(0) + self._buffer.truncate(0) + + # Read each chunk in turn. + while (size < 0) or (length < size): + section = await self._next_chunk() + sections.append(section) + length += len(section) + if section == b'': + break + + # If we've more data than requested, then push some back into the buffer. + output = b''.join(sections) + if size > -1 and len(output) > size: + output, remainder = output[:size], output[size:] + self._buffer.write(remainder) + self._buffer.seek(0) + + return output + + async def close(self) -> None: + self._buffer.close() + if self._complete is not None: + await self._complete() + + +class MultiPartStream(Stream): + def __init__(self, form: list[tuple[str, str]], files: list[tuple[str, str]], boundary=''): + self._form = list(form) + self._files = list(files) + self._boundary = boundary or os.urandom(16).hex() + # Mutable state... + self._form_progress = list(self._form) + self._files_progress = list(self._files) + self._filestream: FileStream | None = None + self._complete = False + self._buffer = io.BytesIO() + + async def read(self, size=-1) -> bytes: + sections = [] + length = 0 + + # If we have any data in the buffer read that and clear the buffer. + buffered = self._buffer.read() + if buffered: + sections.append(buffered) + length += len(buffered) + self._buffer.seek(0) + self._buffer.truncate(0) + + # Read each multipart section in turn. + while (size < 0) or (length < size): + section = await self._read_next_section() + sections.append(section) + length += len(section) + if section == b'': + break + + # If we've more data than requested, then push some back into the buffer. + output = b''.join(sections) + if size > -1 and len(output) > size: + output, remainder = output[:size], output[size:] + self._buffer.write(remainder) + self._buffer.seek(0) + + return output + + async def _read_next_section(self) -> bytes: + if self._form_progress: + # return a form item + key, value = self._form_progress.pop(0) + name = key.translate({10: "%0A", 13: "%0D", 34: "%22"}) + return ( + f"--{self._boundary}\r\n" + f'Content-Disposition: form-data; name="{name}"\r\n' + f"\r\n" + f"{value}\r\n" + ).encode("utf-8") + elif self._files_progress and self._filestream is None: + # return start of a file item + key, value = self._files_progress.pop(0) + self._filestream = await FileStream(value).open() + name = key.translate({10: "%0A", 13: "%0D", 34: "%22"}) + filename = os.path.basename(value) + return ( + f"--{self._boundary}\r\n" + f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n' + f"\r\n" + ).encode("utf-8") + elif self._filestream is not None: + chunk = await self._filestream.read(64*1024) + if chunk != b'': + # return some bytes from file + return chunk + else: + # return end of file item + await self._filestream.close() + self._filestream = None + return b"\r\n" + elif not self._complete: + # return final section of multipart + self._complete = True + return f"--{self._boundary}--\r\n".encode("utf-8") + # return EOF marker + return b"" + + async def close(self) -> None: + if self._filestream is not None: + await self._filestream.close() + self._filestream = None + self._buffer.close() + + @property + def size(self) -> int | None: + return None diff --git a/src/ahttpx/_urlencode.py b/src/ahttpx/_urlencode.py new file mode 100644 index 00000000..1a83b620 --- /dev/null +++ b/src/ahttpx/_urlencode.py @@ -0,0 +1,85 @@ +import re + +__all__ = ["quote", "unquote", "urldecode", "urlencode"] + + +# Matchs a sequence of one or more '%xx' escapes. +PERCENT_ENCODED_REGEX = re.compile("(%[A-Fa-f0-9][A-Fa-f0-9])+") + +# https://datatracker.ietf.org/doc/html/rfc3986#section-2.3 +SAFE = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + + +def urlencode(multidict, safe=SAFE): + pairs = [] + for key, values in multidict.items(): + pairs.extend([(key, value) for value in values]) + + safe += "+" + pairs = [(k.replace(" ", "+"), v.replace(" ", "+")) for k, v in pairs] + + return "&".join( + f"{quote(key, safe)}={quote(val, safe)}" + for key, val in pairs + ) + + +def urldecode(string): + parts = [part.partition("=") for part in string.split("&") if part] + pairs = [ + (unquote(key), unquote(val)) + for key, _, val in parts + ] + + pairs = [(k.replace("+", " "), v.replace("+", " ")) for k, v in pairs] + + ret = {} + for k, v in pairs: + ret.setdefault(k, []).append(v) + return ret + + +def quote(string, safe=SAFE): + # Fast path if the string is already safe. + if not string.strip(safe): + return string + + # Replace any characters not in the safe set with '%xx' escape sequences. + return "".join([ + char if char in safe else percent(char) + for char in string + ]) + + +def unquote(string): + # Fast path if the string is not quoted. + if '%' not in string: + return string + + # Unquote. + parts = [] + current_position = 0 + for match in re.finditer(PERCENT_ENCODED_REGEX, string): + start_position, end_position = match.start(), match.end() + matched_text = match.group(0) + # Include any text up to the '%xx' escape sequence. + if start_position != current_position: + leading_text = string[current_position:start_position] + parts.append(leading_text) + + # Decode the '%xx' escape sequence. + hex = matched_text.replace('%', '') + decoded = bytes.fromhex(hex).decode('utf-8') + parts.append(decoded) + current_position = end_position + + # Include any text after the final '%xx' escape sequence. + if current_position != len(string): + trailing_text = string[current_position:] + parts.append(trailing_text) + + return "".join(parts) + + +def percent(c): + return ''.join(f"%{b:02X}" for b in c.encode("utf-8")) diff --git a/src/ahttpx/_urlparse.py b/src/ahttpx/_urlparse.py new file mode 100644 index 00000000..612892fa --- /dev/null +++ b/src/ahttpx/_urlparse.py @@ -0,0 +1,534 @@ +""" +An implementation of `urlparse` that provides URL validation and normalization +as described by RFC3986. + +We rely on this implementation rather than the one in Python's stdlib, because: + +* It provides more complete URL validation. +* It properly differentiates between an empty querystring and an absent querystring, + to distinguish URLs with a trailing '?'. +* It handles scheme, hostname, port, and path normalization. +* It supports IDNA hostnames, normalizing them to their encoded form. +* The API supports passing individual components, as well as the complete URL string. + +Previously we relied on the excellent `rfc3986` package to handle URL parsing and +validation, but this module provides a simpler alternative, with less indirection +required. +""" + +import ipaddress +import re +import typing + + +class InvalidURL(ValueError): + pass + + +MAX_URL_LENGTH = 65536 + +# https://datatracker.ietf.org/doc/html/rfc3986.html#section-2.3 +UNRESERVED_CHARACTERS = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" +) +SUB_DELIMS = "!$&'()*+,;=" + +PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}") + +# https://url.spec.whatwg.org/#percent-encoded-bytes + +# The fragment percent-encode set is the C0 control percent-encode set +# and U+0020 SPACE, U+0022 ("), U+003C (<), U+003E (>), and U+0060 (`). +FRAG_SAFE = "".join( + [chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x3C, 0x3E, 0x60)] +) + +# The query percent-encode set is the C0 control percent-encode set +# and U+0020 SPACE, U+0022 ("), U+0023 (#), U+003C (<), and U+003E (>). +QUERY_SAFE = "".join( + [chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E)] +) + +# The path percent-encode set is the query percent-encode set +# and U+003F (?), U+0060 (`), U+007B ({), and U+007D (}). +PATH_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + (0x3F, 0x60, 0x7B, 0x7D) + ] +) + +# The userinfo percent-encode set is the path percent-encode set +# and U+002F (/), U+003A (:), U+003B (;), U+003D (=), U+0040 (@), +# U+005B ([) to U+005E (^), inclusive, and U+007C (|). +USERNAME_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) +PASSWORD_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) +# Note... The terminology 'userinfo' percent-encode set in the WHATWG document +# is used for the username and password quoting. For the joint userinfo component +# we remove U+003A (:) from the safe set. +USERINFO_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) + + +# {scheme}: (optional) +# //{authority} (optional) +# {path} +# ?{query} (optional) +# #{fragment} (optional) +URL_REGEX = re.compile( + ( + r"(?:(?P{scheme}):)?" + r"(?://(?P{authority}))?" + r"(?P{path})" + r"(?:\?(?P{query}))?" + r"(?:#(?P{fragment}))?" + ).format( + scheme="([a-zA-Z][a-zA-Z0-9+.-]*)?", + authority="[^/?#]*", + path="[^?#]*", + query="[^#]*", + fragment=".*", + ) +) + +# {userinfo}@ (optional) +# {host} +# :{port} (optional) +AUTHORITY_REGEX = re.compile( + ( + r"(?:(?P{userinfo})@)?" r"(?P{host})" r":?(?P{port})?" + ).format( + userinfo=".*", # Any character sequence. + host="(\\[.*\\]|[^:@]*)", # Either any character sequence excluding ':' or '@', + # or an IPv6 address enclosed within square brackets. + port=".*", # Any character sequence. + ) +) + + +# If we call urlparse with an individual component, then we need to regex +# validate that component individually. +# Note that we're duplicating the same strings as above. Shock! Horror!! +COMPONENT_REGEX = { + "scheme": re.compile("([a-zA-Z][a-zA-Z0-9+.-]*)?"), + "authority": re.compile("[^/?#]*"), + "path": re.compile("[^?#]*"), + "query": re.compile("[^#]*"), + "fragment": re.compile(".*"), + "userinfo": re.compile("[^@]*"), + "host": re.compile("(\\[.*\\]|[^:]*)"), + "port": re.compile(".*"), +} + + +# We use these simple regexs as a first pass before handing off to +# the stdlib 'ipaddress' module for IP address validation. +IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") +IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$") + + +class ParseResult(typing.NamedTuple): + scheme: str + userinfo: str + host: str + port: int | None + path: str + query: str | None + fragment: str | None + + @property + def authority(self) -> str: + return "".join( + [ + f"{self.userinfo}@" if self.userinfo else "", + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + @property + def netloc(self) -> str: + return "".join( + [ + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + def copy_with(self, **kwargs: str | None) -> "ParseResult": + if not kwargs: + return self + + defaults = { + "scheme": self.scheme, + "authority": self.authority, + "path": self.path, + "query": self.query, + "fragment": self.fragment, + } + defaults.update(kwargs) + return urlparse("", **defaults) + + def __str__(self) -> str: + authority = self.authority + return "".join( + [ + f"{self.scheme}:" if self.scheme else "", + f"//{authority}" if authority else "", + self.path, + f"?{self.query}" if self.query is not None else "", + f"#{self.fragment}" if self.fragment is not None else "", + ] + ) + + +def urlparse(url: str = "", **kwargs: str | None) -> ParseResult: + # Initial basic checks on allowable URLs. + # --------------------------------------- + + # Hard limit the maximum allowable URL length. + if len(url) > MAX_URL_LENGTH: + raise InvalidURL("URL too long") + + # If a URL includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in url): + char = next(char for char in url if char.isascii() and not char.isprintable()) + idx = url.find(char) + error = ( + f"Invalid non-printable ASCII character in URL, {char!r} at position {idx}." + ) + raise InvalidURL(error) + + # Some keyword arguments require special handling. + # ------------------------------------------------ + + # Coerce "port" to a string, if it is provided as an integer. + if "port" in kwargs: + port = kwargs["port"] + kwargs["port"] = str(port) if isinstance(port, int) else port + + # Replace "netloc" with "host and "port". + if "netloc" in kwargs: + netloc = kwargs.pop("netloc") or "" + kwargs["host"], _, kwargs["port"] = netloc.partition(":") + + # Replace "username" and/or "password" with "userinfo". + if "username" in kwargs or "password" in kwargs: + username = quote(kwargs.pop("username", "") or "", safe=USERNAME_SAFE) + password = quote(kwargs.pop("password", "") or "", safe=PASSWORD_SAFE) + kwargs["userinfo"] = f"{username}:{password}" if password else username + + # Replace "raw_path" with "path" and "query". + if "raw_path" in kwargs: + raw_path = kwargs.pop("raw_path") or "" + kwargs["path"], seperator, kwargs["query"] = raw_path.partition("?") + if not seperator: + kwargs["query"] = None + + # Ensure that IPv6 "host" addresses are always escaped with "[...]". + if "host" in kwargs: + host = kwargs.get("host") or "" + if ":" in host and not (host.startswith("[") and host.endswith("]")): + kwargs["host"] = f"[{host}]" + + # If any keyword arguments are provided, ensure they are valid. + # ------------------------------------------------------------- + + for key, value in kwargs.items(): + if value is not None: + if len(value) > MAX_URL_LENGTH: + raise InvalidURL(f"URL component '{key}' too long") + + # If a component includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in value): + char = next( + char for char in value if char.isascii() and not char.isprintable() + ) + idx = value.find(char) + error = ( + f"Invalid non-printable ASCII character in URL {key} component, " + f"{char!r} at position {idx}." + ) + raise InvalidURL(error) + + # Ensure that keyword arguments match as a valid regex. + if not COMPONENT_REGEX[key].fullmatch(value): + raise InvalidURL(f"Invalid URL component '{key}'") + + # The URL_REGEX will always match, but may have empty components. + url_match = URL_REGEX.match(url) + assert url_match is not None + url_dict = url_match.groupdict() + + # * 'scheme', 'authority', and 'path' may be empty strings. + # * 'query' may be 'None', indicating no trailing "?" portion. + # Any string including the empty string, indicates a trailing "?". + # * 'fragment' may be 'None', indicating no trailing "#" portion. + # Any string including the empty string, indicates a trailing "#". + scheme = kwargs.get("scheme", url_dict["scheme"]) or "" + authority = kwargs.get("authority", url_dict["authority"]) or "" + path = kwargs.get("path", url_dict["path"]) or "" + query = kwargs.get("query", url_dict["query"]) + frag = kwargs.get("fragment", url_dict["fragment"]) + + # The AUTHORITY_REGEX will always match, but may have empty components. + authority_match = AUTHORITY_REGEX.match(authority) + assert authority_match is not None + authority_dict = authority_match.groupdict() + + # * 'userinfo' and 'host' may be empty strings. + # * 'port' may be 'None'. + userinfo = kwargs.get("userinfo", authority_dict["userinfo"]) or "" + host = kwargs.get("host", authority_dict["host"]) or "" + port = kwargs.get("port", authority_dict["port"]) + + # Normalize and validate each component. + # We end up with a parsed representation of the URL, + # with components that are plain ASCII bytestrings. + parsed_scheme: str = scheme.lower() + parsed_userinfo: str = quote(userinfo, safe=USERINFO_SAFE) + parsed_host: str = encode_host(host) + parsed_port: int | None = normalize_port(port, scheme) + + has_scheme = parsed_scheme != "" + has_authority = ( + parsed_userinfo != "" or parsed_host != "" or parsed_port is not None + ) + validate_path(path, has_scheme=has_scheme, has_authority=has_authority) + if has_scheme or has_authority: + path = normalize_path(path) + + parsed_path: str = quote(path, safe=PATH_SAFE) + parsed_query: str | None = None if query is None else quote(query, safe=QUERY_SAFE) + parsed_frag: str | None = None if frag is None else quote(frag, safe=FRAG_SAFE) + + # The parsed ASCII bytestrings are our canonical form. + # All properties of the URL are derived from these. + return ParseResult( + parsed_scheme, + parsed_userinfo, + parsed_host, + parsed_port, + parsed_path, + parsed_query, + parsed_frag, + ) + + +def encode_host(host: str) -> str: + if not host: + return "" + + elif IPv4_STYLE_HOSTNAME.match(host): + # Validate IPv4 hostnames like #.#.#.# + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet + try: + ipaddress.IPv4Address(host) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv4 address: {host!r}") + return host + + elif IPv6_STYLE_HOSTNAME.match(host): + # Validate IPv6 hostnames like [...] + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # "A host identified by an Internet Protocol literal address, version 6 + # [RFC3513] or later, is distinguished by enclosing the IP literal + # within square brackets ("[" and "]"). This is the only place where + # square bracket characters are allowed in the URI syntax." + try: + ipaddress.IPv6Address(host[1:-1]) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv6 address: {host!r}") + return host[1:-1] + + elif not host.isascii(): + try: + import idna # type: ignore + except ImportError: + raise InvalidURL( + f"Cannot handle URL with IDNA hostname: {host!r}. " + f"Package 'idna' is not installed." + ) + + # IDNA hostnames + try: + return idna.encode(host.lower()).decode("ascii") + except idna.IDNAError: + raise InvalidURL(f"Invalid IDNA hostname: {host!r}") + + # Regular ASCII hostnames + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # reg-name = *( unreserved / pct-encoded / sub-delims ) + WHATWG_SAFE = '"`{}%|\\' + return quote(host.lower(), safe=SUB_DELIMS + WHATWG_SAFE) + + +def normalize_port(port: str | int | None, scheme: str) -> int | None: + # From https://tools.ietf.org/html/rfc3986#section-3.2.3 + # + # "A scheme may define a default port. For example, the "http" scheme + # defines a default port of "80", corresponding to its reserved TCP + # port number. The type of port designated by the port number (e.g., + # TCP, UDP, SCTP) is defined by the URI scheme. URI producers and + # normalizers should omit the port component and its ":" delimiter if + # port is empty or if its value would be the same as that of the + # scheme's default." + if port is None or port == "": + return None + + try: + port_as_int = int(port) + except ValueError: + raise InvalidURL(f"Invalid port: {port!r}") + + # See https://url.spec.whatwg.org/#url-miscellaneous + default_port = {"ftp": 21, "http": 80, "https": 443, "ws": 80, "wss": 443}.get( + scheme + ) + if port_as_int == default_port: + return None + return port_as_int + + +def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None: + """ + Path validation rules that depend on if the URL contains + a scheme or authority component. + + See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3 + """ + if has_authority: + # If a URI contains an authority component, then the path component + # must either be empty or begin with a slash ("/") character." + if path and not path.startswith("/"): + raise InvalidURL("For absolute URLs, path must be empty or begin with '/'") + + if not has_scheme and not has_authority: + # If a URI does not contain an authority component, then the path cannot begin + # with two slash characters ("//"). + if path.startswith("//"): + raise InvalidURL("Relative URLs cannot have a path starting with '//'") + + # In addition, a URI reference (Section 4.1) may be a relative-path reference, + # in which case the first path segment cannot contain a colon (":") character. + if path.startswith(":"): + raise InvalidURL("Relative URLs cannot have a path starting with ':'") + + +def normalize_path(path: str) -> str: + """ + Drop "." and ".." segments from a URL path. + + For example: + + normalize_path("/path/./to/somewhere/..") == "/path/to" + """ + # Fast return when no '.' characters in the path. + if "." not in path: + return path + + components = path.split("/") + + # Fast return when no '.' or '..' components in the path. + if "." not in components and ".." not in components: + return path + + # https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4 + output: list[str] = [] + for component in components: + if component == ".": + pass + elif component == "..": + if output and output != [""]: + output.pop() + else: + output.append(component) + return "/".join(output) + + +def PERCENT(string: str) -> str: + return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")]) + + +def percent_encoded(string: str, safe: str) -> str: + """ + Use percent-encoding to quote a string. + """ + NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + + # Fast path for strings that don't need escaping. + if not string.rstrip(NON_ESCAPED_CHARS): + return string + + return "".join( + [char if char in NON_ESCAPED_CHARS else PERCENT(char) for char in string] + ) + + +def quote(string: str, safe: str) -> str: + """ + Use percent-encoding to quote a string, omitting existing '%xx' escape sequences. + + See: https://www.rfc-editor.org/rfc/rfc3986#section-2.1 + + * `string`: The string to be percent-escaped. + * `safe`: A string containing characters that may be treated as safe, and do not + need to be escaped. Unreserved characters are always treated as safe. + See: https://www.rfc-editor.org/rfc/rfc3986#section-2.3 + """ + parts = [] + current_position = 0 + for match in re.finditer(PERCENT_ENCODED_REGEX, string): + start_position, end_position = match.start(), match.end() + matched_text = match.group(0) + # Add any text up to the '%xx' escape sequence. + if start_position != current_position: + leading_text = string[current_position:start_position] + parts.append(percent_encoded(leading_text, safe=safe)) + + # Add the '%xx' escape sequence. + parts.append(matched_text) + current_position = end_position + + # Add any text after the final '%xx' escape sequence. + if current_position != len(string): + trailing_text = string[current_position:] + parts.append(percent_encoded(trailing_text, safe=safe)) + + return "".join(parts) diff --git a/src/ahttpx/_urls.py b/src/ahttpx/_urls.py new file mode 100644 index 00000000..4ae4464e --- /dev/null +++ b/src/ahttpx/_urls.py @@ -0,0 +1,552 @@ +from __future__ import annotations + +import typing + +from ._urlparse import urlparse +from ._urlencode import unquote, urldecode, urlencode + +__all__ = ["QueryParams", "URL"] + + +class URL: + """ + url = httpx.URL("HTTPS://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink") + + assert url.scheme == "https" + assert url.username == "jo@email.com" + assert url.password == "a secret" + assert url.userinfo == b"jo%40email.com:a%20secret" + assert url.host == "müller.de" + assert url.raw_host == b"xn--mller-kva.de" + assert url.port == 1234 + assert url.netloc == b"xn--mller-kva.de:1234" + assert url.path == "/pa th" + assert url.query == b"?search=ab" + assert url.raw_path == b"/pa%20th?search=ab" + assert url.fragment == "anchorlink" + + The components of a URL are broken down like this: + + https://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink + [scheme] [ username ] [password] [ host ][port][ path ] [ query ] [fragment] + [ userinfo ] [ netloc ][ raw_path ] + + Note that: + + * `url.scheme` is normalized to always be lowercased. + + * `url.host` is normalized to always be lowercased. Internationalized domain + names are represented in unicode, without IDNA encoding applied. For instance: + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "中国.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "中国.icom.museum" + + * `url.raw_host` is normalized to always be lowercased, and is IDNA encoded. + + url = httpx.URL("http://中国.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + * `url.port` is either None or an integer. URLs that include the default port for + "http", "https", "ws", "wss", and "ftp" schemes have their port + normalized to `None`. + + assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80") + assert httpx.URL("http://example.com").port is None + assert httpx.URL("http://example.com:80").port is None + + * `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work + with `url.username` and `url.password` instead, which handle the URL escaping. + + * `url.raw_path` is raw bytes of both the path and query, without URL escaping. + This portion is used as the target when constructing HTTP requests. Usually you'll + want to work with `url.path` instead. + + * `url.query` is raw bytes, without URL escaping. A URL query string portion can + only be properly URL escaped when decoding the parameter names and values + themselves. + """ + + def __init__(self, url: "URL" | str = "", **kwargs: typing.Any) -> None: + if kwargs: + allowed = { + "scheme": str, + "username": str, + "password": str, + "userinfo": bytes, + "host": str, + "port": int, + "netloc": str, + "path": str, + "query": bytes, + "raw_path": bytes, + "fragment": str, + "params": object, + } + + # Perform type checking for all supported keyword arguments. + for key, value in kwargs.items(): + if key not in allowed: + message = f"{key!r} is an invalid keyword argument for URL()" + raise TypeError(message) + if value is not None and not isinstance(value, allowed[key]): + expected = allowed[key].__name__ + seen = type(value).__name__ + message = f"Argument {key!r} must be {expected} but got {seen}" + raise TypeError(message) + if isinstance(value, bytes): + kwargs[key] = value.decode("ascii") + + if "params" in kwargs: + # Replace any "params" keyword with the raw "query" instead. + # + # Ensure that empty params use `kwargs["query"] = None` rather + # than `kwargs["query"] = ""`, so that generated URLs do not + # include an empty trailing "?". + params = kwargs.pop("params") + kwargs["query"] = None if not params else str(QueryParams(params)) + + if isinstance(url, str): + self._uri_reference = urlparse(url, **kwargs) + elif isinstance(url, URL): + self._uri_reference = url._uri_reference.copy_with(**kwargs) + else: + raise TypeError( + "Invalid type for url. Expected str or httpx.URL," + f" got {type(url)}: {url!r}" + ) + + @property + def scheme(self) -> str: + """ + The URL scheme, such as "http", "https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme + + @property + def userinfo(self) -> bytes: + """ + The URL userinfo as a raw bytestring. + For example: b"jo%40email.com:a%20secret". + """ + return self._uri_reference.userinfo.encode("ascii") + + @property + def username(self) -> str: + """ + The URL username as a string, with URL decoding applied. + For example: "jo@email.com" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[0]) + + @property + def password(self) -> str: + """ + The URL password as a string, with URL decoding applied. + For example: "a secret" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[2]) + + @property + def host(self) -> str: + """ + The URL host as a string. + Always normalized to lowercase. Possibly IDNA encoded. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.host == "www.example.org" + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "xn--fiqs8s" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "xn--fiqs8s" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.host == "::ffff:192.168.0.1" + """ + return self._uri_reference.host + + @property + def port(self) -> int | None: + """ + The URL port as an integer. + + Note that the URL class performs port normalization as per the WHATWG spec. + Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always + treated as `None`. + + For example: + + assert httpx.URL("http://www.example.com") == httpx.URL("http://www.example.com:80") + assert httpx.URL("http://www.example.com:80").port is None + """ + return self._uri_reference.port + + @property + def netloc(self) -> str: + """ + Either `` or `:` as bytes. + Always normalized to lowercase, and IDNA encoded. + + This property may be used for generating the value of a request + "Host" header. + """ + return self._uri_reference.netloc + + @property + def path(self) -> str: + """ + The URL path as a string. Excluding the query string, and URL decoded. + + For example: + + url = httpx.URL("https://example.com/pa%20th") + assert url.path == "/pa th" + """ + path = self._uri_reference.path or "/" + return unquote(path) + + @property + def query(self) -> bytes: + """ + The URL query string, as raw bytes, excluding the leading b"?". + + This is necessarily a bytewise interface, because we cannot + perform URL decoding of this representation until we've parsed + the keys and values into a QueryParams instance. + + For example: + + url = httpx.URL("https://example.com/?filter=some%20search%20terms") + assert url.query == b"filter=some%20search%20terms" + """ + query = self._uri_reference.query or "" + return query.encode("ascii") + + @property + def params(self) -> "QueryParams": + """ + The URL query parameters, neatly parsed and packaged into an immutable + multidict representation. + """ + return QueryParams(self._uri_reference.query) + + @property + def target(self) -> str: + """ + The complete URL path and query string as raw bytes. + Used as the target when constructing HTTP requests. + + For example: + + GET /users?search=some%20text HTTP/1.1 + Host: www.example.org + Connection: close + """ + target = self._uri_reference.path or "/" + if self._uri_reference.query is not None: + target += "?" + self._uri_reference.query + return target + + @property + def fragment(self) -> str: + """ + The URL fragments, as used in HTML anchors. + As a string, without the leading '#'. + """ + return unquote(self._uri_reference.fragment or "") + + @property + def is_absolute_url(self) -> bool: + """ + Return `True` for absolute URLs such as 'http://example.com/path', + and `False` for relative URLs such as '/path'. + """ + # We don't use `.is_absolute` from `rfc3986` because it treats + # URLs with a fragment portion as not absolute. + # What we actually care about is if the URL provides + # a scheme and hostname to which connections should be made. + return bool(self._uri_reference.scheme and self._uri_reference.host) + + @property + def is_relative_url(self) -> bool: + """ + Return `False` for absolute URLs such as 'http://example.com/path', + and `True` for relative URLs such as '/path'. + """ + return not self.is_absolute_url + + def copy_with(self, **kwargs: typing.Any) -> "URL": + """ + Copy this URL, returning a new URL with some components altered. + Accepts the same set of parameters as the components that are made + available via properties on the `URL` class. + + For example: + + url = httpx.URL("https://www.example.com").copy_with( + username="jo@gmail.com", password="a secret" + ) + assert url == "https://jo%40email.com:a%20secret@www.example.com" + """ + return URL(self, **kwargs) + + def copy_set_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.copy_set(key, value)) + + def copy_append_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.copy_append(key, value)) + + def copy_remove_param(self, key: str) -> "URL": + return self.copy_with(params=self.params.copy_remove(key)) + + def copy_merge_params( + self, + params: "QueryParams" | dict[str, str | list[str]] | list[tuple[str, str]] | None, + ) -> "URL": + return self.copy_with(params=self.params.copy_update(params)) + + def join(self, url: "URL" | str) -> "URL": + """ + Return an absolute URL, using this URL as the base. + + Eg. + + url = httpx.URL("https://www.example.com/test") + url = url.join("/new/path") + assert url == "https://www.example.com/new/path" + """ + from urllib.parse import urljoin + + return URL(urljoin(str(self), str(URL(url)))) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, (URL, str)) and str(self) == str(URL(other)) + + def __str__(self) -> str: + return str(self._uri_reference) + + def __repr__(self) -> str: + return f"" + + +class QueryParams(typing.Mapping[str, str]): + """ + URL query parameters, as a multi-dict. + """ + + def __init__( + self, + params: ( + "QueryParams" | dict[str, str | list[str]] | list[tuple[str, str]] | str | None + ) = None, + ) -> None: + d: dict[str, list[str]] = {} + + if params is None: + d = {} + elif isinstance(params, str): + d = urldecode(params) + elif isinstance(params, QueryParams): + d = params.multi_dict() + elif isinstance(params, dict): + # Convert dict inputs like: + # {"a": "123", "b": ["456", "789"]} + # To dict inputs where values are always lists, like: + # {"a": ["123"], "b": ["456", "789"]} + d = {k: [v] if isinstance(v, str) else list(v) for k, v in params.items()} + else: + # Convert list inputs like: + # [("a", "123"), ("a", "456"), ("b", "789")] + # To a dict representation, like: + # {"a": ["123", "456"], "b": ["789"]} + for k, v in params: + d.setdefault(k, []).append(v) + + self._dict = d + + def keys(self) -> typing.KeysView[str]: + """ + Return all the keys in the query params. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.keys()) == ["a", "b"] + """ + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + """ + Return all the values in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.values()) == ["123", "789"] + """ + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return all items in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.items()) == [("a", "123"), ("b", "789")] + """ + return {k: v[0] for k, v in self._dict.items()}.items() + + def multi_items(self) -> list[tuple[str, str]]: + """ + Return all items in the query params. Allow duplicate keys to occur. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")] + """ + multi_items: list[tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def multi_dict(self) -> dict[str, list[str]]: + return {k: list(v) for k, v in self._dict.items()} + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get("a") == "123" + """ + if key in self._dict: + return self._dict[key][0] + return default + + def get_list(self, key: str) -> list[str]: + """ + Get all values from the query param for a given key. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get_list("a") == ["123", "456"] + """ + return list(self._dict.get(key, [])) + + def copy_set(self, key: str, value: str) -> "QueryParams": + """ + Return a new QueryParams instance, setting the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.set("a", "456") + assert q == httpx.QueryParams("a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[key] = [value] + return q + + def copy_append(self, key: str, value: str) -> "QueryParams": + """ + Return a new QueryParams instance, setting or appending the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.append("a", "456") + assert q == httpx.QueryParams("a=123&a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[key] = q.get_list(key) + [value] + return q + + def copy_remove(self, key: str) -> QueryParams: + """ + Return a new QueryParams instance, removing the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.remove("a") + assert q == httpx.QueryParams("") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict.pop(str(key), None) + return q + + def copy_update( + self, + params: ( + "QueryParams" | dict[str, str | list[str]] | list[tuple[str, str]] | None + ) = None, + ) -> "QueryParams": + """ + Return a new QueryParams instance, updated with. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.copy_update({"b": "456"}) + assert q == httpx.QueryParams("a=123&b=456") + + q = httpx.QueryParams("a=123") + q = q.copy_update({"a": "456", "b": "789"}) + assert q == httpx.QueryParams("a=456&b=789") + """ + q = QueryParams(params) + q._dict = {**self._dict, **q._dict} + return q + + def __getitem__(self, key: str) -> str: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self.multi_items()) == sorted(other.multi_items()) + + def __str__(self) -> str: + return urlencode(self.multi_dict()) + + def __repr__(self) -> str: + return f"" diff --git a/src/httpx/__init__.py b/src/httpx/__init__.py new file mode 100644 index 00000000..9e589ab6 --- /dev/null +++ b/src/httpx/__init__.py @@ -0,0 +1,65 @@ +from .__version__ import __title__, __version__ +from ._client import * # Client +from ._content import * # Content, File, Files, Form, HTML, JSON, MultiPart, Text +from ._headers import * # Headers +from ._network import * # NetworkBackend, NetworkStream, timeout +from ._parsers import * # HTTPParser, ProtocolError +from ._pool import * # Connection, ConnectionPool, Transport +from ._quickstart import * # get, post, put, patch, delete +from ._response import * # Response +from ._request import * # Request +from ._streams import * # ByteStream, DuplexStream, FileStream, HTTPStream, Stream +from ._server import * # serve_http, run +from ._urlencode import * # quote, unquote, urldecode, urlencode +from ._urls import * # QueryParams, URL + + +__all__ = [ + "__title__", + "__version__", + "ByteStream", + "Client", + "Connection", + "ConnectionPool", + "Content", + "delete", + "DuplexStream", + "File", + "FileStream", + "Files", + "Form", + "get", + "Headers", + "HTML", + "HTTPParser", + "HTTPStream", + "JSON", + "MultiPart", + "NetworkBackend", + "NetworkStream", + "open_connection", + "post", + "ProtocolError", + "put", + "patch", + "Response", + "Request", + "run", + "serve_http", + "Stream", + "Text", + "timeout", + "Transport", + "QueryParams", + "quote", + "unquote", + "URL", + "urldecode", + "urlencode", +] + + +__locals = locals() +for __name in __all__: + if not __name.startswith('__'): + setattr(__locals[__name], "__module__", "httpx") diff --git a/src/httpx/__version__.py b/src/httpx/__version__.py new file mode 100644 index 00000000..ba1c14e7 --- /dev/null +++ b/src/httpx/__version__.py @@ -0,0 +1,2 @@ +__title__ = "httpx" +__version__ = "1.0.dev3" \ No newline at end of file diff --git a/src/httpx/_client.py b/src/httpx/_client.py new file mode 100644 index 00000000..2dd54fd3 --- /dev/null +++ b/src/httpx/_client.py @@ -0,0 +1,156 @@ +import types +import typing + +from ._content import Content +from ._headers import Headers +from ._pool import ConnectionPool, Transport +from ._request import Request +from ._response import Response +from ._streams import Stream +from ._urls import URL + +__all__ = ["Client"] + + +class Client: + def __init__( + self, + url: URL | str | None = None, + headers: Headers | typing.Mapping[str, str] | None = None, + transport: Transport | None = None, + ): + if url is None: + url = "" + if headers is None: + headers = {"User-Agent": "dev"} + if transport is None: + transport = ConnectionPool() + + self.url = URL(url) + self.headers = Headers(headers) + self.transport = transport + self.via = RedirectMiddleware(self.transport) + + def build_request( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Request: + return Request( + method=method, + url=self.url.join(url), + headers=self.headers.copy_update(headers), + content=content, + ) + + def request( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = self.build_request(method, url, headers=headers, content=content) + with self.via.send(request) as response: + response.read() + return response + + def stream( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = self.build_request(method, url, headers=headers, content=content) + return self.via.send(request) + + def get( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + ): + return self.request("GET", url, headers=headers) + + def post( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + return self.request("POST", url, headers=headers, content=content) + + def put( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + return self.request("PUT", url, headers=headers, content=content) + + def patch( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + return self.request("PATCH", url, headers=headers, content=content) + + def delete( + self, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + ): + return self.request("DELETE", url, headers=headers) + + def close(self): + self.transport.close() + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + self.close() + + def __repr__(self): + return f"" + + +class RedirectMiddleware(Transport): + def __init__(self, transport: Transport) -> None: + self._transport = transport + + def is_redirect(self, response: Response) -> bool: + return ( + response.status_code in (301, 302, 303, 307, 308) + and "Location" in response.headers + ) + + def build_redirect_request(self, request: Request, response: Response) -> Request: + raise NotImplementedError() + + def send(self, request: Request) -> Response: + while True: + response = self._transport.send(request) + + if not self.is_redirect(response): + return response + + # If we have a redirect, then we read the body of the response. + # Ensures that the HTTP connection is available for a new + # request/response cycle. + response.read() + response.close() + + # We've made a request-response and now need to issue a redirect request. + request = self.build_redirect_request(request, response) + + def close(self): + pass diff --git a/src/httpx/_content.py b/src/httpx/_content.py new file mode 100644 index 00000000..e2c4aaa9 --- /dev/null +++ b/src/httpx/_content.py @@ -0,0 +1,378 @@ +import json +import os +import typing + +from ._streams import Stream, ByteStream, FileStream, MultiPartStream +from ._urlencode import urldecode, urlencode + +__all__ = [ + "Content", + "Form", + "File", + "Files", + "JSON", + "MultiPart", + "Text", + "HTML", +] + +# https://github.com/nginx/nginx/blob/master/conf/mime.types +_content_types = { + ".json": "application/json", + ".js": "application/javascript", + ".html": "text/html", + ".css": "text/css", + ".png": "image/png", + ".jpeg": "image/jpeg", + ".jpg": "image/jpeg", + ".gif": "image/gif", +} + + +class Content: + def encode(self) -> Stream: + raise NotImplementedError() + + def content_type(self) -> str: + raise NotImplementedError() + + +class Form(typing.Mapping[str, str], Content): + """ + HTML form data, as an immutable multi-dict. + Form parameters, as a multi-dict. + """ + + def __init__( + self, + form: ( + typing.Mapping[str, str | typing.Sequence[str]] + | typing.Sequence[tuple[str, str]] + | str + | None + ) = None, + ) -> None: + d: dict[str, list[str]] = {} + + if form is None: + d = {} + elif isinstance(form, str): + d = urldecode(form) + elif isinstance(form, typing.Mapping): + # Convert dict inputs like: + # {"a": "123", "b": ["456", "789"]} + # To dict inputs where values are always lists, like: + # {"a": ["123"], "b": ["456", "789"]} + d = {k: [v] if isinstance(v, str) else list(v) for k, v in form.items()} + else: + # Convert list inputs like: + # [("a", "123"), ("a", "456"), ("b", "789")] + # To a dict representation, like: + # {"a": ["123", "456"], "b": ["789"]} + for k, v in form: + d.setdefault(k, []).append(v) + + self._dict = d + + # Content API + + def encode(self) -> Stream: + content = str(self).encode("ascii") + return ByteStream(content) + + def content_type(self) -> str: + return "application/x-www-form-urlencoded" + + # Dict operations + + def keys(self) -> typing.KeysView[str]: + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, str]: + return {k: v[0] for k, v in self._dict.items()}.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + if key in self._dict: + return self._dict[key][0] + return default + + # Multi-dict operations + + def multi_items(self) -> list[tuple[str, str]]: + multi_items: list[tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def multi_dict(self) -> dict[str, list[str]]: + return {k: list(v) for k, v in self._dict.items()} + + def get_list(self, key: str) -> list[str]: + return list(self._dict.get(key, [])) + + # Update operations + + def copy_set(self, key: str, value: str) -> "Form": + d = self.multi_dict() + d[key] = [value] + return Form(d) + + def copy_append(self, key: str, value: str) -> "Form": + d = self.multi_dict() + d[key] = d.get(key, []) + [value] + return Form(d) + + def copy_remove(self, key: str) -> "Form": + d = self.multi_dict() + d.pop(key, None) + return Form(d) + + # Accessors & built-ins + + def __getitem__(self, key: str) -> str: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, Form) and + sorted(self.multi_items()) == sorted(other.multi_items()) + ) + + def __str__(self) -> str: + return urlencode(self.multi_dict()) + + def __repr__(self) -> str: + return f"" + + +class File(Content): + """ + Wrapper class used for files in uploads and multipart requests. + """ + + def __init__(self, path: str): + self._path = path + + def name(self) -> str: + return os.path.basename(self._path) + + def size(self) -> int: + return os.path.getsize(self._path) + + def encode(self) -> Stream: + return FileStream(self._path) + + def content_type(self) -> str: + _, ext = os.path.splitext(self._path) + ct = _content_types.get(ext, "application/octet-stream") + if ct.startswith('text/'): + ct += "; charset='utf-8'" + return ct + + def __lt__(self, other: typing.Any) -> bool: + return isinstance(other, File) and other._path < self._path + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, File) and other._path == self._path + + def __repr__(self) -> str: + return f"" + + +class Files(typing.Mapping[str, File], Content): + """ + File parameters, as a multi-dict. + """ + + def __init__( + self, + files: ( + typing.Mapping[str, File | typing.Sequence[File]] + | typing.Sequence[tuple[str, File]] + | None + ) = None, + boundary: str = '' + ) -> None: + d: dict[str, list[File]] = {} + + if files is None: + d = {} + elif isinstance(files, typing.Mapping): + d = {k: [v] if isinstance(v, File) else list(v) for k, v in files.items()} + else: + d = {} + for k, v in files: + d.setdefault(k, []).append(v) + + self._dict = d + self._boundary = boundary or os.urandom(16).hex() + + # Standard dict interface + def keys(self) -> typing.KeysView[str]: + return self._dict.keys() + + def values(self) -> typing.ValuesView[File]: + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, File]: + return {k: v[0] for k, v in self._dict.items()}.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + if key in self._dict: + return self._dict[key][0] + return None + + # Multi dict interface + def multi_items(self) -> list[tuple[str, File]]: + multi_items: list[tuple[str, File]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def multi_dict(self) -> dict[str, list[File]]: + return {k: list(v) for k, v in self._dict.items()} + + def get_list(self, key: str) -> list[File]: + return list(self._dict.get(key, [])) + + # Content interface + def encode(self) -> Stream: + return MultiPart(files=self).encode() + + def content_type(self) -> str: + return f"multipart/form-data; boundary={self._boundary}" + + # Builtins + def __getitem__(self, key: str) -> File: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, Files) and + sorted(self.multi_items()) == sorted(other.multi_items()) + ) + + def __repr__(self) -> str: + return f"" + + +class JSON(Content): + def __init__(self, data: typing.Any) -> None: + self._data = data + + def encode(self) -> Stream: + content = json.dumps( + self._data, + ensure_ascii=False, + separators=(",", ":"), + allow_nan=False + ).encode("utf-8") + return ByteStream(content) + + def content_type(self) -> str: + return "application/json" + + def __repr__(self) -> str: + return f"" + + +class Text(Content): + def __init__(self, text: str) -> None: + self._text = text + + def encode(self) -> Stream: + content = self._text.encode("utf-8") + return ByteStream(content) + + def content_type(self) -> str: + return "text/plain; charset='utf-8'" + + def __repr__(self) -> str: + return f"" + + +class HTML(Content): + def __init__(self, text: str) -> None: + self._text = text + + def encode(self) -> Stream: + content = self._text.encode("utf-8") + return ByteStream(content) + + def content_type(self) -> str: + return "text/html; charset='utf-8'" + + def __repr__(self) -> str: + return f"" + + +class MultiPart(Content): + def __init__( + self, + form: ( + Form + | typing.Mapping[str, str | typing.Sequence[str]] + | typing.Sequence[tuple[str, str]] + | str + | None + ) = None, + files: ( + Files + | typing.Mapping[str, File | typing.Sequence[File]] + | typing.Sequence[tuple[str, File]] + | None + ) = None, + boundary: str | None = None + ): + self._form = form if isinstance(form , Form) else Form(form) + self._files = files if isinstance(files, Files) else Files(files) + self._boundary = os.urandom(16).hex() if boundary is None else boundary + + @property + def form(self) -> Form: + return self._form + + @property + def files(self) -> Files: + return self._files + + def encode(self) -> Stream: + form = [(key, value) for key, value in self._form.items()] + files = [(key, file._path) for key, file in self._files.items()] + return MultiPartStream(form, files, boundary=self._boundary) + + def content_type(self) -> str: + return f"multipart/form-data; boundary={self._boundary}" + + def __repr__(self) -> str: + return f"" diff --git a/src/httpx/_headers.py b/src/httpx/_headers.py new file mode 100644 index 00000000..dade8058 --- /dev/null +++ b/src/httpx/_headers.py @@ -0,0 +1,243 @@ +import re +import typing + + +__all__ = ["Headers"] + + +VALID_HEADER_CHARS = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789" + "!#$%&'*+-.^_`|~" +) + + +# TODO... +# +# * Comma folded values, eg. `Vary: ...` +# * Multiple Set-Cookie headers. +# * Non-ascii support. +# * Ordering, including `Host` header exception. + + +def headername(name: str) -> str: + if name.strip(VALID_HEADER_CHARS) or not name: + raise ValueError(f"Invalid HTTP header name {name!r}.") + return name + + +def headervalue(value: str) -> str: + value = value.strip(" ") + if not value or not value.isascii() or not value.isprintable(): + raise ValueError(f"Invalid HTTP header value {value!r}.") + return value + + +class Headers(typing.Mapping[str, str]): + def __init__( + self, + headers: typing.Mapping[str, str] | typing.Sequence[tuple[str, str]] | None = None, + ) -> None: + # {'accept': ('Accept', '*/*')} + d: dict[str, str] = {} + + if isinstance(headers, typing.Mapping): + # Headers({ + # 'Content-Length': '1024', + # 'Content-Type': 'text/plain; charset=utf-8', + # ) + d = {headername(k): headervalue(v) for k, v in headers.items()} + elif headers is not None: + # Headers([ + # ('Location', 'https://www.example.com'), + # ('Set-Cookie', 'session_id=3498jj489jhb98jn'), + # ]) + d = {headername(k): headervalue(v) for k, v in headers} + + self._dict = d + + def keys(self) -> typing.KeysView[str]: + """ + Return all the header keys. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.keys()) == ["Accept", "User-Agent"] + """ + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + """ + Return all the header values. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.values()) == ["*/*", "python/httpx"] + """ + return self._dict.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return all headers as (key, value) tuples. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.items()) == [("Accept", "*/*"), ("User-Agent", "python/httpx")] + """ + return self._dict.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert h.get("User-Agent") == "python/httpx" + """ + for k, v in self._dict.items(): + if k.lower() == key.lower(): + return v + return default + + def copy_set(self, key: str, value: str) -> "Headers": + """ + Return a new Headers instance, setting the value of a key. + + Usage: + + h = httpx.Headers({"Expires": "0"}) + h = h.copy_set("Expires", "Wed, 21 Oct 2015 07:28:00 GMT") + assert h == httpx.Headers({"Expires": "Wed, 21 Oct 2015 07:28:00 GMT"}) + """ + l = [] + seen = False + + # Either insert... + for k, v in self._dict.items(): + if k.lower() == key.lower(): + l.append((key, value)) + seen = True + else: + l.append((k, v)) + + # Or append... + if not seen: + l.append((key, value)) + + return Headers(l) + + def copy_remove(self, key: str) -> "Headers": + """ + Return a new Headers instance, removing the value of a key. + + Usage: + + h = httpx.Headers({"Accept": "*/*"}) + h = h.copy_remove("Accept") + assert h == httpx.Headers({}) + """ + h = {k: v for k, v in self._dict.items() if k.lower() != key.lower()} + return Headers(h) + + def copy_update(self, update: "Headers" | typing.Mapping[str, str] | None) -> "Headers": + """ + Return a new Headers instance, removing the value of a key. + + Usage: + + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + h = h.copy_update({"Accept-Encoding": "gzip"}) + assert h == httpx.Headers({"Accept": "*/*", "Accept-Encoding": "gzip", "User-Agent": "python/httpx"}) + """ + if update is None: + return self + + new = update if isinstance(update, Headers) else Headers(update) + + # Remove updated items using a case-insensitive approach... + keys = set([key.lower() for key in new.keys()]) + h = {k: v for k, v in self._dict.items() if k.lower() not in keys} + + # Perform the actual update... + h.update(dict(new)) + + return Headers(h) + + def __getitem__(self, key: str) -> str: + match = key.lower() + for k, v in self._dict.items(): + if k.lower() == match: + return v + raise KeyError(key) + + def __contains__(self, key: typing.Any) -> bool: + match = key.lower() + return any(k.lower() == match for k in self._dict.keys()) + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + self_lower = {k.lower(): v for k, v in self.items()} + other_lower = {k.lower(): v for k, v in Headers(other).items()} + return self_lower == other_lower + + def __repr__(self) -> str: + return f"" + + +def parse_opts_header(header: str) -> tuple[str, dict[str, str]]: + # The Content-Type header is described in RFC 2616 'Content-Type' + # https://datatracker.ietf.org/doc/html/rfc2616#section-14.17 + + # The 'type/subtype; parameter' format is described in RFC 2616 'Media Types' + # https://datatracker.ietf.org/doc/html/rfc2616#section-3.7 + + # Parameter quoting is described in RFC 2616 'Transfer Codings' + # https://datatracker.ietf.org/doc/html/rfc2616#section-3.6 + + header = header.strip() + content_type = '' + params = {} + + # Match the content type (up to the first semicolon or end) + match = re.match(r'^([^;]+)', header) + if match: + content_type = match.group(1).strip().lower() + rest = header[match.end():] + else: + return '', {} + + # Parse parameters, accounting for quoted strings + param_pattern = re.compile(r''' + ;\s* # Semicolon + optional whitespace + (?P[^=;\s]+) # Parameter key + = # Equal sign + (?P # Parameter value: + "(?:[^"\\]|\\.)*" # Quoted string with escapes + | # OR + [^;]* # Unquoted string (until semicolon) + ) + ''', re.VERBOSE) + + for match in param_pattern.finditer(rest): + key = match.group('key').lower() + value = match.group('value').strip() + if value.startswith('"') and value.endswith('"'): + # Remove surrounding quotes and unescape + value = re.sub(r'\\(.)', r'\1', value[1:-1]) + params[key] = value + + return content_type, params diff --git a/src/httpx/_network.py b/src/httpx/_network.py new file mode 100644 index 00000000..5ea9bb54 --- /dev/null +++ b/src/httpx/_network.py @@ -0,0 +1,243 @@ +import concurrent.futures +import contextlib +import contextvars +import select +import socket +import ssl +import threading +import time +import types +import typing + +from ._streams import Stream + + +__all__ = ["NetworkBackend", "NetworkStream", "timeout"] + +_timeout_stack: contextvars.ContextVar[list[float]] = contextvars.ContextVar("timeout_context", default=[]) + + +@contextlib.contextmanager +def timeout(duration: float) -> typing.Iterator[None]: + """ + A context managed timeout API. + + with timeout(1.0): + ... + """ + now = time.monotonic() + until = now + duration + stack = typing.cast(list[float], _timeout_stack.get()) + stack = [until] + stack + token = _timeout_stack.set(stack) + try: + yield + finally: + _timeout_stack.reset(token) + + +def get_current_timeout() -> float | None: + stack = _timeout_stack.get() + if not stack: + return None + soonest = min(stack) + now = time.monotonic() + remaining = soonest - now + if remaining <= 0.0: + raise TimeoutError() + return remaining + + +class NetworkStream(Stream): + def __init__(self, sock: socket.socket, address: tuple[str, int]) -> None: + self._socket = sock + self._address = address + self._is_tls = False + self._is_closed = False + + @property + def host(self) -> str: + return self._address[0] + + @property + def port(self) -> int: + return self._address[1] + + def read(self, size: int = -1) -> bytes: + if size < 0: + size = 64 * 1024 + timeout = get_current_timeout() + self._socket.settimeout(timeout) + content = self._socket.recv(size) + return content + + def write(self, buffer: bytes) -> None: + while buffer: + timeout = get_current_timeout() + self._socket.settimeout(timeout) + n = self._socket.send(buffer) + buffer = buffer[n:] + + def close(self) -> None: + if not self._is_closed: + self._is_closed = True + self._socket.close() + + def __repr__(self): + description = "" + description += " TLS" if self._is_tls else "" + description += " CLOSED" if self._is_closed else "" + return f"" + + def __del__(self): + if not self._is_closed: + import warnings + warnings.warn(f"NetworkStream was garbage collected without being closed.") + + def __enter__(self) -> "NetworkStream": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + self.close() + + +class NetworkListener: + def __init__(self, sock: socket.socket, address: tuple[str, int]) -> None: + self._server_socket = sock + self._address = address + self._is_closed = False + + @property + def host(self): + return self._address[0] + + @property + def port(self): + return self._address[1] + + def accept(self) -> NetworkStream | None: + """ + Blocks until an incoming connection is accepted, and returns the NetworkStream. + Stops blocking and returns `None` once the listener is closed. + """ + while not self._is_closed: + r, _, _ = select.select([self._server_socket], [], [], 3) + if r: + sock, address = self._server_socket.accept() + return NetworkStream(sock, address) + return None + + def close(self): + self._is_closed = True + self._server_socket.close() + + def __del__(self): + if not self._is_closed: + import warnings + warnings.warn("NetworkListener was garbage collected without being closed.") + + def __enter__(self) -> "NetworkListener": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + self.close() + + +class NetworkServer: + def __init__(self, listener: NetworkListener, handler: typing.Callable[[NetworkStream], None]) -> None: + self.listener = listener + self.handler = handler + self._max_workers = 5 + self._executor = None + self._thread = None + self._streams = list[NetworkStream] + + @property + def host(self): + return self.listener.host + + @property + def port(self): + return self.listener.port + + def __enter__(self): + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) + self._executor.submit(self._serve) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.listener.close() + self._executor.shutdown(wait=True) + + def _serve(self): + while stream := self.listener.accept(): + self._executor.submit(self._handler, stream) + + def _handler(self, stream): + try: + self.handler(stream) + finally: + stream.close() + + +class NetworkBackend: + def __init__(self, ssl_ctx: ssl.SSLContext | None = None): + self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx + + def create_default_context(self) -> ssl.SSLContext: + import certifi + return ssl.create_default_context(cafile=certifi.where()) + + def connect(self, host: str, port: int) -> NetworkStream: + """ + Connect to the given address, returning a NetworkStream instance. + """ + address = (host, port) + timeout = get_current_timeout() + sock = socket.create_connection(address, timeout=timeout) + return NetworkStream(sock, address) + + def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream: + """ + Connect to the given address, returning a NetworkStream instance. + """ + address = (host, port) + hostname = hostname or host + timeout = get_current_timeout() + sock = socket.create_connection(address, timeout=timeout) + sock = self._ssl_ctx.wrap_socket(sock, server_hostname=hostname) + return NetworkStream(sock, address) + + def listen(self, host: str, port: int) -> NetworkListener: + """ + List on the given address, returning a NetworkListener instance. + """ + address = (host, port) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(5) + sock.setblocking(False) + return NetworkListener(sock, address) + + def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer: + listener = self.listen(host, port) + return NetworkServer(listener, handler) + + def __repr__(self): + return "" + + +Semaphore = threading.Semaphore +Lock = threading.Lock +sleep = time.sleep diff --git a/src/httpx/_parsers.py b/src/httpx/_parsers.py new file mode 100644 index 00000000..830fccd9 --- /dev/null +++ b/src/httpx/_parsers.py @@ -0,0 +1,515 @@ +import enum + +from ._streams import Stream + +__all__ = ['HTTPParser', 'Mode', 'ProtocolError'] + + +# TODO... + +# * Upgrade +# * CONNECT + +# * Support 'Expect: 100 Continue' +# * Add 'Error' state transitions +# * Add tests to trickle data +# * Add type annotations + +# * Optional... HTTP/1.0 support +# * Read trailing headers on Transfer-Encoding: chunked. Not just '\r\n'. +# * When writing Transfer-Encoding: chunked, split large writes into buffer size. +# * When reading Transfer-Encoding: chunked, handle incomplete reads from large chunk sizes. +# * .read() doesn't document if will always return maximum available. + +# * validate method, target, protocol in request line +# * validate protocol, status_code, reason_phrase in response line +# * validate name, value on headers + + +class State(enum.Enum): + WAIT = 0 + SEND_METHOD_LINE = 1 + SEND_STATUS_LINE = 2 + SEND_HEADERS = 3 + SEND_BODY = 4 + RECV_METHOD_LINE = 5 + RECV_STATUS_LINE = 6 + RECV_HEADERS = 7 + RECV_BODY = 8 + DONE = 9 + CLOSED = 10 + + +class Mode(enum.Enum): + CLIENT = 0 + SERVER = 1 + + +# The usual transitions will be... + +# IDLE, IDLE +# SEND_HEADERS, IDLE +# SEND_BODY, IDLE +# DONE, IDLE +# DONE, SEND_HEADERS +# DONE, SEND_BODY +# DONE, DONE + +# Then either back to IDLE, IDLE +# or move to CLOSED, CLOSED + +# 1. It is also valid for the server to start +# sending the response without waiting for the +# complete request. +# 2. 1xx status codes are interim states, and +# transition from SEND_HEADERS back to IDLE +# 3. ... + +class ProtocolError(Exception): + pass + + +class HTTPParser: + """ + Usage... + + client = HTTPParser(writer, reader) + client.send_method_line() + client.send_headers() + client.send_body() + client.recv_status_line() + client.recv_headers() + client.recv_body() + client.complete() + client.close() + """ + def __init__(self, stream: Stream, mode: str) -> None: + self.stream = stream + self.parser = ReadAheadParser(stream) + self.mode = {'CLIENT': Mode.CLIENT, 'SERVER': Mode.SERVER}[mode] + + # Track state... + if self.mode == Mode.CLIENT: + self.send_state: State = State.SEND_METHOD_LINE + self.recv_state: State = State.WAIT + else: + self.recv_state = State.RECV_METHOD_LINE + self.send_state = State.WAIT + + # Track message framing... + self.send_content_length: int | None = 0 + self.recv_content_length: int | None = 0 + self.send_seen_length = 0 + self.recv_seen_length = 0 + + # Track connection keep alive... + self.send_keep_alive = True + self.recv_keep_alive = True + + # Special states... + self.processing_1xx = False + + def send_method_line(self, method: bytes, target: bytes, protocol: bytes) -> None: + """ + Send the initial request line: + + >>> p.send_method_line(b'GET', b'/', b'HTTP/1.1') + + Sending state will switch to SEND_HEADERS state. + """ + if self.send_state != State.SEND_METHOD_LINE: + msg = f"Called 'send_method_line' in invalid state {self.send_state}" + raise ProtocolError(msg) + + # Send initial request line, eg. "GET / HTTP/1.1" + if protocol != b'HTTP/1.1': + raise ProtocolError("Sent unsupported protocol version") + data = b" ".join([method, target, protocol]) + b"\r\n" + self.stream.write(data) + + self.send_state = State.SEND_HEADERS + self.recv_state = State.RECV_STATUS_LINE + + def send_status_line(self, protocol: bytes, status_code: int, reason: bytes) -> None: + """ + Send the initial response line: + + >>> p.send_method_line(b'HTTP/1.1', 200, b'OK') + + Sending state will switch to SEND_HEADERS state. + """ + if self.send_state != State.SEND_STATUS_LINE: + msg = f"Called 'send_status_line' in invalid state {self.send_state}" + raise ProtocolError(msg) + + # Send initial request line, eg. "GET / HTTP/1.1" + if protocol != b'HTTP/1.1': + raise ProtocolError("Sent unsupported protocol version") + status_code_bytes = str(status_code).encode('ascii') + data = b" ".join([protocol, status_code_bytes, reason]) + b"\r\n" + self.stream.write(data) + + self.send_state = State.SEND_HEADERS + + def send_headers(self, headers: list[tuple[bytes, bytes]]) -> None: + """ + Send the request headers: + + >>> p.send_headers([(b'Host', b'www.example.com')]) + + Sending state will switch to SEND_BODY state. + """ + if self.send_state != State.SEND_HEADERS: + msg = f"Called 'send_headers' in invalid state {self.send_state}" + raise ProtocolError(msg) + + # Update header state + seen_host = False + for name, value in headers: + lname = name.lower() + if lname == b'host': + seen_host = True + elif lname == b'content-length': + self.send_content_length = bounded_int( + value, + max_digits=20, + exc_text="Sent invalid Content-Length" + ) + elif lname == b'connection' and value == b'close': + self.send_keep_alive = False + elif lname == b'transfer-encoding' and value == b'chunked': + self.send_content_length = None + + if self.mode == Mode.CLIENT and not seen_host: + raise ProtocolError("Request missing 'Host' header") + + # Send request headers + lines = [name + b": " + value + b"\r\n" for name, value in headers] + data = b"".join(lines) + b"\r\n" + self.stream.write(data) + + self.send_state = State.SEND_BODY + + def send_body(self, body: bytes) -> None: + """ + Send the request body. An empty bytes argument indicates the end of the stream: + + >>> p.send_body(b'') + + Sending state will switch to DONE. + """ + if self.send_state != State.SEND_BODY: + msg = f"Called 'send_body' in invalid state {self.send_state}" + raise ProtocolError(msg) + + if self.send_content_length is None: + # Transfer-Encoding: chunked + self.send_seen_length += len(body) + marker = f'{len(body):x}\r\n'.encode('ascii') + self.stream.write(marker + body + b'\r\n') + + else: + # Content-Length: xxx + self.send_seen_length += len(body) + if self.send_seen_length > self.send_content_length: + msg = 'Too much data sent for declared Content-Length' + raise ProtocolError(msg) + if self.send_seen_length < self.send_content_length and body == b'': + msg = 'Not enough data sent for declared Content-Length' + raise ProtocolError(msg) + if body: + self.stream.write(body) + + if body == b'': + # Handle body close + self.send_state = State.DONE + + def recv_method_line(self) -> tuple[bytes, bytes, bytes]: + """ + Receive the initial request method line: + + >>> method, target, protocol = p.recv_status_line() + + Receive state will switch to RECV_HEADERS. + """ + if self.recv_state != State.RECV_METHOD_LINE: + msg = f"Called 'recv_method_line' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + # Read initial response line, eg. "GET / HTTP/1.1" + exc_text = "reading request method line" + line = self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text) + method, target, protocol = line.split(b" ", 2) + if protocol != b'HTTP/1.1': + raise ProtocolError("Received unsupported protocol version") + + self.recv_state = State.RECV_HEADERS + self.send_state = State.SEND_STATUS_LINE + return method, target, protocol + + def recv_status_line(self) -> tuple[bytes, int, bytes]: + """ + Receive the initial response status line: + + >>> protocol, status_code, reason_phrase = p.recv_status_line() + + Receive state will switch to RECV_HEADERS. + """ + if self.recv_state != State.RECV_STATUS_LINE: + msg = f"Called 'recv_status_line' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + # Read initial response line, eg. "HTTP/1.1 200 OK" + exc_text = "reading response status line" + line = self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text) + protocol, status_code_str, reason_phrase = line.split(b" ", 2) + if protocol != b'HTTP/1.1': + raise ProtocolError("Received unsupported protocol version") + + status_code = bounded_int( + status_code_str, + max_digits=3, + exc_text="Received invalid status code" + ) + if status_code < 100: + raise ProtocolError("Received invalid status code") + # 1xx status codes preceed the final response status code + self.processing_1xx = status_code < 200 + + self.recv_state = State.RECV_HEADERS + return protocol, status_code, reason_phrase + + def recv_headers(self) -> list[tuple[bytes, bytes]]: + """ + Receive the response headers: + + >>> headers = p.recv_status_line() + + Receive state will switch to RECV_BODY by default. + Receive state will revert to RECV_STATUS_CODE for interim 1xx responses. + """ + if self.recv_state != State.RECV_HEADERS: + msg = f"Called 'recv_headers' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + # Read response headers + headers = [] + exc_text = "reading response headers" + while line := self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text): + name, value = line.split(b":", 1) + value = value.strip(b" ") + headers.append((name, value)) + + # Update header state + seen_host = False + for name, value in headers: + lname = name.lower() + if lname == b'host': + seen_host = True + elif lname == b'content-length': + self.recv_content_length = bounded_int( + value, + max_digits=20, + exc_text="Received invalid Content-Length" + ) + elif lname == b'connection' and value == b'close': + self.recv_keep_alive = False + elif lname == b'transfer-encoding' and value == b'chunked': + self.recv_content_length = None + + if self.mode == Mode.SERVER and not seen_host: + raise ProtocolError("Request missing 'Host' header") + + if self.processing_1xx: + # 1xx status codes preceed the final response status code + self.processing_1xx = False + self.recv_state = State.RECV_STATUS_LINE + else: + self.recv_state = State.RECV_BODY + return headers + + def recv_body(self) -> bytes: + """ + Receive the response body. An empty byte string indicates the end of the stream: + + >>> buffer = bytearray() + >>> while body := p.recv_body() + >>> buffer.extend(body) + + The server will switch to DONE. + """ + if self.recv_state != State.RECV_BODY: + msg = f"Called 'recv_body' in invalid state {self.recv_state}" + raise ProtocolError(msg) + + if self.recv_content_length is None: + # Transfer-Encoding: chunked + exc_text = 'reading chunk size' + line = self.parser.read_until(b"\r\n", max_size=4096, exc_text=exc_text) + sizestr, _, _ = line.partition(b";") + + exc_text = "Received invalid chunk size" + size = bounded_hex(sizestr, max_digits=8, exc_text=exc_text) + if size > 0: + body = self.parser.read(size=size) + exc_text = 'reading chunk data' + self.parser.read_until(b"\r\n", max_size=2, exc_text=exc_text) + self.recv_seen_length += len(body) + else: + body = b'' + exc_text = 'reading chunk termination' + self.parser.read_until(b"\r\n", max_size=2, exc_text=exc_text) + + else: + # Content-Length: xxx + remaining = self.recv_content_length - self.recv_seen_length + size = min(remaining, 4096) + body = self.parser.read(size=size) + self.recv_seen_length += len(body) + if self.recv_seen_length < self.recv_content_length and body == b'': + msg = 'Not enough data received for declared Content-Length' + raise ProtocolError(msg) + + if body == b'': + # Handle body close + self.recv_state = State.DONE + return body + + def complete(self): + is_fully_complete = self.send_state == State.DONE and self.recv_state == State.DONE + is_keepalive = self.send_keep_alive and self.recv_keep_alive + + if not (is_fully_complete and is_keepalive): + self.close() + return + + if self.mode == Mode.CLIENT: + self.send_state = State.SEND_METHOD_LINE + self.recv_state = State.WAIT + else: + self.recv_state = State.RECV_METHOD_LINE + self.send_state = State.WAIT + + self.send_content_length = 0 + self.recv_content_length = 0 + self.send_seen_length = 0 + self.recv_seen_length = 0 + self.send_keep_alive = True + self.recv_keep_alive = True + self.processing_1xx = False + + def close(self): + if self.send_state != State.CLOSED: + self.send_state = State.CLOSED + self.recv_state = State.CLOSED + self.stream.close() + + def is_idle(self) -> bool: + return ( + self.send_state == State.SEND_METHOD_LINE or + self.recv_state == State.RECV_METHOD_LINE + ) + + def is_closed(self) -> bool: + return self.send_state == State.CLOSED + + def description(self) -> str: + return { + State.SEND_METHOD_LINE: "idle", + State.CLOSED: "closed", + }.get(self.send_state, "active") + + def __repr__(self) -> str: + cl_state = self.send_state.name + sr_state = self.recv_state.name + detail = f"client {cl_state}, server {sr_state}" + return f'' + + +class ReadAheadParser: + """ + A buffered I/O stream, with methods for read-ahead parsing. + """ + def __init__(self, stream: Stream) -> None: + self._buffer = b'' + self._stream = stream + self._chunk_size = 4096 + + def _read_some(self) -> bytes: + if self._buffer: + ret, self._buffer = self._buffer, b'' + return ret + return self._stream.read(self._chunk_size) + + def _push_back(self, buffer): + assert self._buffer == b'' + self._buffer = buffer + + def read(self, size: int) -> bytes: + """ + Read and return up to 'size' bytes from the stream, with I/O buffering provided. + + * Returns b'' to indicate connection close. + """ + buffer = bytearray() + while len(buffer) < size: + chunk = self._read_some() + if not chunk: + break + buffer.extend(chunk) + + if len(buffer) > size: + buffer, push_back = buffer[:size], buffer[size:] + self._push_back(bytes(push_back)) + return bytes(buffer) + + def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes: + """ + Read and return bytes from the stream, delimited by marker. + + * The marker is not included in the return bytes. + * The marker is consumed from the I/O stream. + * Raises `ProtocolError` if the stream closes before a marker occurance. + * Raises `ProtocolError` if marker did not occur within 'max_size + len(marker)' bytes. + """ + buffer = bytearray() + while len(buffer) <= max_size: + chunk = self._read_some() + if not chunk: + # stream closed before marker found. + raise ProtocolError(f"Stream closed early {exc_text}") + start_search = max(len(buffer) - len(marker), 0) + buffer.extend(chunk) + index = buffer.find(marker, start_search) + + if index > max_size: + # marker was found, though 'max_size' exceeded. + raise ProtocolError(f"Exceeded maximum size {exc_text}") + elif index >= 0: + endindex = index + len(marker) + self._push_back(bytes(buffer[endindex:])) + return bytes(buffer[:index]) + + raise ProtocolError(f"Exceeded maximum size {exc_text}") + + +def bounded_int(intstr: bytes, max_digits: int, exc_text: str): + if len(intstr) > max_digits: + # Length of bytestring exceeds maximum. + raise ProtocolError(exc_text) + if len(intstr.strip(b'0123456789')) != 0: + # Contains invalid characters. + raise ProtocolError(exc_text) + + return int(intstr) + + +def bounded_hex(hexstr: bytes, max_digits: int, exc_text: str): + if len(hexstr) > max_digits: + # Length of bytestring exceeds maximum. + raise ProtocolError(exc_text) + if len(hexstr.strip(b'0123456789abcdefABCDEF')) != 0: + # Contains invalid characters. + raise ProtocolError(exc_text) + + return int(hexstr, base=16) diff --git a/src/httpx/_pool.py b/src/httpx/_pool.py new file mode 100644 index 00000000..7193f8d8 --- /dev/null +++ b/src/httpx/_pool.py @@ -0,0 +1,284 @@ +import time +import typing +import types + +from ._content import Content +from ._headers import Headers +from ._network import Lock, NetworkBackend, Semaphore +from ._parsers import HTTPParser +from ._response import Response +from ._request import Request +from ._streams import HTTPStream, Stream +from ._urls import URL + + +__all__ = [ + "Transport", + "ConnectionPool", + "Connection", + "open_connection", +] + + +class Transport: + def send(self, request: Request) -> Response: + raise NotImplementedError() + + def close(self): + pass + + def request( + self, + method: str, + url: URL | str, + headers: Headers | dict[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = Request(method, url, headers=headers, content=content) + with self.send(request) as response: + response.read() + return response + + def stream( + self, + method: str, + url: URL | str, + headers: Headers | dict[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + request = Request(method, url, headers=headers, content=content) + response = self.send(request) + return response + + +class ConnectionPool(Transport): + def __init__(self, backend: NetworkBackend | None = None): + if backend is None: + backend = NetworkBackend() + + self._connections: list[Connection] = [] + self._network_backend = backend + self._limit_concurrency = Semaphore(100) + self._closed = False + + # Public API... + def send(self, request: Request) -> Response: + if self._closed: + raise RuntimeError("ConnectionPool is closed.") + + # TODO: concurrency limiting + self._cleanup() + connection = self._get_connection(request) + response = connection.send(request) + return response + + def close(self): + self._closed = True + closing = list(self._connections) + self._connections = [] + for conn in closing: + conn.close() + + # Create or reuse connections as required... + def _get_connection(self, request: Request) -> "Connection": + # Attempt to reuse an existing connection. + url = request.url + origin = URL(scheme=url.scheme, host=url.host, port=url.port) + now = time.monotonic() + for conn in self._connections: + if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now): + return conn + + # Or else create a new connection. + conn = open_connection( + origin, + hostname=request.headers["Host"], + backend=self._network_backend + ) + self._connections.append(conn) + return conn + + # Connection pool management... + def _cleanup(self) -> None: + now = time.monotonic() + for conn in list(self._connections): + if conn.is_expired(now): + conn.close() + if conn.is_closed(): + self._connections.remove(conn) + + @property + def connections(self) -> typing.List['Connection']: + return [c for c in self._connections] + + def description(self) -> str: + counts = {"active": 0} + for status in [c.description() for c in self._connections]: + counts[status] = counts.get(status, 0) + 1 + return ", ".join(f"{count} {status}" for status, count in counts.items()) + + # Builtins... + def __repr__(self) -> str: + return f"" + + def __del__(self): + if not self._closed: + import warnings + warnings.warn("ConnectionPool was garbage collected without being closed.") + + def __enter__(self) -> "ConnectionPool": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: + self.close() + + +class Connection(Transport): + def __init__(self, stream: Stream, origin: URL | str): + self._stream = stream + self._origin = URL(origin) + self._keepalive_duration = 5.0 + self._idle_expiry = time.monotonic() + self._keepalive_duration + self._request_lock = Lock() + self._parser = HTTPParser(stream, mode='CLIENT') + + # API for connection pool management... + def origin(self) -> URL: + return self._origin + + def is_idle(self) -> bool: + return self._parser.is_idle() + + def is_expired(self, when: float) -> bool: + return self._parser.is_idle() and when > self._idle_expiry + + def is_closed(self) -> bool: + return self._parser.is_closed() + + def description(self) -> str: + return self._parser.description() + + # API entry points... + def send(self, request: Request) -> Response: + #async with self._request_lock: + # try: + self._send_head(request) + self._send_body(request) + code, headers = self._recv_head() + stream = HTTPStream(self._recv_body, self._complete) + # TODO... + return Response(code, headers=headers, content=stream) + # finally: + # await self._cycle_complete() + + def close(self) -> None: + with self._request_lock: + self._close() + + # Top-level API for working directly with a connection. + def request( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + url = self._origin.join(url) + request = Request(method, url, headers=headers, content=content) + with self.send(request) as response: + response.read() + return response + + def stream( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ) -> Response: + url = self._origin.join(url) + request = Request(method, url, headers=headers, content=content) + return self.send(request) + + # Send the request... + def _send_head(self, request: Request) -> None: + method = request.method.encode('ascii') + target = request.url.target.encode('ascii') + protocol = b'HTTP/1.1' + self._parser.send_method_line(method, target, protocol) + headers = [ + (k.encode('ascii'), v.encode('ascii')) + for k, v in request.headers.items() + ] + self._parser.send_headers(headers) + + def _send_body(self, request: Request) -> None: + while data := request.stream.read(64 * 1024): + self._parser.send_body(data) + self._parser.send_body(b'') + + # Receive the response... + def _recv_head(self) -> tuple[int, Headers]: + _, code, _ = self._parser.recv_status_line() + h = self._parser.recv_headers() + headers = Headers([ + (k.decode('ascii'), v.decode('ascii')) + for k, v in h + ]) + return code, headers + + def _recv_body(self) -> bytes: + return self._parser.recv_body() + + # Request/response cycle complete... + def _complete(self) -> None: + self._parser.complete() + self._idle_expiry = time.monotonic() + self._keepalive_duration + + def _close(self) -> None: + self._parser.close() + + # Builtins... + def __repr__(self) -> str: + return f"" + + def __enter__(self) -> "Connection": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + self.close() + + +def open_connection( + url: URL | str, + hostname: str = '', + backend: NetworkBackend | None = None, + ) -> Connection: + + if isinstance(url, str): + url = URL(url) + + if url.scheme not in ("http", "https"): + raise ValueError("URL scheme must be 'http://' or 'https://'.") + if backend is None: + backend = NetworkBackend() + + host = url.host + port = url.port or {"http": 80, "https": 443}[url.scheme] + + if url.scheme == "https": + stream = backend.connect_tls(host, port, hostname) + else: + stream = backend.connect(host, port) + + return Connection(stream, url) diff --git a/src/httpx/_quickstart.py b/src/httpx/_quickstart.py new file mode 100644 index 00000000..1a975301 --- /dev/null +++ b/src/httpx/_quickstart.py @@ -0,0 +1,49 @@ +import typing + +from ._client import Client +from ._content import Content +from ._headers import Headers +from ._streams import Stream +from ._urls import URL + + +__all__ = ['get', 'post', 'put', 'patch', 'delete'] + + +def get( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, +): + with Client() as client: + return client.request("GET", url=url, headers=headers) + +def post( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, +): + with Client() as client: + return client.request("POST", url, headers=headers, content=content) + +def put( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, +): + with Client() as client: + return client.request("PUT", url, headers=headers, content=content) + +def patch( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, +): + with Client() as client: + return client.request("PATCH", url, headers=headers, content=content) + +def delete( + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, +): + with Client() as client: + return client.request("DELETE", url=url, headers=headers) diff --git a/src/httpx/_request.py b/src/httpx/_request.py new file mode 100644 index 00000000..1b739b18 --- /dev/null +++ b/src/httpx/_request.py @@ -0,0 +1,93 @@ +import types +import typing + +from ._content import Content +from ._streams import ByteStream, Stream +from ._headers import Headers +from ._urls import URL + +__all__ = ["Request"] + + +class Request: + def __init__( + self, + method: str, + url: URL | str, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + self.method = method + self.url = URL(url) + self.headers = Headers(headers) + self.stream: Stream = ByteStream(b"") + + # https://datatracker.ietf.org/doc/html/rfc2616#section-14.23 + # RFC 2616, Section 14.23, Host. + # + # A client MUST include a Host header field in all HTTP/1.1 request messages. + if "Host" not in self.headers: + self.headers = self.headers.copy_set("Host", self.url.netloc) + + if content is not None: + if isinstance(content, bytes): + self.stream = ByteStream(content) + elif isinstance(content, Stream): + self.stream = content + elif isinstance(content, Content): + ct = content.content_type() + self.stream = content.encode() + self.headers = self.headers.copy_set("Content-Type", ct) + else: + raise TypeError(f'Expected `Content | Stream | bytes | None` got {type(content)}') + + # https://datatracker.ietf.org/doc/html/rfc2616#section-4.3 + # RFC 2616, Section 4.3, Message Body. + # + # The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers. + content_length: int | None = self.stream.size + if content_length is None: + self.headers = self.headers.copy_set("Transfer-Encoding", "chunked") + elif content_length > 0: + self.headers = self.headers.copy_set("Content-Length", str(content_length)) + + elif method in ("POST", "PUT", "PATCH"): + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2 + # RFC 7230, Section 3.3.2, Content Length. + # + # A user agent SHOULD send a Content-Length in a request message when no + # Transfer-Encoding is sent and the request method defines a meaning for + # an enclosed payload body. For example, a Content-Length header field is + # normally sent in a POST request even when the value is 0. + # (indicating an empty payload body). + self.headers = self.headers.copy_set("Content-Length", "0") + + @property + def body(self) -> bytes: + if not hasattr(self, '_body'): + raise RuntimeError("'.body' cannot be accessed without calling '.read()'") + return self._body + + def read(self) -> bytes: + if not hasattr(self, '_body'): + self._body = self.stream.read() + self.stream = ByteStream(self._body) + return self._body + + def close(self) -> None: + self.stream.close() + + def __enter__(self): + return self + + def __exit__(self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + self.close() + + def __repr__(self): + return f"" diff --git a/src/httpx/_response.py b/src/httpx/_response.py new file mode 100644 index 00000000..abfec810 --- /dev/null +++ b/src/httpx/_response.py @@ -0,0 +1,158 @@ +import types +import typing + +from ._content import Content +from ._streams import ByteStream, Stream +from ._headers import Headers, parse_opts_header + +__all__ = ["Response"] + +# We're using the same set as stdlib `http.HTTPStatus` here... +# +# https://github.com/python/cpython/blob/main/Lib/http/__init__.py +_codes = { + 100: "Continue", + 101: "Switching Protocols", + 102: "Processing", + 103: "Early Hints", + 200: "OK", + 201: "Created", + 202: "Accepted", + 203: "Non-Authoritative Information", + 204: "No Content", + 205: "Reset Content", + 206: "Partial Content", + 207: "Multi-Status", + 208: "Already Reported", + 226: "IM Used", + 300: "Multiple Choices", + 301: "Moved Permanently", + 302: "Found", + 303: "See Other", + 304: "Not Modified", + 305: "Use Proxy", + 307: "Temporary Redirect", + 308: "Permanent Redirect", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 412: "Precondition Failed", + 413: "Content Too Large", + 414: "URI Too Long", + 415: "Unsupported Media Type", + 416: "Range Not Satisfiable", + 417: "Expectation Failed", + 418: "I'm a Teapot", + 421: "Misdirected Request", + 422: "Unprocessable Content", + 423: "Locked", + 424: "Failed Dependency", + 425: "Too Early", + 426: "Upgrade Required", + 428: "Precondition Required", + 429: "Too Many Requests", + 431: "Request Header Fields Too Large", + 451: "Unavailable For Legal Reasons", + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + 505: "HTTP Version Not Supported", + 506: "Variant Also Negotiates", + 507: "Insufficient Storage", + 508: "Loop Detected", + 510: "Not Extended", + 511: "Network Authentication Required", +} + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: Headers | typing.Mapping[str, str] | None = None, + content: Content | Stream | bytes | None = None, + ): + self.status_code = status_code + self.headers = Headers(headers) + self.stream: Stream = ByteStream(b"") + + if content is not None: + if isinstance(content, bytes): + self.stream = ByteStream(content) + elif isinstance(content, Stream): + self.stream = content + elif isinstance(content, Content): + ct = content.content_type() + self.stream = content.encode() + self.headers = self.headers.copy_set("Content-Type", ct) + else: + raise TypeError(f'Expected `Content | Stream | bytes | None` got {type(content)}') + + # https://datatracker.ietf.org/doc/html/rfc2616#section-4.3 + # RFC 2616, Section 4.3, Message Body. + # + # All 1xx (informational), 204 (no content), and 304 (not modified) responses + # MUST NOT include a message-body. All other responses do include a + # message-body, although it MAY be of zero length. + if status_code >= 200 and status_code != 204 and status_code != 304: + content_length: int | None = self.stream.size + if content_length is None: + self.headers = self.headers.copy_set("Transfer-Encoding", "chunked") + else: + self.headers = self.headers.copy_set("Content-Length", str(content_length)) + + @property + def reason_phrase(self): + return _codes.get(self.status_code, "Unknown Status Code") + + @property + def body(self) -> bytes: + if not hasattr(self, '_body'): + raise RuntimeError("'.body' cannot be accessed without calling '.read()'") + return self._body + + @property + def text(self) -> str: + if not hasattr(self, '_body'): + raise RuntimeError("'.text' cannot be accessed without calling '.read()'") + if not hasattr(self, '_text'): + ct = self.headers.get('Content-Type', '') + media, opts = parse_opts_header(ct) + charset = 'utf-8' + if media.startswith('text/'): + charset = opts.get('charset', 'utf-8') + self._text = self._body.decode(charset) + return self._text + + def read(self) -> bytes: + if not hasattr(self, '_body'): + self._body = self.stream.read() + return self._body + + def close(self) -> None: + self.stream.close() + + def __enter__(self): + return self + + def __exit__(self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + self.close() + + def __repr__(self): + return f"" diff --git a/src/httpx/_server.py b/src/httpx/_server.py new file mode 100644 index 00000000..95226d99 --- /dev/null +++ b/src/httpx/_server.py @@ -0,0 +1,126 @@ +import contextlib +import logging +import time + +from ._content import Text +from ._parsers import HTTPParser +from ._request import Request +from ._response import Response +from ._network import NetworkBackend, sleep +from ._streams import HTTPStream + +__all__ = [ + "serve_http", "run" +] + +logger = logging.getLogger("httpx.server") + + +class ConnectionClosed(Exception): + pass + + +class HTTPConnection: + def __init__(self, stream, endpoint): + self._stream = stream + self._endpoint = endpoint + self._parser = HTTPParser(stream, mode='SERVER') + self._keepalive_duration = 5.0 + self._idle_expiry = time.monotonic() + self._keepalive_duration + + # API entry points... + def handle_requests(self): + try: + while not self._parser.is_closed(): + method, url, headers = self._recv_head() + stream = HTTPStream(self._recv_body, self._complete) + # TODO: Handle endpoint exceptions + with Request(method, url, headers=headers, content=stream) as request: + try: + response = self._endpoint(request) + status_line = f"{request.method} {request.url.target} [{response.status_code} {response.reason_phrase}]" + logger.info(status_line) + except Exception: + logger.error("Internal Server Error", exc_info=True) + content = Text("Internal Server Error") + err = Response(code=500, content=content) + self._send_head(err) + self._send_body(err) + else: + self._send_head(response) + self._send_body(response) + except Exception: + logger.error("Internal Server Error", exc_info=True) + + def close(self): + self._parser.close() + + # Receive the request... + def _recv_head(self) -> tuple[str, str, list[tuple[str, str]]]: + method, target, _ = self._parser.recv_method_line() + m = method.decode('ascii') + t = target.decode('ascii') + headers = self._parser.recv_headers() + h = [ + (k.decode('latin-1'), v.decode('latin-1')) + for k, v in headers + ] + return m, t, h + + def _recv_body(self): + return self._parser.recv_body() + + # Return the response... + def _send_head(self, response: Response): + protocol = b"HTTP/1.1" + status = response.status_code + reason = response.reason_phrase.encode('ascii') + self._parser.send_status_line(protocol, status, reason) + headers = [ + (k.encode('ascii'), v.encode('ascii')) + for k, v in response.headers.items() + ] + self._parser.send_headers(headers) + + def _send_body(self, response: Response): + while data := response.stream.read(64 * 1024): + self._parser.send_body(data) + self._parser.send_body(b'') + + # Start it all over again... + def _complete(self): + self._parser.complete + self._idle_expiry = time.monotonic() + self._keepalive_duration + + +class HTTPServer: + def __init__(self, host, port): + self.url = f"http://{host}:{port}/" + + def wait(self): + while(True): + sleep(1) + + +@contextlib.contextmanager +def serve_http(endpoint): + def handler(stream): + connection = HTTPConnection(stream, endpoint) + connection.handle_requests() + + logging.basicConfig( + format="%(levelname)s [%(asctime)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG + ) + + backend = NetworkBackend() + with backend.serve("127.0.0.1", 8080, handler) as server: + server = HTTPServer(server.host, server.port) + logger.info(f"Serving on {server.url} (Press CTRL+C to quit)") + yield server + + +def run(app): + with serve_http(app) as server: + server.wait() diff --git a/src/httpx/_streams.py b/src/httpx/_streams.py new file mode 100644 index 00000000..1fc6cde0 --- /dev/null +++ b/src/httpx/_streams.py @@ -0,0 +1,235 @@ +import io +import types +import os + + +class Stream: + def read(self, size: int=-1) -> bytes: + raise NotImplementedError() + + def write(self, data: bytes) -> None: + raise NotImplementedError() + + def close(self) -> None: + raise NotImplementedError() + + @property + def size(self) -> int | None: + return None + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None + ): + self.close() + + +class ByteStream(Stream): + def __init__(self, data: bytes = b''): + self._buffer = io.BytesIO(data) + self._size = len(data) + + def read(self, size: int=-1) -> bytes: + return self._buffer.read(size) + + def close(self) -> None: + self._buffer.close() + + @property + def size(self) -> int | None: + return self._size + + +class DuplexStream(Stream): + """ + DuplexStream supports both `read` and `write` operations, + which are applied to seperate buffers. + + This stream can be used for testing network parsers. + """ + + def __init__(self, data: bytes = b''): + self._read_buffer = io.BytesIO(data) + self._write_buffer = io.BytesIO() + + def read(self, size: int=-1) -> bytes: + return self._read_buffer.read(size) + + def write(self, buffer: bytes): + return self._write_buffer.write(buffer) + + def close(self) -> None: + self._read_buffer.close() + self._write_buffer.close() + + def input_bytes(self) -> bytes: + return self._read_buffer.getvalue() + + def output_bytes(self) -> bytes: + return self._write_buffer.getvalue() + + +class FileStream(Stream): + def __init__(self, path): + self._path = path + self._fileobj = None + self._size = None + + def read(self, size: int=-1) -> bytes: + if self._fileobj is None: + raise ValueError('I/O operation on unopened file') + return self._fileobj.read(size) + + def open(self): + self._fileobj = open(self._path, 'rb') + self._size = os.path.getsize(self._path) + return self + + def close(self) -> None: + if self._fileobj is not None: + self._fileobj.close() + + @property + def size(self) -> int | None: + return self._size + + def __enter__(self): + self.open() + return self + + +class HTTPStream(Stream): + def __init__(self, next_chunk, complete): + self._next_chunk = next_chunk + self._complete = complete + self._buffer = io.BytesIO() + + def read(self, size=-1) -> bytes: + sections = [] + length = 0 + + # If we have any data in the buffer read that and clear the buffer. + buffered = self._buffer.read() + if buffered: + sections.append(buffered) + length += len(buffered) + self._buffer.seek(0) + self._buffer.truncate(0) + + # Read each chunk in turn. + while (size < 0) or (length < size): + section = self._next_chunk() + sections.append(section) + length += len(section) + if section == b'': + break + + # If we've more data than requested, then push some back into the buffer. + output = b''.join(sections) + if size > -1 and len(output) > size: + output, remainder = output[:size], output[size:] + self._buffer.write(remainder) + self._buffer.seek(0) + + return output + + def close(self) -> None: + self._buffer.close() + if self._complete is not None: + self._complete() + + +class MultiPartStream(Stream): + def __init__(self, form: list[tuple[str, str]], files: list[tuple[str, str]], boundary=''): + self._form = list(form) + self._files = list(files) + self._boundary = boundary or os.urandom(16).hex() + # Mutable state... + self._form_progress = list(self._form) + self._files_progress = list(self._files) + self._filestream: FileStream | None = None + self._complete = False + self._buffer = io.BytesIO() + + def read(self, size=-1) -> bytes: + sections = [] + length = 0 + + # If we have any data in the buffer read that and clear the buffer. + buffered = self._buffer.read() + if buffered: + sections.append(buffered) + length += len(buffered) + self._buffer.seek(0) + self._buffer.truncate(0) + + # Read each multipart section in turn. + while (size < 0) or (length < size): + section = self._read_next_section() + sections.append(section) + length += len(section) + if section == b'': + break + + # If we've more data than requested, then push some back into the buffer. + output = b''.join(sections) + if size > -1 and len(output) > size: + output, remainder = output[:size], output[size:] + self._buffer.write(remainder) + self._buffer.seek(0) + + return output + + def _read_next_section(self) -> bytes: + if self._form_progress: + # return a form item + key, value = self._form_progress.pop(0) + name = key.translate({10: "%0A", 13: "%0D", 34: "%22"}) + return ( + f"--{self._boundary}\r\n" + f'Content-Disposition: form-data; name="{name}"\r\n' + f"\r\n" + f"{value}\r\n" + ).encode("utf-8") + elif self._files_progress and self._filestream is None: + # return start of a file item + key, value = self._files_progress.pop(0) + self._filestream = FileStream(value).open() + name = key.translate({10: "%0A", 13: "%0D", 34: "%22"}) + filename = os.path.basename(value) + return ( + f"--{self._boundary}\r\n" + f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n' + f"\r\n" + ).encode("utf-8") + elif self._filestream is not None: + chunk = self._filestream.read(64*1024) + if chunk != b'': + # return some bytes from file + return chunk + else: + # return end of file item + self._filestream.close() + self._filestream = None + return b"\r\n" + elif not self._complete: + # return final section of multipart + self._complete = True + return f"--{self._boundary}--\r\n".encode("utf-8") + # return EOF marker + return b"" + + def close(self) -> None: + if self._filestream is not None: + self._filestream.close() + self._filestream = None + self._buffer.close() + + @property + def size(self) -> int | None: + return None diff --git a/src/httpx/_urlencode.py b/src/httpx/_urlencode.py new file mode 100644 index 00000000..1a83b620 --- /dev/null +++ b/src/httpx/_urlencode.py @@ -0,0 +1,85 @@ +import re + +__all__ = ["quote", "unquote", "urldecode", "urlencode"] + + +# Matchs a sequence of one or more '%xx' escapes. +PERCENT_ENCODED_REGEX = re.compile("(%[A-Fa-f0-9][A-Fa-f0-9])+") + +# https://datatracker.ietf.org/doc/html/rfc3986#section-2.3 +SAFE = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + + +def urlencode(multidict, safe=SAFE): + pairs = [] + for key, values in multidict.items(): + pairs.extend([(key, value) for value in values]) + + safe += "+" + pairs = [(k.replace(" ", "+"), v.replace(" ", "+")) for k, v in pairs] + + return "&".join( + f"{quote(key, safe)}={quote(val, safe)}" + for key, val in pairs + ) + + +def urldecode(string): + parts = [part.partition("=") for part in string.split("&") if part] + pairs = [ + (unquote(key), unquote(val)) + for key, _, val in parts + ] + + pairs = [(k.replace("+", " "), v.replace("+", " ")) for k, v in pairs] + + ret = {} + for k, v in pairs: + ret.setdefault(k, []).append(v) + return ret + + +def quote(string, safe=SAFE): + # Fast path if the string is already safe. + if not string.strip(safe): + return string + + # Replace any characters not in the safe set with '%xx' escape sequences. + return "".join([ + char if char in safe else percent(char) + for char in string + ]) + + +def unquote(string): + # Fast path if the string is not quoted. + if '%' not in string: + return string + + # Unquote. + parts = [] + current_position = 0 + for match in re.finditer(PERCENT_ENCODED_REGEX, string): + start_position, end_position = match.start(), match.end() + matched_text = match.group(0) + # Include any text up to the '%xx' escape sequence. + if start_position != current_position: + leading_text = string[current_position:start_position] + parts.append(leading_text) + + # Decode the '%xx' escape sequence. + hex = matched_text.replace('%', '') + decoded = bytes.fromhex(hex).decode('utf-8') + parts.append(decoded) + current_position = end_position + + # Include any text after the final '%xx' escape sequence. + if current_position != len(string): + trailing_text = string[current_position:] + parts.append(trailing_text) + + return "".join(parts) + + +def percent(c): + return ''.join(f"%{b:02X}" for b in c.encode("utf-8")) diff --git a/src/httpx/_urlparse.py b/src/httpx/_urlparse.py new file mode 100644 index 00000000..612892fa --- /dev/null +++ b/src/httpx/_urlparse.py @@ -0,0 +1,534 @@ +""" +An implementation of `urlparse` that provides URL validation and normalization +as described by RFC3986. + +We rely on this implementation rather than the one in Python's stdlib, because: + +* It provides more complete URL validation. +* It properly differentiates between an empty querystring and an absent querystring, + to distinguish URLs with a trailing '?'. +* It handles scheme, hostname, port, and path normalization. +* It supports IDNA hostnames, normalizing them to their encoded form. +* The API supports passing individual components, as well as the complete URL string. + +Previously we relied on the excellent `rfc3986` package to handle URL parsing and +validation, but this module provides a simpler alternative, with less indirection +required. +""" + +import ipaddress +import re +import typing + + +class InvalidURL(ValueError): + pass + + +MAX_URL_LENGTH = 65536 + +# https://datatracker.ietf.org/doc/html/rfc3986.html#section-2.3 +UNRESERVED_CHARACTERS = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" +) +SUB_DELIMS = "!$&'()*+,;=" + +PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}") + +# https://url.spec.whatwg.org/#percent-encoded-bytes + +# The fragment percent-encode set is the C0 control percent-encode set +# and U+0020 SPACE, U+0022 ("), U+003C (<), U+003E (>), and U+0060 (`). +FRAG_SAFE = "".join( + [chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x3C, 0x3E, 0x60)] +) + +# The query percent-encode set is the C0 control percent-encode set +# and U+0020 SPACE, U+0022 ("), U+0023 (#), U+003C (<), and U+003E (>). +QUERY_SAFE = "".join( + [chr(i) for i in range(0x20, 0x7F) if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E)] +) + +# The path percent-encode set is the query percent-encode set +# and U+003F (?), U+0060 (`), U+007B ({), and U+007D (}). +PATH_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + (0x3F, 0x60, 0x7B, 0x7D) + ] +) + +# The userinfo percent-encode set is the path percent-encode set +# and U+002F (/), U+003A (:), U+003B (;), U+003D (=), U+0040 (@), +# U+005B ([) to U+005E (^), inclusive, and U+007C (|). +USERNAME_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) +PASSWORD_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3A, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) +# Note... The terminology 'userinfo' percent-encode set in the WHATWG document +# is used for the username and password quoting. For the joint userinfo component +# we remove U+003A (:) from the safe set. +USERINFO_SAFE = "".join( + [ + chr(i) + for i in range(0x20, 0x7F) + if i + not in (0x20, 0x22, 0x23, 0x3C, 0x3E) + + (0x3F, 0x60, 0x7B, 0x7D) + + (0x2F, 0x3B, 0x3D, 0x40, 0x5B, 0x5C, 0x5D, 0x5E, 0x7C) + ] +) + + +# {scheme}: (optional) +# //{authority} (optional) +# {path} +# ?{query} (optional) +# #{fragment} (optional) +URL_REGEX = re.compile( + ( + r"(?:(?P{scheme}):)?" + r"(?://(?P{authority}))?" + r"(?P{path})" + r"(?:\?(?P{query}))?" + r"(?:#(?P{fragment}))?" + ).format( + scheme="([a-zA-Z][a-zA-Z0-9+.-]*)?", + authority="[^/?#]*", + path="[^?#]*", + query="[^#]*", + fragment=".*", + ) +) + +# {userinfo}@ (optional) +# {host} +# :{port} (optional) +AUTHORITY_REGEX = re.compile( + ( + r"(?:(?P{userinfo})@)?" r"(?P{host})" r":?(?P{port})?" + ).format( + userinfo=".*", # Any character sequence. + host="(\\[.*\\]|[^:@]*)", # Either any character sequence excluding ':' or '@', + # or an IPv6 address enclosed within square brackets. + port=".*", # Any character sequence. + ) +) + + +# If we call urlparse with an individual component, then we need to regex +# validate that component individually. +# Note that we're duplicating the same strings as above. Shock! Horror!! +COMPONENT_REGEX = { + "scheme": re.compile("([a-zA-Z][a-zA-Z0-9+.-]*)?"), + "authority": re.compile("[^/?#]*"), + "path": re.compile("[^?#]*"), + "query": re.compile("[^#]*"), + "fragment": re.compile(".*"), + "userinfo": re.compile("[^@]*"), + "host": re.compile("(\\[.*\\]|[^:]*)"), + "port": re.compile(".*"), +} + + +# We use these simple regexs as a first pass before handing off to +# the stdlib 'ipaddress' module for IP address validation. +IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") +IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$") + + +class ParseResult(typing.NamedTuple): + scheme: str + userinfo: str + host: str + port: int | None + path: str + query: str | None + fragment: str | None + + @property + def authority(self) -> str: + return "".join( + [ + f"{self.userinfo}@" if self.userinfo else "", + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + @property + def netloc(self) -> str: + return "".join( + [ + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + def copy_with(self, **kwargs: str | None) -> "ParseResult": + if not kwargs: + return self + + defaults = { + "scheme": self.scheme, + "authority": self.authority, + "path": self.path, + "query": self.query, + "fragment": self.fragment, + } + defaults.update(kwargs) + return urlparse("", **defaults) + + def __str__(self) -> str: + authority = self.authority + return "".join( + [ + f"{self.scheme}:" if self.scheme else "", + f"//{authority}" if authority else "", + self.path, + f"?{self.query}" if self.query is not None else "", + f"#{self.fragment}" if self.fragment is not None else "", + ] + ) + + +def urlparse(url: str = "", **kwargs: str | None) -> ParseResult: + # Initial basic checks on allowable URLs. + # --------------------------------------- + + # Hard limit the maximum allowable URL length. + if len(url) > MAX_URL_LENGTH: + raise InvalidURL("URL too long") + + # If a URL includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in url): + char = next(char for char in url if char.isascii() and not char.isprintable()) + idx = url.find(char) + error = ( + f"Invalid non-printable ASCII character in URL, {char!r} at position {idx}." + ) + raise InvalidURL(error) + + # Some keyword arguments require special handling. + # ------------------------------------------------ + + # Coerce "port" to a string, if it is provided as an integer. + if "port" in kwargs: + port = kwargs["port"] + kwargs["port"] = str(port) if isinstance(port, int) else port + + # Replace "netloc" with "host and "port". + if "netloc" in kwargs: + netloc = kwargs.pop("netloc") or "" + kwargs["host"], _, kwargs["port"] = netloc.partition(":") + + # Replace "username" and/or "password" with "userinfo". + if "username" in kwargs or "password" in kwargs: + username = quote(kwargs.pop("username", "") or "", safe=USERNAME_SAFE) + password = quote(kwargs.pop("password", "") or "", safe=PASSWORD_SAFE) + kwargs["userinfo"] = f"{username}:{password}" if password else username + + # Replace "raw_path" with "path" and "query". + if "raw_path" in kwargs: + raw_path = kwargs.pop("raw_path") or "" + kwargs["path"], seperator, kwargs["query"] = raw_path.partition("?") + if not seperator: + kwargs["query"] = None + + # Ensure that IPv6 "host" addresses are always escaped with "[...]". + if "host" in kwargs: + host = kwargs.get("host") or "" + if ":" in host and not (host.startswith("[") and host.endswith("]")): + kwargs["host"] = f"[{host}]" + + # If any keyword arguments are provided, ensure they are valid. + # ------------------------------------------------------------- + + for key, value in kwargs.items(): + if value is not None: + if len(value) > MAX_URL_LENGTH: + raise InvalidURL(f"URL component '{key}' too long") + + # If a component includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in value): + char = next( + char for char in value if char.isascii() and not char.isprintable() + ) + idx = value.find(char) + error = ( + f"Invalid non-printable ASCII character in URL {key} component, " + f"{char!r} at position {idx}." + ) + raise InvalidURL(error) + + # Ensure that keyword arguments match as a valid regex. + if not COMPONENT_REGEX[key].fullmatch(value): + raise InvalidURL(f"Invalid URL component '{key}'") + + # The URL_REGEX will always match, but may have empty components. + url_match = URL_REGEX.match(url) + assert url_match is not None + url_dict = url_match.groupdict() + + # * 'scheme', 'authority', and 'path' may be empty strings. + # * 'query' may be 'None', indicating no trailing "?" portion. + # Any string including the empty string, indicates a trailing "?". + # * 'fragment' may be 'None', indicating no trailing "#" portion. + # Any string including the empty string, indicates a trailing "#". + scheme = kwargs.get("scheme", url_dict["scheme"]) or "" + authority = kwargs.get("authority", url_dict["authority"]) or "" + path = kwargs.get("path", url_dict["path"]) or "" + query = kwargs.get("query", url_dict["query"]) + frag = kwargs.get("fragment", url_dict["fragment"]) + + # The AUTHORITY_REGEX will always match, but may have empty components. + authority_match = AUTHORITY_REGEX.match(authority) + assert authority_match is not None + authority_dict = authority_match.groupdict() + + # * 'userinfo' and 'host' may be empty strings. + # * 'port' may be 'None'. + userinfo = kwargs.get("userinfo", authority_dict["userinfo"]) or "" + host = kwargs.get("host", authority_dict["host"]) or "" + port = kwargs.get("port", authority_dict["port"]) + + # Normalize and validate each component. + # We end up with a parsed representation of the URL, + # with components that are plain ASCII bytestrings. + parsed_scheme: str = scheme.lower() + parsed_userinfo: str = quote(userinfo, safe=USERINFO_SAFE) + parsed_host: str = encode_host(host) + parsed_port: int | None = normalize_port(port, scheme) + + has_scheme = parsed_scheme != "" + has_authority = ( + parsed_userinfo != "" or parsed_host != "" or parsed_port is not None + ) + validate_path(path, has_scheme=has_scheme, has_authority=has_authority) + if has_scheme or has_authority: + path = normalize_path(path) + + parsed_path: str = quote(path, safe=PATH_SAFE) + parsed_query: str | None = None if query is None else quote(query, safe=QUERY_SAFE) + parsed_frag: str | None = None if frag is None else quote(frag, safe=FRAG_SAFE) + + # The parsed ASCII bytestrings are our canonical form. + # All properties of the URL are derived from these. + return ParseResult( + parsed_scheme, + parsed_userinfo, + parsed_host, + parsed_port, + parsed_path, + parsed_query, + parsed_frag, + ) + + +def encode_host(host: str) -> str: + if not host: + return "" + + elif IPv4_STYLE_HOSTNAME.match(host): + # Validate IPv4 hostnames like #.#.#.# + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet + try: + ipaddress.IPv4Address(host) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv4 address: {host!r}") + return host + + elif IPv6_STYLE_HOSTNAME.match(host): + # Validate IPv6 hostnames like [...] + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # "A host identified by an Internet Protocol literal address, version 6 + # [RFC3513] or later, is distinguished by enclosing the IP literal + # within square brackets ("[" and "]"). This is the only place where + # square bracket characters are allowed in the URI syntax." + try: + ipaddress.IPv6Address(host[1:-1]) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv6 address: {host!r}") + return host[1:-1] + + elif not host.isascii(): + try: + import idna # type: ignore + except ImportError: + raise InvalidURL( + f"Cannot handle URL with IDNA hostname: {host!r}. " + f"Package 'idna' is not installed." + ) + + # IDNA hostnames + try: + return idna.encode(host.lower()).decode("ascii") + except idna.IDNAError: + raise InvalidURL(f"Invalid IDNA hostname: {host!r}") + + # Regular ASCII hostnames + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # reg-name = *( unreserved / pct-encoded / sub-delims ) + WHATWG_SAFE = '"`{}%|\\' + return quote(host.lower(), safe=SUB_DELIMS + WHATWG_SAFE) + + +def normalize_port(port: str | int | None, scheme: str) -> int | None: + # From https://tools.ietf.org/html/rfc3986#section-3.2.3 + # + # "A scheme may define a default port. For example, the "http" scheme + # defines a default port of "80", corresponding to its reserved TCP + # port number. The type of port designated by the port number (e.g., + # TCP, UDP, SCTP) is defined by the URI scheme. URI producers and + # normalizers should omit the port component and its ":" delimiter if + # port is empty or if its value would be the same as that of the + # scheme's default." + if port is None or port == "": + return None + + try: + port_as_int = int(port) + except ValueError: + raise InvalidURL(f"Invalid port: {port!r}") + + # See https://url.spec.whatwg.org/#url-miscellaneous + default_port = {"ftp": 21, "http": 80, "https": 443, "ws": 80, "wss": 443}.get( + scheme + ) + if port_as_int == default_port: + return None + return port_as_int + + +def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None: + """ + Path validation rules that depend on if the URL contains + a scheme or authority component. + + See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3 + """ + if has_authority: + # If a URI contains an authority component, then the path component + # must either be empty or begin with a slash ("/") character." + if path and not path.startswith("/"): + raise InvalidURL("For absolute URLs, path must be empty or begin with '/'") + + if not has_scheme and not has_authority: + # If a URI does not contain an authority component, then the path cannot begin + # with two slash characters ("//"). + if path.startswith("//"): + raise InvalidURL("Relative URLs cannot have a path starting with '//'") + + # In addition, a URI reference (Section 4.1) may be a relative-path reference, + # in which case the first path segment cannot contain a colon (":") character. + if path.startswith(":"): + raise InvalidURL("Relative URLs cannot have a path starting with ':'") + + +def normalize_path(path: str) -> str: + """ + Drop "." and ".." segments from a URL path. + + For example: + + normalize_path("/path/./to/somewhere/..") == "/path/to" + """ + # Fast return when no '.' characters in the path. + if "." not in path: + return path + + components = path.split("/") + + # Fast return when no '.' or '..' components in the path. + if "." not in components and ".." not in components: + return path + + # https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4 + output: list[str] = [] + for component in components: + if component == ".": + pass + elif component == "..": + if output and output != [""]: + output.pop() + else: + output.append(component) + return "/".join(output) + + +def PERCENT(string: str) -> str: + return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")]) + + +def percent_encoded(string: str, safe: str) -> str: + """ + Use percent-encoding to quote a string. + """ + NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + + # Fast path for strings that don't need escaping. + if not string.rstrip(NON_ESCAPED_CHARS): + return string + + return "".join( + [char if char in NON_ESCAPED_CHARS else PERCENT(char) for char in string] + ) + + +def quote(string: str, safe: str) -> str: + """ + Use percent-encoding to quote a string, omitting existing '%xx' escape sequences. + + See: https://www.rfc-editor.org/rfc/rfc3986#section-2.1 + + * `string`: The string to be percent-escaped. + * `safe`: A string containing characters that may be treated as safe, and do not + need to be escaped. Unreserved characters are always treated as safe. + See: https://www.rfc-editor.org/rfc/rfc3986#section-2.3 + """ + parts = [] + current_position = 0 + for match in re.finditer(PERCENT_ENCODED_REGEX, string): + start_position, end_position = match.start(), match.end() + matched_text = match.group(0) + # Add any text up to the '%xx' escape sequence. + if start_position != current_position: + leading_text = string[current_position:start_position] + parts.append(percent_encoded(leading_text, safe=safe)) + + # Add the '%xx' escape sequence. + parts.append(matched_text) + current_position = end_position + + # Add any text after the final '%xx' escape sequence. + if current_position != len(string): + trailing_text = string[current_position:] + parts.append(percent_encoded(trailing_text, safe=safe)) + + return "".join(parts) diff --git a/src/httpx/_urls.py b/src/httpx/_urls.py new file mode 100644 index 00000000..4ae4464e --- /dev/null +++ b/src/httpx/_urls.py @@ -0,0 +1,552 @@ +from __future__ import annotations + +import typing + +from ._urlparse import urlparse +from ._urlencode import unquote, urldecode, urlencode + +__all__ = ["QueryParams", "URL"] + + +class URL: + """ + url = httpx.URL("HTTPS://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink") + + assert url.scheme == "https" + assert url.username == "jo@email.com" + assert url.password == "a secret" + assert url.userinfo == b"jo%40email.com:a%20secret" + assert url.host == "müller.de" + assert url.raw_host == b"xn--mller-kva.de" + assert url.port == 1234 + assert url.netloc == b"xn--mller-kva.de:1234" + assert url.path == "/pa th" + assert url.query == b"?search=ab" + assert url.raw_path == b"/pa%20th?search=ab" + assert url.fragment == "anchorlink" + + The components of a URL are broken down like this: + + https://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink + [scheme] [ username ] [password] [ host ][port][ path ] [ query ] [fragment] + [ userinfo ] [ netloc ][ raw_path ] + + Note that: + + * `url.scheme` is normalized to always be lowercased. + + * `url.host` is normalized to always be lowercased. Internationalized domain + names are represented in unicode, without IDNA encoding applied. For instance: + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "中国.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "中国.icom.museum" + + * `url.raw_host` is normalized to always be lowercased, and is IDNA encoded. + + url = httpx.URL("http://中国.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + * `url.port` is either None or an integer. URLs that include the default port for + "http", "https", "ws", "wss", and "ftp" schemes have their port + normalized to `None`. + + assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80") + assert httpx.URL("http://example.com").port is None + assert httpx.URL("http://example.com:80").port is None + + * `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work + with `url.username` and `url.password` instead, which handle the URL escaping. + + * `url.raw_path` is raw bytes of both the path and query, without URL escaping. + This portion is used as the target when constructing HTTP requests. Usually you'll + want to work with `url.path` instead. + + * `url.query` is raw bytes, without URL escaping. A URL query string portion can + only be properly URL escaped when decoding the parameter names and values + themselves. + """ + + def __init__(self, url: "URL" | str = "", **kwargs: typing.Any) -> None: + if kwargs: + allowed = { + "scheme": str, + "username": str, + "password": str, + "userinfo": bytes, + "host": str, + "port": int, + "netloc": str, + "path": str, + "query": bytes, + "raw_path": bytes, + "fragment": str, + "params": object, + } + + # Perform type checking for all supported keyword arguments. + for key, value in kwargs.items(): + if key not in allowed: + message = f"{key!r} is an invalid keyword argument for URL()" + raise TypeError(message) + if value is not None and not isinstance(value, allowed[key]): + expected = allowed[key].__name__ + seen = type(value).__name__ + message = f"Argument {key!r} must be {expected} but got {seen}" + raise TypeError(message) + if isinstance(value, bytes): + kwargs[key] = value.decode("ascii") + + if "params" in kwargs: + # Replace any "params" keyword with the raw "query" instead. + # + # Ensure that empty params use `kwargs["query"] = None` rather + # than `kwargs["query"] = ""`, so that generated URLs do not + # include an empty trailing "?". + params = kwargs.pop("params") + kwargs["query"] = None if not params else str(QueryParams(params)) + + if isinstance(url, str): + self._uri_reference = urlparse(url, **kwargs) + elif isinstance(url, URL): + self._uri_reference = url._uri_reference.copy_with(**kwargs) + else: + raise TypeError( + "Invalid type for url. Expected str or httpx.URL," + f" got {type(url)}: {url!r}" + ) + + @property + def scheme(self) -> str: + """ + The URL scheme, such as "http", "https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme + + @property + def userinfo(self) -> bytes: + """ + The URL userinfo as a raw bytestring. + For example: b"jo%40email.com:a%20secret". + """ + return self._uri_reference.userinfo.encode("ascii") + + @property + def username(self) -> str: + """ + The URL username as a string, with URL decoding applied. + For example: "jo@email.com" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[0]) + + @property + def password(self) -> str: + """ + The URL password as a string, with URL decoding applied. + For example: "a secret" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[2]) + + @property + def host(self) -> str: + """ + The URL host as a string. + Always normalized to lowercase. Possibly IDNA encoded. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.host == "www.example.org" + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "xn--fiqs8s" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "xn--fiqs8s" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.host == "::ffff:192.168.0.1" + """ + return self._uri_reference.host + + @property + def port(self) -> int | None: + """ + The URL port as an integer. + + Note that the URL class performs port normalization as per the WHATWG spec. + Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always + treated as `None`. + + For example: + + assert httpx.URL("http://www.example.com") == httpx.URL("http://www.example.com:80") + assert httpx.URL("http://www.example.com:80").port is None + """ + return self._uri_reference.port + + @property + def netloc(self) -> str: + """ + Either `` or `:` as bytes. + Always normalized to lowercase, and IDNA encoded. + + This property may be used for generating the value of a request + "Host" header. + """ + return self._uri_reference.netloc + + @property + def path(self) -> str: + """ + The URL path as a string. Excluding the query string, and URL decoded. + + For example: + + url = httpx.URL("https://example.com/pa%20th") + assert url.path == "/pa th" + """ + path = self._uri_reference.path or "/" + return unquote(path) + + @property + def query(self) -> bytes: + """ + The URL query string, as raw bytes, excluding the leading b"?". + + This is necessarily a bytewise interface, because we cannot + perform URL decoding of this representation until we've parsed + the keys and values into a QueryParams instance. + + For example: + + url = httpx.URL("https://example.com/?filter=some%20search%20terms") + assert url.query == b"filter=some%20search%20terms" + """ + query = self._uri_reference.query or "" + return query.encode("ascii") + + @property + def params(self) -> "QueryParams": + """ + The URL query parameters, neatly parsed and packaged into an immutable + multidict representation. + """ + return QueryParams(self._uri_reference.query) + + @property + def target(self) -> str: + """ + The complete URL path and query string as raw bytes. + Used as the target when constructing HTTP requests. + + For example: + + GET /users?search=some%20text HTTP/1.1 + Host: www.example.org + Connection: close + """ + target = self._uri_reference.path or "/" + if self._uri_reference.query is not None: + target += "?" + self._uri_reference.query + return target + + @property + def fragment(self) -> str: + """ + The URL fragments, as used in HTML anchors. + As a string, without the leading '#'. + """ + return unquote(self._uri_reference.fragment or "") + + @property + def is_absolute_url(self) -> bool: + """ + Return `True` for absolute URLs such as 'http://example.com/path', + and `False` for relative URLs such as '/path'. + """ + # We don't use `.is_absolute` from `rfc3986` because it treats + # URLs with a fragment portion as not absolute. + # What we actually care about is if the URL provides + # a scheme and hostname to which connections should be made. + return bool(self._uri_reference.scheme and self._uri_reference.host) + + @property + def is_relative_url(self) -> bool: + """ + Return `False` for absolute URLs such as 'http://example.com/path', + and `True` for relative URLs such as '/path'. + """ + return not self.is_absolute_url + + def copy_with(self, **kwargs: typing.Any) -> "URL": + """ + Copy this URL, returning a new URL with some components altered. + Accepts the same set of parameters as the components that are made + available via properties on the `URL` class. + + For example: + + url = httpx.URL("https://www.example.com").copy_with( + username="jo@gmail.com", password="a secret" + ) + assert url == "https://jo%40email.com:a%20secret@www.example.com" + """ + return URL(self, **kwargs) + + def copy_set_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.copy_set(key, value)) + + def copy_append_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.copy_append(key, value)) + + def copy_remove_param(self, key: str) -> "URL": + return self.copy_with(params=self.params.copy_remove(key)) + + def copy_merge_params( + self, + params: "QueryParams" | dict[str, str | list[str]] | list[tuple[str, str]] | None, + ) -> "URL": + return self.copy_with(params=self.params.copy_update(params)) + + def join(self, url: "URL" | str) -> "URL": + """ + Return an absolute URL, using this URL as the base. + + Eg. + + url = httpx.URL("https://www.example.com/test") + url = url.join("/new/path") + assert url == "https://www.example.com/new/path" + """ + from urllib.parse import urljoin + + return URL(urljoin(str(self), str(URL(url)))) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, (URL, str)) and str(self) == str(URL(other)) + + def __str__(self) -> str: + return str(self._uri_reference) + + def __repr__(self) -> str: + return f"" + + +class QueryParams(typing.Mapping[str, str]): + """ + URL query parameters, as a multi-dict. + """ + + def __init__( + self, + params: ( + "QueryParams" | dict[str, str | list[str]] | list[tuple[str, str]] | str | None + ) = None, + ) -> None: + d: dict[str, list[str]] = {} + + if params is None: + d = {} + elif isinstance(params, str): + d = urldecode(params) + elif isinstance(params, QueryParams): + d = params.multi_dict() + elif isinstance(params, dict): + # Convert dict inputs like: + # {"a": "123", "b": ["456", "789"]} + # To dict inputs where values are always lists, like: + # {"a": ["123"], "b": ["456", "789"]} + d = {k: [v] if isinstance(v, str) else list(v) for k, v in params.items()} + else: + # Convert list inputs like: + # [("a", "123"), ("a", "456"), ("b", "789")] + # To a dict representation, like: + # {"a": ["123", "456"], "b": ["789"]} + for k, v in params: + d.setdefault(k, []).append(v) + + self._dict = d + + def keys(self) -> typing.KeysView[str]: + """ + Return all the keys in the query params. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.keys()) == ["a", "b"] + """ + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + """ + Return all the values in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.values()) == ["123", "789"] + """ + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return all items in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.items()) == [("a", "123"), ("b", "789")] + """ + return {k: v[0] for k, v in self._dict.items()}.items() + + def multi_items(self) -> list[tuple[str, str]]: + """ + Return all items in the query params. Allow duplicate keys to occur. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")] + """ + multi_items: list[tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def multi_dict(self) -> dict[str, list[str]]: + return {k: list(v) for k, v in self._dict.items()} + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get("a") == "123" + """ + if key in self._dict: + return self._dict[key][0] + return default + + def get_list(self, key: str) -> list[str]: + """ + Get all values from the query param for a given key. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get_list("a") == ["123", "456"] + """ + return list(self._dict.get(key, [])) + + def copy_set(self, key: str, value: str) -> "QueryParams": + """ + Return a new QueryParams instance, setting the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.set("a", "456") + assert q == httpx.QueryParams("a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[key] = [value] + return q + + def copy_append(self, key: str, value: str) -> "QueryParams": + """ + Return a new QueryParams instance, setting or appending the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.append("a", "456") + assert q == httpx.QueryParams("a=123&a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[key] = q.get_list(key) + [value] + return q + + def copy_remove(self, key: str) -> QueryParams: + """ + Return a new QueryParams instance, removing the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.remove("a") + assert q == httpx.QueryParams("") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict.pop(str(key), None) + return q + + def copy_update( + self, + params: ( + "QueryParams" | dict[str, str | list[str]] | list[tuple[str, str]] | None + ) = None, + ) -> "QueryParams": + """ + Return a new QueryParams instance, updated with. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.copy_update({"b": "456"}) + assert q == httpx.QueryParams("a=123&b=456") + + q = httpx.QueryParams("a=123") + q = q.copy_update({"a": "456", "b": "789"}) + assert q == httpx.QueryParams("a=456&b=789") + """ + q = QueryParams(params) + q._dict = {**self._dict, **q._dict} + return q + + def __getitem__(self, key: str) -> str: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self.multi_items()) == sorted(other.multi_items()) + + def __str__(self) -> str: + return urlencode(self.multi_dict()) + + def __repr__(self) -> str: + return f"" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..c26f6ba8 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,112 @@ +import json +import httpx +import pytest + + +def echo(request): + request.read() + response = httpx.Response(200, content=httpx.JSON({ + 'method': request.method, + 'query-params': dict(request.url.params.items()), + 'content-type': request.headers.get('Content-Type'), + 'json': json.loads(request.body) if request.body else None, + })) + return response + + +@pytest.fixture +def client(): + with httpx.Client() as client: + yield client + + +@pytest.fixture +def server(): + with httpx.serve_http(echo) as server: + yield server + + +def test_client(client): + assert repr(client) == "" + + +def test_get(client, server): + r = client.get(server.url) + assert r.status_code == 200 + assert r.body == b'{"method":"GET","query-params":{},"content-type":null,"json":null}' + assert r.text == '{"method":"GET","query-params":{},"content-type":null,"json":null}' + + +def test_post(client, server): + data = httpx.JSON({"data": 123}) + r = client.post(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'POST', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +def test_put(client, server): + data = httpx.JSON({"data": 123}) + r = client.put(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PUT', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +def test_patch(client, server): + data = httpx.JSON({"data": 123}) + r = client.patch(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PATCH', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +def test_delete(client, server): + r = client.delete(server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'DELETE', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +def test_request(client, server): + r = client.request("GET", server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'GET', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +def test_stream(client, server): + with client.stream("GET", server.url) as r: + assert r.status_code == 200 + r.read() + assert json.loads(r.body) == { + 'method': 'GET', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +def test_get_with_invalid_scheme(client): + with pytest.raises(ValueError): + client.get("nope://www.example.com") diff --git a/tests/test_content.py b/tests/test_content.py new file mode 100644 index 00000000..ae3158e9 --- /dev/null +++ b/tests/test_content.py @@ -0,0 +1,285 @@ +import httpx +import os +import tempfile + + + +# HTML + +def test_html(): + html = httpx.HTML("Hello, world") + + stream = html.encode() + content_type = html.content_type() + + assert stream.read() == b'Hello, world' + assert content_type == "text/html; charset='utf-8'" + + +# Text + +def test_text(): + text = httpx.Text("Hello, world") + + stream = text.encode() + content_type = text.content_type() + + assert stream.read() == b'Hello, world' + assert content_type == "text/plain; charset='utf-8'" + + +# JSON + +def test_json(): + data = httpx.JSON({'data': 123}) + + stream = data.encode() + content_type = data.content_type() + + assert stream.read() == b'{"data":123}' + assert content_type == "application/json" + + +# Form + +def test_form(): + f = httpx.Form("a=123&a=456&b=789") + assert str(f) == "a=123&a=456&b=789" + assert repr(f) == "" + assert f.multi_dict() == { + "a": ["123", "456"], + "b": ["789"] + } + + +def test_form_from_dict(): + f = httpx.Form({ + "a": ["123", "456"], + "b": "789" + }) + assert str(f) == "a=123&a=456&b=789" + assert repr(f) == "" + assert f.multi_dict() == { + "a": ["123", "456"], + "b": ["789"] + } + + +def test_form_from_list(): + f = httpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert str(f) == "a=123&a=456&b=789" + assert repr(f) == "" + assert f.multi_dict() == { + "a": ["123", "456"], + "b": ["789"] + } + + +def test_empty_form(): + f = httpx.Form() + assert str(f) == '' + assert repr(f) == "" + assert f.multi_dict() == {} + + +def test_form_accessors(): + f = httpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert "a" in f + assert "A" not in f + assert "c" not in f + assert f["a"] == "123" + assert f.get("a") == "123" + assert f.get("nope", default=None) is None + + +def test_form_dict(): + f = httpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert list(f.keys()) == ["a", "b"] + assert list(f.values()) == ["123", "789"] + assert list(f.items()) == [("a", "123"), ("b", "789")] + assert list(f) == ["a", "b"] + assert dict(f) == {"a": "123", "b": "789"} + + +def test_form_multidict(): + f = httpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert f.get_list("a") == ["123", "456"] + assert f.multi_items() == [("a", "123"), ("a", "456"), ("b", "789")] + assert f.multi_dict() == {"a": ["123", "456"], "b": ["789"]} + + +def test_form_builtins(): + f = httpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert len(f) == 2 + assert bool(f) + assert hash(f) + assert f == httpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + + +def test_form_copy_operations(): + f = httpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert f.copy_set("a", "abc") == httpx.Form([("a", "abc"), ("b", "789")]) + assert f.copy_append("a", "abc") == httpx.Form([("a", "123"), ("a", "456"), ("a", "abc"), ("b", "789")]) + assert f.copy_remove("a") == httpx.Form([("b", "789")]) + + +def test_form_encode(): + form = httpx.Form({'email': 'address@example.com'}) + assert form['email'] == "address@example.com" + + stream = form.encode() + content_type = form.content_type() + + assert stream.read() == b"email=address%40example.com" + assert content_type == "application/x-www-form-urlencoded" + + +# Files + +def test_files(): + f = httpx.Files() + assert f.multi_dict() == {} + assert repr(f) == "" + + +def test_files_from_dict(): + f = httpx.Files({ + "a": [ + httpx.File("123.json"), + httpx.File("456.json"), + ], + "b": httpx.File("789.json") + }) + assert f.multi_dict() == { + "a": [ + httpx.File("123.json"), + httpx.File("456.json"), + ], + "b": [ + httpx.File("789.json"), + ] + } + assert repr(f) == ( + "), ('a', ), ('b', )]>" + ) + + + +def test_files_from_list(): + f = httpx.Files([ + ("a", httpx.File("123.json")), + ("a", httpx.File("456.json")), + ("b", httpx.File("789.json")) + ]) + assert f.multi_dict() == { + "a": [ + httpx.File("123.json"), + httpx.File("456.json"), + ], + "b": [ + httpx.File("789.json"), + ] + } + assert repr(f) == ( + "), ('a', ), ('b', )]>" + ) + + +def test_files_accessors(): + f = httpx.Files([ + ("a", httpx.File("123.json")), + ("a", httpx.File("456.json")), + ("b", httpx.File("789.json")) + ]) + assert "a" in f + assert "A" not in f + assert "c" not in f + assert f["a"] == httpx.File("123.json") + assert f.get("a") == httpx.File("123.json") + assert f.get("nope", default=None) is None + + +def test_files_dict(): + f = httpx.Files([ + ("a", httpx.File("123.json")), + ("a", httpx.File("456.json")), + ("b", httpx.File("789.json")) + ]) + assert list(f.keys()) == ["a", "b"] + assert list(f.values()) == [httpx.File("123.json"), httpx.File("789.json")] + assert list(f.items()) == [("a", httpx.File("123.json")), ("b", httpx.File("789.json"))] + assert list(f) == ["a", "b"] + assert dict(f) == {"a": httpx.File("123.json"), "b": httpx.File("789.json")} + + +def test_files_multidict(): + f = httpx.Files([ + ("a", httpx.File("123.json")), + ("a", httpx.File("456.json")), + ("b", httpx.File("789.json")) + ]) + assert f.get_list("a") == [ + httpx.File("123.json"), + httpx.File("456.json"), + ] + assert f.multi_items() == [ + ("a", httpx.File("123.json")), + ("a", httpx.File("456.json")), + ("b", httpx.File("789.json")), + ] + assert f.multi_dict() == { + "a": [ + httpx.File("123.json"), + httpx.File("456.json"), + ], + "b": [ + httpx.File("789.json"), + ] + } + + +def test_files_builtins(): + f = httpx.Files([ + ("a", httpx.File("123.json")), + ("a", httpx.File("456.json")), + ("b", httpx.File("789.json")) + ]) + assert len(f) == 2 + assert bool(f) + assert f == httpx.Files([ + ("a", httpx.File("123.json")), + ("a", httpx.File("456.json")), + ("b", httpx.File("789.json")), + ]) + + +def test_multipart(): + with tempfile.NamedTemporaryFile() as f: + f.write(b"Hello, world") + f.seek(0) + + multipart = httpx.MultiPart( + form={'email': 'me@example.com'}, + files={'upload': httpx.File(f.name)}, + boundary='BOUNDARY', + ) + assert multipart.form['email'] == "me@example.com" + assert multipart.files['upload'] == httpx.File(f.name) + + fname = os.path.basename(f.name).encode('utf-8') + stream = multipart.encode() + content_type = multipart.content_type() + + content_type == "multipart/form-data; boundary=BOUNDARY" + content = stream.read() + assert content == ( + b'--BOUNDARY\r\n' + b'Content-Disposition: form-data; name="email"\r\n' + b'\r\n' + b'me@example.com\r\n' + b'--BOUNDARY\r\n' + b'Content-Disposition: form-data; name="upload"; filename="' + fname + b'"\r\n' + b'\r\n' + b'Hello, world\r\n' + b'--BOUNDARY--\r\n' + ) diff --git a/tests/test_headers.py b/tests/test_headers.py new file mode 100644 index 00000000..6ebb99dc --- /dev/null +++ b/tests/test_headers.py @@ -0,0 +1,109 @@ +import httpx +import pytest + + +def test_headers_from_dict(): + headers = httpx.Headers({ + 'Content-Length': '1024', + 'Content-Type': 'text/plain; charset=utf-8', + }) + assert headers['Content-Length'] == '1024' + assert headers['Content-Type'] == 'text/plain; charset=utf-8' + + +def test_headers_from_list(): + headers = httpx.Headers([ + ('Location', 'https://www.example.com'), + ('Set-Cookie', 'session_id=3498jj489jhb98jn'), + ]) + assert headers['Location'] == 'https://www.example.com' + assert headers['Set-Cookie'] == 'session_id=3498jj489jhb98jn' + + +def test_header_keys(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.keys()) == ["Accept", "User-Agent"] + + +def test_header_values(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.values()) == ["*/*", "python/httpx"] + + +def test_header_items(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.items()) == [("Accept", "*/*"), ("User-Agent", "python/httpx")] + + +def test_header_get(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert h.get("User-Agent") == "python/httpx" + assert h.get("user-agent") == "python/httpx" + assert h.get("missing") is None + + +def test_header_copy_set(): + h = httpx.Headers({"Expires": "0"}) + h = h.copy_set("Expires", "Wed, 21 Oct 2015 07:28:00 GMT") + assert h == httpx.Headers({"Expires": "Wed, 21 Oct 2015 07:28:00 GMT"}) + + h = httpx.Headers({"Expires": "0"}) + h = h.copy_set("expires", "Wed, 21 Oct 2015 07:28:00 GMT") + assert h == httpx.Headers({"Expires": "Wed, 21 Oct 2015 07:28:00 GMT"}) + + +def test_header_copy_remove(): + h = httpx.Headers({"Accept": "*/*"}) + h = h.copy_remove("Accept") + assert h == httpx.Headers({}) + + h = httpx.Headers({"Accept": "*/*"}) + h = h.copy_remove("accept") + assert h == httpx.Headers({}) + + +def test_header_getitem(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert h["User-Agent"] == "python/httpx" + assert h["user-agent"] == "python/httpx" + with pytest.raises(KeyError): + h["missing"] + + +def test_header_contains(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert "User-Agent" in h + assert "user-agent" in h + assert "missing" not in h + + +def test_header_bool(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert bool(h) + h = httpx.Headers() + assert not bool(h) + + +def test_header_iter(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert [k for k in h] == ["Accept", "User-Agent"] + + +def test_header_len(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert len(h) == 2 + + +def test_header_repr(): + h = httpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert repr(h) == "" + + +def test_header_invalid_name(): + with pytest.raises(ValueError): + httpx.Headers({"Accept\n": "*/*"}) + + +def test_header_invalid_value(): + with pytest.raises(ValueError): + httpx.Headers({"Accept": "*/*\n"}) diff --git a/tests/test_network.py b/tests/test_network.py new file mode 100644 index 00000000..e6ce9256 --- /dev/null +++ b/tests/test_network.py @@ -0,0 +1,101 @@ +import httpx +import pytest + + +def echo(stream): + while buffer := stream.read(): + stream.write(buffer) + + +@pytest.fixture +def server(): + net = httpx.NetworkBackend() + with net.serve("127.0.0.1", 8080, echo) as server: + yield server + + +def test_network_backend(): + net = httpx.NetworkBackend() + assert repr(net) == "" + + +def test_network_backend_connect(server): + net = httpx.NetworkBackend() + stream = net.connect(server.host, server.port) + try: + assert repr(stream) == f"" + stream.write(b"Hello, world.") + content = stream.read() + assert content == b"Hello, world." + finally: + stream.close() + + +def test_network_backend_context_managed(server): + net = httpx.NetworkBackend() + with net.connect(server.host, server.port) as stream: + stream.write(b"Hello, world.") + content = stream.read() + assert content == b"Hello, world." + assert repr(stream) == f"" + + +def test_network_backend_timeout(server): + net = httpx.NetworkBackend() + with httpx.timeout(0.0): + with pytest.raises(TimeoutError): + with net.connect(server.host, server.port) as stream: + pass + + with httpx.timeout(10.0): + with net.connect(server.host, server.port) as stream: + pass + + +# >>> net = httpx.NetworkBackend() +# >>> stream = net.connect("dev.encode.io", 80) +# >>> try: +# >>> ... +# >>> finally: +# >>> stream.close() +# >>> stream +# + +# import httpx +# import ssl +# import truststore + +# net = httpx.NetworkBackend() +# ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +# req = b'\r\n'.join([ +# b'GET / HTTP/1.1', +# b'Host: www.example.com', +# b'User-Agent: python/dev', +# b'Connection: close', +# b'', +# ]) + +# # Use a 10 second overall timeout for the entire request/response. +# with timeout(10.0): +# # Use a 3 second timeout for the initial connection. +# with timeout(3.0) as t: +# # Open the connection & establish SSL. +# with net.open_stream("www.example.com", 443) as stream: +# stream.start_tls(ctx, hostname="www.example.com") +# t.cancel() +# # Send the request & read the response. +# stream.write(req) +# buffer = [] +# while part := stream.read(): +# buffer.append(part) +# resp = b''.join(buffer) + + +# def test_fixture(tcp_echo_server): +# host, port = (tcp_echo_server.host, tcp_echo_server.port) + +# net = httpx.NetworkBackend() +# with net.connect(host, port) as stream: +# stream.write(b"123") +# buffer = stream.read() +# assert buffer == b"123" diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 00000000..e2a321e2 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,748 @@ +import httpx +import pytest + + +class TrickleIO(httpx.Stream): + def __init__(self, stream: httpx.Stream): + self._stream = stream + + def read(self, size) -> bytes: + return self._stream.read(1) + + def write(self, data: bytes) -> None: + self._stream.write(data) + + def close(self) -> None: + self._stream.close() + + +def test_parser(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"POST", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + p.send_body(b'{"msg": "hello, world"}') + p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + body = p.recv_body() + terminator = p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b'hello, world' + assert terminator == b'' + + assert not p.is_idle() + p.complete() + assert p.is_idle() + + +def test_parser_server(): + stream = httpx.DuplexStream( + b"GET / HTTP/1.1\r\n" + b"Host: www.example.com\r\n" + b"\r\n" + ) + + p = httpx.HTTPParser(stream, mode='SERVER') + method, target, protocol = p.recv_method_line() + headers = p.recv_headers() + body = p.recv_body() + + assert method == b'GET' + assert target == b'/' + assert protocol == b'HTTP/1.1' + assert headers == [ + (b'Host', b'www.example.com'), + ] + assert body == b'' + + p.send_status_line(b"HTTP/1.1", 200, b"OK") + p.send_headers([ + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + p.send_body(b'{"msg": "hello, world"}') + p.send_body(b'') + + assert stream.input_bytes() == ( + b"GET / HTTP/1.1\r\n" + b"Host: www.example.com\r\n" + b"\r\n" + ) + assert stream.output_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + assert not p.is_idle() + p.complete() + assert p.is_idle() + + +def test_parser_trickle(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = httpx.HTTPParser(TrickleIO(stream), mode='CLIENT') + p.send_method_line(b"POST", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + p.send_body(b'{"msg": "hello, world"}') + p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + body = p.recv_body() + terminator = p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b'hello, world' + assert terminator == b'' + + +def test_parser_transfer_encoding_chunked(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"POST", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Transfer-Encoding", b"chunked"), + ]) + p.send_body(b'{"msg": "hello, world"}') + p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b'17\r\n' + b'{"msg": "hello, world"}\r\n' + b'0\r\n\r\n' + ) + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + body = p.recv_body() + terminator = p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Type', b'text/plain'), + (b'Transfer-Encoding', b'chunked'), + ] + assert body == b'hello, world' + assert terminator == b'' + + +def test_parser_transfer_encoding_chunked_trickle(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + + p = httpx.HTTPParser(TrickleIO(stream), mode='CLIENT') + p.send_method_line(b"POST", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Transfer-Encoding", b"chunked"), + ]) + p.send_body(b'{"msg": "hello, world"}') + p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b'17\r\n' + b'{"msg": "hello, world"}\r\n' + b'0\r\n\r\n' + ) + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + body = p.recv_body() + terminator = p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Type', b'text/plain'), + (b'Transfer-Encoding', b'chunked'), + ] + assert body == b'hello, world' + assert terminator == b'' + + +def test_parser_repr(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + assert repr(p) == "" + + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + assert repr(p) == "" + + p.send_headers([(b"Host", b"example.com")]) + assert repr(p) == "" + + p.send_body(b'') + assert repr(p) == "" + + p.recv_status_line() + assert repr(p) == "" + + p.recv_headers() + assert repr(p) == "" + + p.recv_body() + assert repr(p) == "" + + p.recv_body() + assert repr(p) == "" + + p.complete() + assert repr(p) == "" + + +def test_parser_invalid_transitions(): + stream = httpx.DuplexStream() + + with pytest.raises(httpx.ProtocolError): + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b'GET', b'/', b'HTTP/1.1') + p.send_method_line(b'GET', b'/', b'HTTP/1.1') + + with pytest.raises(httpx.ProtocolError): + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_headers([]) + + with pytest.raises(httpx.ProtocolError): + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_body(b'') + + with pytest.raises(httpx.ProtocolError): + reader = httpx.ByteStream(b'HTTP/1.1 200 OK\r\n') + p = httpx.HTTPParser(stream, mode='CLIENT') + p.recv_status_line() + + with pytest.raises(httpx.ProtocolError): + p = httpx.HTTPParser(stream, mode='CLIENT') + p.recv_headers() + + with pytest.raises(httpx.ProtocolError): + p = httpx.HTTPParser(stream, mode='CLIENT') + p.recv_body() + + +def test_parser_invalid_status_line(): + # ... + stream = httpx.DuplexStream(b'...') + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + msg = 'Stream closed early reading response status line' + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_status_line() + + # ... + stream = httpx.DuplexStream(b'HTTP/1.1' + b'x' * 5000) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + msg = 'Exceeded maximum size reading response status line' + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_status_line() + + # ... + stream = httpx.DuplexStream(b'HTTP/1.1' + b'x' * 5000 + b'\r\n') + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + msg = 'Exceeded maximum size reading response status line' + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_status_line() + + +def test_parser_sent_unsupported_protocol(): + # Currently only HTTP/1.1 is supported. + stream = httpx.DuplexStream() + + p = httpx.HTTPParser(stream, mode='CLIENT') + msg = 'Sent unsupported protocol version' + with pytest.raises(httpx.ProtocolError, match=msg): + p.send_method_line(b"GET", b"/", b"HTTP/1.0") + + +def test_parser_recv_unsupported_protocol(): + # Currently only HTTP/1.1 is supported. + stream = httpx.DuplexStream(b"HTTP/1.0 200 OK\r\n") + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + msg = 'Received unsupported protocol version' + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_status_line() + + +def test_parser_large_body(): + body = b"x" * 6988 + + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 6988\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + body + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + # Checkout our buffer sizes. + p.recv_status_line() + p.recv_headers() + assert len(p.recv_body()) == 4096 + assert len(p.recv_body()) == 2892 + assert len(p.recv_body()) == 0 + + +def test_parser_stream_large_body(): + body = b"x" * 6956 + + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"1b2c\r\n" + body + b'\r\n0\r\n\r\n' + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + # Checkout our buffer sizes. + p.recv_status_line() + p.recv_headers() + # assert len(p.recv_body()) == 4096 + # assert len(p.recv_body()) == 2860 + assert len(p.recv_body()) == 6956 + assert len(p.recv_body()) == 0 + + +def test_parser_not_enough_data_received(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 188\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"truncated" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + # Checkout our buffer sizes. + p.recv_status_line() + p.recv_headers() + p.recv_body() + msg = 'Not enough data received for declared Content-Length' + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_body() + + +def test_parser_not_enough_data_sent(): + stream = httpx.DuplexStream() + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"POST", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + p.send_body(b'{"msg": "too smol"}') + msg = 'Not enough data sent for declared Content-Length' + with pytest.raises(httpx.ProtocolError, match=msg): + p.send_body(b'') + + +def test_parser_too_much_data_sent(): + stream = httpx.DuplexStream() + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"POST", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"19"), + ]) + msg = 'Too much data sent for declared Content-Length' + with pytest.raises(httpx.ProtocolError, match=msg): + p.send_body(b'{"msg": "too chonky"}') + + +def test_parser_missing_host_header(): + stream = httpx.DuplexStream() + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + msg = "Request missing 'Host' header" + with pytest.raises(httpx.ProtocolError, match=msg): + p.send_headers([]) + + +def test_client_connection_close(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + p.send_body(b'') + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + body = p.recv_body() + terminator = p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b"OK" + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b"hello, world" + assert terminator == b"" + + assert repr(p) == "" + + p.complete() + assert repr(p) == "" + assert p.is_closed() + + +def test_server_connection_close(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"Connection: close\r\n" + b"\r\n" + b"hello, world" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + body = p.recv_body() + terminator = p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b"OK" + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + (b'Connection', b'close'), + ] + assert body == b"hello, world" + assert terminator == b"" + + assert repr(p) == "" + p.complete() + assert repr(p) == "" + + +def test_invalid_status_code(): + stream = httpx.DuplexStream( + b"HTTP/1.1 99 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + p.send_body(b'') + + msg = "Received invalid status code" + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_status_line() + + +def test_1xx_status_code(): + stream = httpx.DuplexStream( + b"HTTP/1.1 103 Early Hints\r\n" + b"Link: ; rel=preload; as=style\r\n" + b"Link: ; rel=preload; as=script\r\n" + b"\r\n" + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([(b"Host", b"example.com")]) + p.send_body(b'') + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + + assert protocol == b'HTTP/1.1' + assert code == 103 + assert reason_phase == b'Early Hints' + assert headers == [ + (b'Link', b'; rel=preload; as=style'), + (b'Link', b'; rel=preload; as=script'), + ] + + protocol, code, reason_phase = p.recv_status_line() + headers = p.recv_headers() + body = p.recv_body() + terminator = p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b"OK" + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b"hello, world" + assert terminator == b"" + + +def test_received_invalid_content_length(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: -999\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + p.send_body(b'') + + p.recv_status_line() + msg = "Received invalid Content-Length" + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_headers() + + +def test_sent_invalid_content_length(): + stream = httpx.DuplexStream() + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + msg = "Sent invalid Content-Length" + with pytest.raises(httpx.ProtocolError, match=msg): + # Limited to 20 digits. + # 100 million terabytes should be enough for anyone. + p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Length", b"100000000000000000000"), + ]) + + +def test_received_invalid_characters_in_chunk_size(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"0xFF\r\n..." + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + p.send_body(b'') + + p.recv_status_line() + p.recv_headers() + msg = "Received invalid chunk size" + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_body() + + +def test_received_oversized_chunk(): + stream = httpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"FFFFFFFFFF\r\n..." + ) + + p = httpx.HTTPParser(stream, mode='CLIENT') + p.send_method_line(b"GET", b"/", b"HTTP/1.1") + p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + p.send_body(b'') + + p.recv_status_line() + p.recv_headers() + msg = "Received invalid chunk size" + with pytest.raises(httpx.ProtocolError, match=msg): + p.recv_body() diff --git a/tests/test_pool.py b/tests/test_pool.py new file mode 100644 index 00000000..04cd0246 --- /dev/null +++ b/tests/test_pool.py @@ -0,0 +1,126 @@ +import httpx +import pytest + + +def hello_world(request): + content = httpx.Text('Hello, world.') + return httpx.Response(200, content=content) + + +@pytest.fixture +def server(): + with httpx.serve_http(hello_world) as server: + yield server + + +def test_connection_pool_request(server): + with httpx.ConnectionPool() as pool: + assert repr(pool) == "" + assert len(pool.connections) == 0 + + r = pool.request("GET", server.url) + + assert r.status_code == 200 + assert repr(pool) == "" + assert len(pool.connections) == 1 + + +def test_connection_pool_connection_close(server): + with httpx.ConnectionPool() as pool: + assert repr(pool) == "" + assert len(pool.connections) == 0 + + r = pool.request("GET", server.url, headers={"Connection": "close"}) + + # TODO: Really we want closed connections proactively removed from the pool, + assert r.status_code == 200 + assert repr(pool) == "" + assert len(pool.connections) == 1 + + +def test_connection_pool_stream(server): + with httpx.ConnectionPool() as pool: + assert repr(pool) == "" + assert len(pool.connections) == 0 + + with pool.stream("GET", server.url) as r: + assert r.status_code == 200 + assert repr(pool) == "" + assert len(pool.connections) == 1 + r.read() + + assert repr(pool) == "" + assert len(pool.connections) == 1 + + +def test_connection_pool_cannot_request_after_closed(server): + with httpx.ConnectionPool() as pool: + pool + + with pytest.raises(RuntimeError): + pool.request("GET", server.url) + + +def test_connection_pool_should_have_managed_lifespan(server): + pool = httpx.ConnectionPool() + with pytest.warns(UserWarning): + del pool + + +def test_connection_request(server): + with httpx.open_connection(server.url) as conn: + assert repr(conn) == f"" + + r = conn.request("GET", "/") + + assert r.status_code == 200 + assert repr(conn) == f"" + + +def test_connection_stream(server): + with httpx.open_connection(server.url) as conn: + assert repr(conn) == f"" + with conn.stream("GET", "/") as r: + assert r.status_code == 200 + assert repr(conn) == f"" + r.read() + assert repr(conn) == f"" + + +# # with httpx.open_connection("https://www.example.com/") as conn: +# # r = conn.request("GET", "/") + +# # >>> pool = httpx.ConnectionPool() +# # >>> pool +# # + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> res = pool.request("GET", "https://www.example.com") +# # >>> res, pool +# # , + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> with pool.stream("GET", "https://www.example.com") as res: +# # >>> res, pool +# # , + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> req = httpx.Request("GET", "https://www.example.com") +# # >>> with pool.send(req) as res: +# # >>> res.body() +# # >>> res, pool +# # , + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> pool.close() +# # + +# # with httpx.open_connection("https://www.example.com/") as conn: +# # with conn.upgrade("GET", "/feed", {"Upgrade": "WebSocket") as stream: +# # ... + +# # with httpx.open_connection("http://127.0.0.1:8080") as conn: +# # with conn.upgrade("CONNECT", "www.encode.io:443") as stream: +# # stream.start_tls(ctx, hostname="www.encode.io") +# # ... + diff --git a/tests/test_quickstart.py b/tests/test_quickstart.py new file mode 100644 index 00000000..55c34b1b --- /dev/null +++ b/tests/test_quickstart.py @@ -0,0 +1,78 @@ +import json +import httpx +import pytest + + +def echo(request): + request.read() + response = httpx.Response(200, content=httpx.JSON({ + 'method': request.method, + 'query-params': dict(request.url.params.items()), + 'content-type': request.headers.get('Content-Type'), + 'json': json.loads(request.body) if request.body else None, + })) + return response + + +@pytest.fixture +def server(): + with httpx.serve_http(echo) as server: + yield server + + +def test_get(server): + r = httpx.get(server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'GET', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +def test_post(server): + data = httpx.JSON({"data": 123}) + r = httpx.post(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'POST', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +def test_put(server): + data = httpx.JSON({"data": 123}) + r = httpx.put(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PUT', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +def test_patch(server): + data = httpx.JSON({"data": 123}) + r = httpx.patch(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PATCH', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +def test_delete(server): + r = httpx.delete(server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'DELETE', + 'query-params': {}, + 'content-type': None, + 'json': None, + } diff --git a/tests/test_request.py b/tests/test_request.py new file mode 100644 index 00000000..a69e1d13 --- /dev/null +++ b/tests/test_request.py @@ -0,0 +1,79 @@ +import httpx + + +class ByteIterator: + def __init__(self, buffer=b""): + self._buffer = buffer + + def next(self) -> bytes: + buffer = self._buffer + self._buffer = b'' + return buffer + + +def test_request(): + r = httpx.Request("GET", "https://example.com") + + assert repr(r) == "" + assert r.method == "GET" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com" + } + assert r.read() == b"" + +def test_request_bytes(): + content = b"Hello, world" + r = httpx.Request("POST", "https://example.com", content=content) + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Content-Length": "12", + } + assert r.read() == b"Hello, world" + + +def test_request_stream(): + i = ByteIterator(b"Hello, world") + stream = httpx.HTTPStream(i.next, None) + r = httpx.Request("POST", "https://example.com", content=stream) + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Transfer-Encoding": "chunked", + } + assert r.read() == b"Hello, world" + + +def test_request_json(): + data = httpx.JSON({"msg": "Hello, world"}) + r = httpx.Request("POST", "https://example.com", content=data) + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Content-Length": "22", + "Content-Type": "application/json", + } + assert r.read() == b'{"msg":"Hello, world"}' + + +def test_request_empty_post(): + r = httpx.Request("POST", "https://example.com") + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Content-Length": "0", + } + assert r.read() == b'' diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 00000000..d25ebeb2 --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,64 @@ +import httpx + + +class ByteIterator: + def __init__(self, buffer=b""): + self._buffer = buffer + + def next(self) -> bytes: + buffer = self._buffer + self._buffer = b'' + return buffer + + +def test_response(): + r = httpx.Response(200) + + assert repr(r) == "" + assert r.status_code == 200 + assert r.headers == {'Content-Length': '0'} + assert r.read() == b"" + + +def test_response_204(): + r = httpx.Response(204) + + assert repr(r) == "" + assert r.status_code == 204 + assert r.headers == {} + assert r.read() == b"" + + +def test_response_bytes(): + content = b"Hello, world" + r = httpx.Response(200, content=content) + + assert repr(r) == "" + assert r.headers == { + "Content-Length": "12", + } + assert r.read() == b"Hello, world" + + +def test_response_stream(): + i = ByteIterator(b"Hello, world") + stream = httpx.HTTPStream(i.next, None) + r = httpx.Response(200, content=stream) + + assert repr(r) == "" + assert r.headers == { + "Transfer-Encoding": "chunked", + } + assert r.read() == b"Hello, world" + + +def test_response_json(): + data = httpx.JSON({"msg": "Hello, world"}) + r = httpx.Response(200, content=data) + + assert repr(r) == "" + assert r.headers == { + "Content-Length": "22", + "Content-Type": "application/json", + } + assert r.read() == b'{"msg":"Hello, world"}' diff --git a/tests/test_streams.py b/tests/test_streams.py new file mode 100644 index 00000000..80537610 --- /dev/null +++ b/tests/test_streams.py @@ -0,0 +1,82 @@ +import pytest +import httpx + + +def test_stream(): + i = httpx.Stream() + with pytest.raises(NotImplementedError): + i.read() + + with pytest.raises(NotImplementedError): + i.close() + + i.size == None + + +def test_bytestream(): + data = b'abc' + s = httpx.ByteStream(data) + assert s.size == 3 + assert s.read() == b'abc' + + s = httpx.ByteStream(data) + assert s.read(1) == b'a' + assert s.read(1) == b'b' + assert s.read(1) == b'c' + assert s.read(1) == b'' + + +def test_filestream(tmp_path): + path = tmp_path / "example.txt" + path.write_bytes(b"hello world") + + with httpx.FileStream(path) as s: + assert s.size == 11 + assert s.read() == b'hello world' + + with httpx.FileStream(path) as s: + assert s.read(5) == b'hello' + assert s.read(5) == b' worl' + assert s.read(5) == b'd' + assert s.read(5) == b'' + + with httpx.FileStream(path) as s: + assert s.read(5) == b'hello' + + + +def test_multipartstream(tmp_path): + path = tmp_path / 'example.txt' + path.write_bytes(b'hello world' + b'x' * 50) + + expected = b''.join([ + b'--boundary\r\n', + b'Content-Disposition: form-data; name="email"\r\n', + b'\r\n', + b'heya@example.com\r\n', + b'--boundary\r\n', + b'Content-Disposition: form-data; name="upload"; filename="example.txt"\r\n', + b'\r\n', + b'hello world' + ( b'x' * 50) + b'\r\n', + b'--boundary--\r\n', + ]) + + form = [('email', 'heya@example.com')] + files = [('upload', str(path))] + with httpx.MultiPartStream(form, files, boundary='boundary') as s: + assert s.size is None + assert s.read() == expected + + with httpx.MultiPartStream(form, files, boundary='boundary') as s: + assert s.read(50) == expected[:50] + assert s.read(50) == expected[50:100] + assert s.read(50) == expected[100:150] + assert s.read(50) == expected[150:200] + assert s.read(50) == expected[200:250] + + with httpx.MultiPartStream(form, files, boundary='boundary') as s: + assert s.read(50) == expected[:50] + assert s.read(50) == expected[50:100] + assert s.read(50) == expected[100:150] + assert s.read(50) == expected[150:200] + s.close() # test close during open file diff --git a/tests/test_urlencode.py b/tests/test_urlencode.py new file mode 100644 index 00000000..42ba45ac --- /dev/null +++ b/tests/test_urlencode.py @@ -0,0 +1,33 @@ +import httpx + + +def test_urlencode(): + qs = "a=name%40example.com&a=456&b=7+8+9&c" + d = httpx.urldecode(qs) + assert d == { + "a": ["name@example.com", "456"], + "b": ["7 8 9"], + "c": [""] + } + + +def test_urldecode(): + d = { + "a": ["name@example.com", "456"], + "b": ["7 8 9"], + "c": [""] + } + qs = httpx.urlencode(d) + assert qs == "a=name%40example.com&a=456&b=7+8+9&c=" + + +def test_urlencode_empty(): + qs = "" + d = httpx.urldecode(qs) + assert d == {} + + +def test_urldecode_empty(): + d = {} + qs = httpx.urlencode(d) + assert qs == "" diff --git a/tests/test_urls.py b/tests/test_urls.py new file mode 100644 index 00000000..ad729352 --- /dev/null +++ b/tests/test_urls.py @@ -0,0 +1,164 @@ +import httpx +import pytest + + +def test_url(): + url = httpx.URL('https://www.example.com/') + assert str(url) == "https://www.example.com/" + + +def test_url_repr(): + url = httpx.URL('https://www.example.com/') + assert repr(url) == "" + + +def test_url_params(): + url = httpx.URL('https://www.example.com/', params={"a": "b", "c": "d"}) + assert str(url) == "https://www.example.com/?a=b&c=d" + + +def test_url_normalisation(): + url = httpx.URL('https://www.EXAMPLE.com:443/path/../main') + assert str(url) == 'https://www.example.com/main' + + +def test_url_relative(): + url = httpx.URL('/README.md') + assert str(url) == '/README.md' + + +def test_url_escaping(): + url = httpx.URL('https://example.com/path to here?search=🦋') + assert str(url) == 'https://example.com/path%20to%20here?search=%F0%9F%A6%8B' + + +def test_url_components(): + url = httpx.URL(scheme="https", host="example.com", path="/") + assert str(url) == 'https://example.com/' + + +# QueryParams + +def test_queryparams(): + params = httpx.QueryParams({"color": "black", "size": "medium"}) + assert str(params) == 'color=black&size=medium' + + +def test_queryparams_repr(): + params = httpx.QueryParams({"color": "black", "size": "medium"}) + assert repr(params) == "" + + +def test_queryparams_list_of_values(): + params = httpx.QueryParams({"filter": ["60GHz", "75GHz", "100GHz"]}) + assert str(params) == 'filter=60GHz&filter=75GHz&filter=100GHz' + + +def test_queryparams_from_str(): + params = httpx.QueryParams("color=black&size=medium") + assert str(params) == 'color=black&size=medium' + + +def test_queryparams_access(): + params = httpx.QueryParams("sort_by=published&author=natalie") + assert params["sort_by"] == 'published' + + +def test_queryparams_escaping(): + params = httpx.QueryParams({"email": "user@example.com", "search": "How HTTP works!"}) + assert str(params) == 'email=user%40example.com&search=How+HTTP+works%21' + + +def test_queryparams_empty(): + q = httpx.QueryParams({"a": ""}) + assert str(q) == "a=" + + q = httpx.QueryParams("a=") + assert str(q) == "a=" + + q = httpx.QueryParams("a") + assert str(q) == "a=" + + +def test_queryparams_set(): + q = httpx.QueryParams("a=123") + q = q.copy_set("a", "456") + assert q == httpx.QueryParams("a=456") + + +def test_queryparams_append(): + q = httpx.QueryParams("a=123") + q = q.copy_append("a", "456") + assert q == httpx.QueryParams("a=123&a=456") + + +def test_queryparams_remove(): + q = httpx.QueryParams("a=123") + q = q.copy_remove("a") + assert q == httpx.QueryParams("") + + +def test_queryparams_merge(): + q = httpx.QueryParams("a=123") + q = q.copy_update({"b": "456"}) + assert q == httpx.QueryParams("a=123&b=456") + q = q.copy_update({"a": "000", "c": "789"}) + assert q == httpx.QueryParams("a=000&b=456&c=789") + + +def test_queryparams_are_hashable(): + params = ( + httpx.QueryParams("a=123"), + httpx.QueryParams({"a": "123"}), + httpx.QueryParams("b=456"), + httpx.QueryParams({"b": "456"}), + ) + + assert len(set(params)) == 2 + + +@pytest.mark.parametrize( + "source", + [ + "a=123&a=456&b=789", + {"a": ["123", "456"], "b": "789"}, + {"a": ("123", "456"), "b": "789"}, + [("a", "123"), ("a", "456"), ("b", "789")], + (("a", "123"), ("a", "456"), ("b", "789")), + ], +) +def test_queryparams_misc(source): + q = httpx.QueryParams(source) + assert "a" in q + assert "A" not in q + assert "c" not in q + assert q["a"] == "123" + assert q.get("a") == "123" + assert q.get("nope", default=None) is None + assert q.get_list("a") == ["123", "456"] + assert bool(q) + + assert list(q.keys()) == ["a", "b"] + assert list(q.values()) == ["123", "789"] + assert list(q.items()) == [("a", "123"), ("b", "789")] + assert len(q) == 2 + assert list(q) == ["a", "b"] + assert dict(q) == {"a": "123", "b": "789"} + assert str(q) == "a=123&a=456&b=789" + assert httpx.QueryParams({"a": "123", "b": "456"}) == httpx.QueryParams( + [("a", "123"), ("b", "456")] + ) + assert httpx.QueryParams({"a": "123", "b": "456"}) == httpx.QueryParams( + "a=123&b=456" + ) + assert httpx.QueryParams({"a": "123", "b": "456"}) == httpx.QueryParams( + {"b": "456", "a": "123"} + ) + assert httpx.QueryParams() == httpx.QueryParams({}) + assert httpx.QueryParams([("a", "123"), ("a", "456")]) == httpx.QueryParams( + "a=123&a=456" + ) + assert httpx.QueryParams({"a": "123", "b": "456"}) != "invalid" + + q = httpx.QueryParams([("a", "123"), ("a", "456")]) + assert httpx.QueryParams(q) == q