638 lines
21 KiB
Python
638 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2022 The Pigweed Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
# use this file except in compliance with the License. You may obtain a copy of
|
|
# the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations under
|
|
# the License.
|
|
"""Proxy for transfer integration testing.
|
|
|
|
This module contains a proxy for transfer intergation testing. It is capable
|
|
of introducing various link failures into the connection between the client and
|
|
server.
|
|
"""
|
|
|
|
import abc
|
|
import argparse
|
|
import asyncio
|
|
from enum import Enum
|
|
import logging
|
|
import random
|
|
import socket
|
|
import sys
|
|
import time
|
|
from typing import Any, Awaitable, Callable, Iterable, List, Optional
|
|
|
|
from google.protobuf import text_format
|
|
|
|
from pigweed.pw_rpc.internal import packet_pb2
|
|
from pigweed.pw_transfer import transfer_pb2
|
|
from pigweed.pw_transfer.integration_test import config_pb2
|
|
from pw_hdlc import decode
|
|
from pw_transfer.chunk import Chunk
|
|
|
|
_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')
|
|
|
|
# This is the maximum size of the socket receive buffers. Ideally, this is set
|
|
# to the lowest allowed value to minimize buffering between the proxy and
|
|
# clients so rate limiting causes the client to block and wait for the
|
|
# integration test proxy to drain rather than allowing OS buffers to backlog
|
|
# large quantities of data.
|
|
#
|
|
# Note that the OS may chose to not strictly follow this requested buffer size.
|
|
# Still, setting this value to be relatively small does reduce bufer sizes
|
|
# significantly enough to better reflect typical inter-device communication.
|
|
#
|
|
# For this to be effective, clients should also configure their sockets to a
|
|
# smaller send buffer size.
|
|
_RECEIVE_BUFFER_SIZE = 2048
|
|
|
|
|
|
class Event(Enum):
|
|
TRANSFER_START = 1
|
|
PARAMETERS_RETRANSMIT = 2
|
|
PARAMETERS_CONTINUE = 3
|
|
START_ACK_CONFIRMATION = 4
|
|
|
|
|
|
class Filter(abc.ABC):
|
|
"""An abstract interface for manipulating a stream of data.
|
|
|
|
``Filter``s are used to implement various transforms to simulate real
|
|
world link properties. Some examples include: data corruption,
|
|
packet loss, packet reordering, rate limiting, latency modeling.
|
|
|
|
A ``Filter`` implementation should implement the ``process`` method
|
|
and call ``self.send_data()`` when it has data to send.
|
|
"""
|
|
|
|
def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
|
|
self.send_data = send_data
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
async def process(self, data: bytes) -> None:
|
|
"""Processes incoming data.
|
|
|
|
Implementations of this method may send arbitrary data, or none, using
|
|
the ``self.send_data()`` handler.
|
|
"""
|
|
|
|
async def __call__(self, data: bytes) -> None:
|
|
await self.process(data)
|
|
|
|
|
|
class HdlcPacketizer(Filter):
|
|
"""A filter which aggregates data into complete HDLC packets.
|
|
|
|
Since the proxy transport (SOCK_STREAM) has no framing and we want some
|
|
filters to operates on whole frames, this filter can be used so that
|
|
downstream filters see whole frames.
|
|
"""
|
|
|
|
def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
|
|
super().__init__(send_data)
|
|
self.decoder = decode.FrameDecoder()
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
for frame in self.decoder.process(data):
|
|
await self.send_data(frame.raw_encoded)
|
|
|
|
|
|
class DataDropper(Filter):
|
|
"""A filter which drops some data.
|
|
|
|
DataDropper will drop data passed through ``process()`` at the
|
|
specified ``rate``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
send_data: Callable[[bytes], Awaitable[None]],
|
|
name: str,
|
|
rate: float,
|
|
seed: Optional[int] = None,
|
|
):
|
|
super().__init__(send_data)
|
|
self._rate = rate
|
|
self._name = name
|
|
if seed == None:
|
|
seed = time.time_ns()
|
|
self._rng = random.Random(seed)
|
|
_LOG.info(f'{name} DataDropper initialized with seed {seed}')
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
if self._rng.uniform(0.0, 1.0) < self._rate:
|
|
_LOG.info(f'{self._name} dropped {len(data)} bytes of data')
|
|
else:
|
|
await self.send_data(data)
|
|
|
|
|
|
class KeepDropQueue(Filter):
|
|
"""A filter which alternates between sending packets and dropping packets.
|
|
|
|
A KeepDropQueue filter will alternate between keeping packets and dropping
|
|
chunks of data based on a keep/drop queue provided during its creation. The
|
|
queue is looped over unless a negative element is found. A negative number
|
|
is effectively the same as a value of infinity.
|
|
|
|
This filter is typically most pratical when used with a packetizer so data
|
|
can be dropped as distinct packets.
|
|
|
|
Examples:
|
|
|
|
keep_drop_queue = [3, 2]:
|
|
Keeps 3 packets,
|
|
Drops 2 packets,
|
|
Keeps 3 packets,
|
|
Drops 2 packets,
|
|
... [loops indefinitely]
|
|
|
|
keep_drop_queue = [5, 99, 1, -1]:
|
|
Keeps 5 packets,
|
|
Drops 99 packets,
|
|
Keeps 1 packet,
|
|
Drops all further packets.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
send_data: Callable[[bytes], Awaitable[None]],
|
|
name: str,
|
|
keep_drop_queue: Iterable[int],
|
|
):
|
|
super().__init__(send_data)
|
|
self._keep_drop_queue = list(keep_drop_queue)
|
|
self._loop_idx = 0
|
|
self._current_count = self._keep_drop_queue[0]
|
|
self._keep = True
|
|
self._name = name
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
# Move forward through the queue if neeeded.
|
|
while self._current_count == 0:
|
|
self._loop_idx += 1
|
|
self._current_count = self._keep_drop_queue[
|
|
self._loop_idx % len(self._keep_drop_queue)
|
|
]
|
|
self._keep = not self._keep
|
|
|
|
if self._current_count > 0:
|
|
self._current_count -= 1
|
|
|
|
if self._keep:
|
|
await self.send_data(data)
|
|
_LOG.info(f'{self._name} forwarded {len(data)} bytes of data')
|
|
else:
|
|
_LOG.info(f'{self._name} dropped {len(data)} bytes of data')
|
|
|
|
|
|
class RateLimiter(Filter):
|
|
"""A filter which limits transmission rate.
|
|
|
|
This filter delays transmission of data by len(data)/rate.
|
|
"""
|
|
|
|
def __init__(
|
|
self, send_data: Callable[[bytes], Awaitable[None]], rate: float
|
|
):
|
|
super().__init__(send_data)
|
|
self._rate = rate
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
delay = len(data) / self._rate
|
|
await asyncio.sleep(delay)
|
|
await self.send_data(data)
|
|
|
|
|
|
class DataTransposer(Filter):
|
|
"""A filter which occasionally transposes two chunks of data.
|
|
|
|
This filter transposes data at the specified rate. It does this by
|
|
holding a chunk to transpose until another chunk arrives. The filter
|
|
will not hold a chunk longer than ``timeout`` seconds.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
send_data: Callable[[bytes], Awaitable[None]],
|
|
name: str,
|
|
rate: float,
|
|
timeout: float,
|
|
seed: int,
|
|
):
|
|
super().__init__(send_data)
|
|
self._name = name
|
|
self._rate = rate
|
|
self._timeout = timeout
|
|
self._data_queue = asyncio.Queue()
|
|
self._rng = random.Random(seed)
|
|
self._transpose_task = asyncio.create_task(self._transpose_handler())
|
|
|
|
_LOG.info(f'{name} DataTranspose initialized with seed {seed}')
|
|
|
|
def __del__(self):
|
|
_LOG.info(f'{self._name} cleaning up transpose task.')
|
|
self._transpose_task.cancel()
|
|
|
|
async def _transpose_handler(self):
|
|
"""Async task that handles the packet transposition and timeouts"""
|
|
held_data: Optional[bytes] = None
|
|
while True:
|
|
# Only use timeout if we have data held for transposition
|
|
timeout = None if held_data is None else self._timeout
|
|
try:
|
|
data = await asyncio.wait_for(
|
|
self._data_queue.get(), timeout=timeout
|
|
)
|
|
|
|
if held_data is not None:
|
|
# If we have held data, send it out of order.
|
|
await self.send_data(data)
|
|
await self.send_data(held_data)
|
|
held_data = None
|
|
else:
|
|
# Otherwise decide if we should transpose the current data.
|
|
if self._rng.uniform(0.0, 1.0) < self._rate:
|
|
_LOG.info(
|
|
f'{self._name} transposing {len(data)} bytes of data'
|
|
)
|
|
held_data = data
|
|
else:
|
|
await self.send_data(data)
|
|
|
|
except asyncio.TimeoutError:
|
|
_LOG.info(f'{self._name} sending data in order due to timeout')
|
|
await self.send_data(held_data)
|
|
held_data = None
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
# Queue data for processing by the transpose task.
|
|
await self._data_queue.put(data)
|
|
|
|
|
|
class ServerFailure(Filter):
|
|
"""A filter to simulate the server stopping sending packets.
|
|
|
|
ServerFailure takes a list of numbers of packets to send before
|
|
dropping all subsequent packets until a TRANSFER_START packet
|
|
is seen. This process is repeated for each element in
|
|
packets_before_failure. After that list is exhausted, ServerFailure
|
|
will send all packets.
|
|
|
|
This filter should be instantiated in the same filter stack as an
|
|
HdlcPacketizer so that EventFilter can decode complete packets.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
send_data: Callable[[bytes], Awaitable[None]],
|
|
name: str,
|
|
packets_before_failure_list: List[int],
|
|
):
|
|
super().__init__(send_data)
|
|
self._name = name
|
|
self._relay_packets = True
|
|
self._packets_before_failure_list = packets_before_failure_list
|
|
self.advance_packets_before_failure()
|
|
|
|
def advance_packets_before_failure(self):
|
|
if len(self._packets_before_failure_list) > 0:
|
|
self._packets_before_failure = (
|
|
self._packets_before_failure_list.pop(0)
|
|
)
|
|
else:
|
|
self._packets_before_failure = None
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
if self._packets_before_failure is None:
|
|
await self.send_data(data)
|
|
elif self._packets_before_failure > 0:
|
|
self._packets_before_failure -= 1
|
|
await self.send_data(data)
|
|
|
|
def handle_event(self, event: Event) -> None:
|
|
if event is Event.TRANSFER_START:
|
|
self.advance_packets_before_failure()
|
|
|
|
|
|
class WindowPacketDropper(Filter):
|
|
"""A filter to allow the same packet in each window to be dropped
|
|
|
|
WindowPacketDropper with drop the nth packet in each window as
|
|
specified by window_packet_to_drop. This process will happend
|
|
indefinitely for each window.
|
|
|
|
This filter should be instantiated in the same filter stack as an
|
|
HdlcPacketizer so that EventFilter can decode complete packets.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
send_data: Callable[[bytes], Awaitable[None]],
|
|
name: str,
|
|
window_packet_to_drop: int,
|
|
):
|
|
super().__init__(send_data)
|
|
self._name = name
|
|
self._relay_packets = True
|
|
self._window_packet_to_drop = window_packet_to_drop
|
|
self._window_packet = 0
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
try:
|
|
is_data_chunk = (
|
|
_extract_transfer_chunk(data).type is Chunk.Type.DATA
|
|
)
|
|
except Exception:
|
|
# Invalid / non-chunk data (e.g. text logs); ignore.
|
|
is_data_chunk = False
|
|
|
|
# Only count transfer data chunks as part of a window.
|
|
if is_data_chunk:
|
|
if self._window_packet != self._window_packet_to_drop:
|
|
await self.send_data(data)
|
|
|
|
self._window_packet += 1
|
|
else:
|
|
await self.send_data(data)
|
|
|
|
def handle_event(self, event: Event) -> None:
|
|
if event in (
|
|
Event.PARAMETERS_RETRANSMIT,
|
|
Event.PARAMETERS_CONTINUE,
|
|
Event.START_ACK_CONFIRMATION,
|
|
):
|
|
self._window_packet = 0
|
|
|
|
|
|
class EventFilter(Filter):
|
|
"""A filter that inspects packets and send events to other filters.
|
|
|
|
This filter should be instantiated in the same filter stack as an
|
|
HdlcPacketizer so that it can decode complete packets.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
send_data: Callable[[bytes], Awaitable[None]],
|
|
name: str,
|
|
event_queue: asyncio.Queue,
|
|
):
|
|
super().__init__(send_data)
|
|
self._queue = event_queue
|
|
|
|
async def process(self, data: bytes) -> None:
|
|
try:
|
|
chunk = _extract_transfer_chunk(data)
|
|
if chunk.type is Chunk.Type.START:
|
|
await self._queue.put(Event.TRANSFER_START)
|
|
if chunk.type is Chunk.Type.START_ACK_CONFIRMATION:
|
|
await self._queue.put(Event.START_ACK_CONFIRMATION)
|
|
elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT:
|
|
await self._queue.put(Event.PARAMETERS_RETRANSMIT)
|
|
elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE:
|
|
await self._queue.put(Event.PARAMETERS_CONTINUE)
|
|
except:
|
|
# Silently ignore invalid packets
|
|
pass
|
|
|
|
await self.send_data(data)
|
|
|
|
|
|
def _extract_transfer_chunk(data: bytes) -> Chunk:
|
|
"""Gets a transfer Chunk from an HDLC frame containing an RPC packet.
|
|
|
|
Raises an exception if a valid chunk does not exist.
|
|
"""
|
|
|
|
decoder = decode.FrameDecoder()
|
|
for frame in decoder.process(data):
|
|
packet = packet_pb2.RpcPacket()
|
|
packet.ParseFromString(frame.data)
|
|
raw_chunk = transfer_pb2.Chunk()
|
|
raw_chunk.ParseFromString(packet.payload)
|
|
return Chunk.from_message(raw_chunk)
|
|
|
|
raise ValueError("Invalid transfer frame")
|
|
|
|
|
|
async def _handle_simplex_events(
|
|
event_queue: asyncio.Queue, handlers: List[Callable[[Event], None]]
|
|
):
|
|
while True:
|
|
event = await event_queue.get()
|
|
for handler in handlers:
|
|
handler(event)
|
|
|
|
|
|
async def _handle_simplex_connection(
|
|
name: str,
|
|
filter_stack_config: List[config_pb2.FilterConfig],
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
inbound_event_queue: asyncio.Queue,
|
|
outbound_event_queue: asyncio.Queue,
|
|
) -> None:
|
|
"""Handle a single direction of a bidirectional connection between
|
|
server and client."""
|
|
|
|
async def send(data: bytes):
|
|
writer.write(data)
|
|
await writer.drain()
|
|
|
|
filter_stack = EventFilter(send, name, outbound_event_queue)
|
|
|
|
event_handlers: List[Callable[[Event], None]] = []
|
|
|
|
# Build the filter stack from the bottom up
|
|
for config in reversed(filter_stack_config):
|
|
filter_name = config.WhichOneof("filter")
|
|
if filter_name == "hdlc_packetizer":
|
|
filter_stack = HdlcPacketizer(filter_stack)
|
|
elif filter_name == "data_dropper":
|
|
data_dropper = config.data_dropper
|
|
filter_stack = DataDropper(
|
|
filter_stack, name, data_dropper.rate, data_dropper.seed
|
|
)
|
|
elif filter_name == "rate_limiter":
|
|
filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate)
|
|
elif filter_name == "data_transposer":
|
|
transposer = config.data_transposer
|
|
filter_stack = DataTransposer(
|
|
filter_stack,
|
|
name,
|
|
transposer.rate,
|
|
transposer.timeout,
|
|
transposer.seed,
|
|
)
|
|
elif filter_name == "server_failure":
|
|
server_failure = config.server_failure
|
|
filter_stack = ServerFailure(
|
|
filter_stack, name, server_failure.packets_before_failure
|
|
)
|
|
event_handlers.append(filter_stack.handle_event)
|
|
elif filter_name == "keep_drop_queue":
|
|
keep_drop_queue = config.keep_drop_queue
|
|
filter_stack = KeepDropQueue(
|
|
filter_stack, name, keep_drop_queue.keep_drop_queue
|
|
)
|
|
elif filter_name == "window_packet_dropper":
|
|
window_packet_dropper = config.window_packet_dropper
|
|
filter_stack = WindowPacketDropper(
|
|
filter_stack, name, window_packet_dropper.window_packet_to_drop
|
|
)
|
|
event_handlers.append(filter_stack.handle_event)
|
|
else:
|
|
sys.exit(f'Unknown filter {filter_name}')
|
|
|
|
event_task = asyncio.create_task(
|
|
_handle_simplex_events(inbound_event_queue, event_handlers)
|
|
)
|
|
|
|
while True:
|
|
# Arbitrarily chosen "page sized" read.
|
|
data = await reader.read(4096)
|
|
|
|
# An empty data indicates that the connection is closed.
|
|
if not data:
|
|
_LOG.info(f'{name} connection closed.')
|
|
return
|
|
|
|
await filter_stack.process(data)
|
|
|
|
|
|
async def _handle_connection(
|
|
server_port: int,
|
|
config: config_pb2.ProxyConfig,
|
|
client_reader: asyncio.StreamReader,
|
|
client_writer: asyncio.StreamWriter,
|
|
) -> None:
|
|
"""Handle a connection between server and client."""
|
|
|
|
client_addr = client_writer.get_extra_info('peername')
|
|
_LOG.info(f'New client connection from {client_addr}')
|
|
|
|
# Open a new connection to the server for each client connection.
|
|
#
|
|
# TODO(konkers): catch exception and close client writer
|
|
server_reader, server_writer = await asyncio.open_connection(
|
|
'localhost', server_port
|
|
)
|
|
_LOG.info(f'New connection opened to server')
|
|
|
|
# Queues for the simplex connections to pass events to each other.
|
|
server_event_queue = asyncio.Queue()
|
|
client_event_queue = asyncio.Queue()
|
|
|
|
# Instantiate two simplex handler one for each direction of the connection.
|
|
_, pending = await asyncio.wait(
|
|
[
|
|
asyncio.create_task(
|
|
_handle_simplex_connection(
|
|
"client",
|
|
config.client_filter_stack,
|
|
client_reader,
|
|
server_writer,
|
|
server_event_queue,
|
|
client_event_queue,
|
|
)
|
|
),
|
|
asyncio.create_task(
|
|
_handle_simplex_connection(
|
|
"server",
|
|
config.server_filter_stack,
|
|
server_reader,
|
|
client_writer,
|
|
client_event_queue,
|
|
server_event_queue,
|
|
)
|
|
),
|
|
],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
# When one side terminates the connection, also terminate the other side
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
for stream in [client_writer, server_writer]:
|
|
stream.close()
|
|
|
|
|
|
def _parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--server-port',
|
|
type=int,
|
|
required=True,
|
|
help='Port of the integration test server. The proxy will forward connections to this port',
|
|
)
|
|
parser.add_argument(
|
|
'--client-port',
|
|
type=int,
|
|
required=True,
|
|
help='Port on which to listen for connections from integration test client.',
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def _init_logging(level: int) -> None:
|
|
_LOG.setLevel(logging.DEBUG)
|
|
log_to_stderr = logging.StreamHandler()
|
|
log_to_stderr.setLevel(level)
|
|
log_to_stderr.setFormatter(
|
|
logging.Formatter(
|
|
fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s',
|
|
datefmt='%H:%M:%S',
|
|
)
|
|
)
|
|
|
|
_LOG.addHandler(log_to_stderr)
|
|
|
|
|
|
async def _main(server_port: int, client_port: int) -> None:
|
|
_init_logging(logging.DEBUG)
|
|
|
|
# Load config from stdin using synchronous IO
|
|
text_config = sys.stdin.buffer.read()
|
|
|
|
config = text_format.Parse(text_config, config_pb2.ProxyConfig())
|
|
|
|
# Instantiate the TCP server.
|
|
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
server_socket.setsockopt(
|
|
socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE
|
|
)
|
|
server_socket.bind(('localhost', client_port))
|
|
server = await asyncio.start_server(
|
|
lambda reader, writer: _handle_connection(
|
|
server_port, config, reader, writer
|
|
),
|
|
limit=_RECEIVE_BUFFER_SIZE,
|
|
sock=server_socket,
|
|
)
|
|
|
|
addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
|
|
_LOG.info(f'Listening for client connection on {addrs}')
|
|
|
|
# Run the TCP server.
|
|
async with server:
|
|
await server.serve_forever()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(_main(**vars(_parse_args())))
|