import h2.events
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
-from .datastructures import Client, Origin, Request, Response
from .exceptions import ConnectTimeout, ReadTimeout
+from .models import Client, Origin, Request, Response
class HTTP2Connection(Client):
self.timeout = timeout
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
- self.events = [] # type: typing.List[h2.events.Event]
+ self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
+ self.initialized = False
@property
def is_closed(self) -> bool:
if timeout is None:
timeout = self.timeout
+ if not self.initialized:
+ self.initiate_connection()
+
# Start sending the request.
- await self._initiate_connection()
- await self._send_headers(request)
+ stream_id = self.h2_state.get_next_available_stream_id()
+ self.events[stream_id] = []
+ await self.send_headers(stream_id, request)
# Send the request body.
- if request.body:
- await self._send_data(request.body)
+ async for data in request.stream():
+ await self.send_data(stream_id, data)
# Finalize sending the request.
- await self._end_stream()
+ await self.end_stream(stream_id)
# Start getting the response.
while True:
- event = await self._receive_event(timeout)
+ event = await self.receive_event(stream_id, timeout)
if isinstance(event, h2.events.ResponseReceived):
break
elif not k.startswith(b":"):
headers.append((k, v))
- body = self._body_iter(timeout)
+ body = self.body_iter(stream_id, timeout)
return Response(
status_code=status_code,
protocol="HTTP/2",
headers=headers,
body=body,
- on_close=self._release,
+ on_close=self.release,
)
- async def _initiate_connection(self) -> None:
+ def initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
+ self.initialized = True
- async def _send_headers(self, request: Request) -> None:
+ async def send_headers(self, stream_id: int, request: Request) -> None:
headers = [
(b":method", request.method.encode()),
(b":authority", request.url.hostname.encode()),
(b":scheme", request.url.scheme.encode()),
(b":path", request.url.full_path.encode()),
] + request.headers
- self.h2_state.send_headers(1, headers)
+ self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
- async def _send_data(self, data: bytes) -> None:
- self.h2_state.send_data(1, data)
+ async def send_data(self, stream_id: int, data: bytes) -> None:
+ self.h2_state.send_data(stream_id, data)
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
- async def _end_stream(self) -> None:
- self.h2_state.end_stream(1)
+ async def end_stream(self, stream_id: int) -> None:
+ self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
- async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
+ async def body_iter(
+ self, stream_id: int, timeout: TimeoutConfig
+ ) -> typing.AsyncIterator[bytes]:
while True:
- event = await self._receive_event(timeout)
+ event = await self.receive_event(stream_id, timeout)
if isinstance(event, h2.events.DataReceived):
yield event.data
elif isinstance(event, h2.events.StreamEnded):
+ del self.events[stream_id]
break
- async def _receive_event(self, timeout: TimeoutConfig) -> h2.events.Event:
- while not self.events:
+ async def receive_event(
+ self, stream_id: int, timeout: TimeoutConfig
+ ) -> h2.events.Event:
+ while not self.events[stream_id]:
try:
data = await asyncio.wait_for(
self.reader.read(2048), timeout.read_timeout
raise ReadTimeout()
events = self.h2_state.receive_data(data)
- self.events.extend(events)
+ for event in events:
+ if getattr(event, "stream_id", 0):
+ self.events[event.stream_id].append(event)
data_to_send = self.h2_state.data_to_send()
if data_to_send:
self.writer.write(data_to_send)
- return self.events.pop(0)
-
- async def _release(self) -> None:
- # if (
- # self.h11_state.our_state is h11.DONE
- # and self.h11_state.their_state is h11.DONE
- # ):
- # self.h11_state.start_next_cycle()
- # else:
- # await self.close()
+ return self.events[stream_id].pop(0)
+ async def release(self) -> None:
if self.on_release is not None:
await self.on_release(self)
async def close(self) -> None:
- # event = h11.ConnectionClosed()
- # try:
- # # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
- # self.h11_state.send(event)
- # except h11.ProtocolError:
- # # If we're in some other state then it's a premature close,
- # # and we'll end up in h11.ERROR.
- # pass
-
- if self.writer is not None:
- self.writer.close()
+ self.writer.close()