469 lines
15 KiB
C++
469 lines
15 KiB
C++
// Copyright 2021 The Chromium Authors
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "net/test/embedded_test_server/http2_connection.h"
|
|
|
|
#include <memory>
|
|
|
|
#include "base/functional/bind.h"
|
|
#include "base/functional/callback_helpers.h"
|
|
#include "base/memory/raw_ptr.h"
|
|
#include "base/memory/raw_ref.h"
|
|
#include "base/strings/strcat.h"
|
|
#include "base/strings/string_piece.h"
|
|
#include "base/task/sequenced_task_runner.h"
|
|
#include "net/http/http_response_headers.h"
|
|
#include "net/http/http_status_code.h"
|
|
#include "net/socket/stream_socket.h"
|
|
#include "net/ssl/ssl_info.h"
|
|
#include "net/test/embedded_test_server/embedded_test_server.h"
|
|
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
|
|
|
|
namespace net {
|
|
|
|
namespace {
|
|
|
|
std::vector<http2::adapter::Header> GenerateHeaders(HttpStatusCode status,
|
|
base::StringPairs headers) {
|
|
std::vector<http2::adapter::Header> response_vector;
|
|
response_vector.emplace_back(
|
|
http2::adapter::HeaderRep(std::string(":status")),
|
|
http2::adapter::HeaderRep(base::NumberToString(status)));
|
|
for (const auto& header : headers) {
|
|
// Connection (and related) headers are considered malformed and will
|
|
// result in a client error
|
|
if (base::EqualsCaseInsensitiveASCII(header.first, "connection"))
|
|
continue;
|
|
response_vector.emplace_back(
|
|
http2::adapter::HeaderRep(base::ToLowerASCII(header.first)),
|
|
http2::adapter::HeaderRep(header.second));
|
|
}
|
|
|
|
return response_vector;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace test_server {
|
|
|
|
class Http2Connection::DataFrameSource
|
|
: public http2::adapter::DataFrameSource {
|
|
public:
|
|
explicit DataFrameSource(Http2Connection* connection,
|
|
const StreamId& stream_id)
|
|
: connection_(connection), stream_id_(stream_id) {}
|
|
~DataFrameSource() override = default;
|
|
DataFrameSource(const DataFrameSource&) = delete;
|
|
DataFrameSource& operator=(const DataFrameSource&) = delete;
|
|
|
|
std::pair<int64_t, bool> SelectPayloadLength(size_t max_length) override {
|
|
if (chunks_.empty())
|
|
return {kBlocked, last_frame_};
|
|
|
|
bool finished = (chunks_.size() <= 1) &&
|
|
(chunks_.front().size() <= max_length) && last_frame_;
|
|
|
|
return {std::min(chunks_.front().size(), max_length), finished};
|
|
}
|
|
|
|
bool Send(absl::string_view frame_header, size_t payload_length) override {
|
|
std::string concatenated =
|
|
base::StrCat({frame_header, chunks_.front().substr(0, payload_length)});
|
|
const int64_t result = connection_->OnReadyToSend(concatenated);
|
|
// Write encountered error.
|
|
if (result < 0) {
|
|
connection_->OnConnectionError(ConnectionError::kSendError);
|
|
return false;
|
|
}
|
|
|
|
// Write blocked.
|
|
if (result == 0) {
|
|
connection_->blocked_streams_.insert(*stream_id_);
|
|
return false;
|
|
}
|
|
|
|
if (static_cast<const size_t>(result) < concatenated.size()) {
|
|
// Probably need to handle this better within this test class.
|
|
QUICHE_LOG(DFATAL)
|
|
<< "DATA frame not fully flushed. Connection will be corrupt!";
|
|
connection_->OnConnectionError(ConnectionError::kSendError);
|
|
return false;
|
|
}
|
|
|
|
chunks_.front().erase(0, payload_length);
|
|
|
|
if (chunks_.front().empty())
|
|
chunks_.pop();
|
|
|
|
if (chunks_.empty() && send_completion_callback_) {
|
|
std::move(send_completion_callback_).Run();
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool send_fin() const override { return true; }
|
|
|
|
void AddChunk(std::string chunk) { chunks_.push(std::move(chunk)); }
|
|
void set_last_frame(bool last_frame) { last_frame_ = last_frame; }
|
|
void SetSendCompletionCallback(base::OnceClosure callback) {
|
|
send_completion_callback_ = std::move(callback);
|
|
}
|
|
|
|
private:
|
|
const raw_ptr<Http2Connection> connection_;
|
|
const raw_ref<const StreamId, DanglingUntriaged> stream_id_;
|
|
std::queue<std::string> chunks_;
|
|
bool last_frame_ = false;
|
|
base::OnceClosure send_completion_callback_;
|
|
};
|
|
|
|
// Corresponds to an HTTP/2 stream
|
|
class Http2Connection::ResponseDelegate : public HttpResponseDelegate {
|
|
public:
|
|
ResponseDelegate(Http2Connection* connection, StreamId stream_id)
|
|
: stream_id_(stream_id), connection_(connection) {}
|
|
~ResponseDelegate() override = default;
|
|
ResponseDelegate(const ResponseDelegate&) = delete;
|
|
ResponseDelegate& operator=(const ResponseDelegate&) = delete;
|
|
|
|
void AddResponse(std::unique_ptr<HttpResponse> response) override {
|
|
responses_.push_back(std::move(response));
|
|
}
|
|
|
|
void SendResponseHeaders(HttpStatusCode status,
|
|
const std::string& status_reason,
|
|
const base::StringPairs& headers) override {
|
|
std::unique_ptr<DataFrameSource> data_frame =
|
|
std::make_unique<DataFrameSource>(connection_, stream_id_);
|
|
data_frame_ = data_frame.get();
|
|
connection_->adapter()->SubmitResponse(
|
|
stream_id_, GenerateHeaders(status, headers), std::move(data_frame));
|
|
connection_->SendIfNotProcessing();
|
|
}
|
|
|
|
void SendRawResponseHeaders(const std::string& headers) override {
|
|
scoped_refptr<HttpResponseHeaders> parsed_headers =
|
|
HttpResponseHeaders::TryToCreate(headers);
|
|
if (parsed_headers->response_code() == 0) {
|
|
connection_->OnConnectionError(ConnectionError::kParseError);
|
|
LOG(ERROR) << "raw headers could not be parsed";
|
|
}
|
|
base::StringPairs header_pairs;
|
|
size_t iter = 0;
|
|
std::string key, value;
|
|
while (parsed_headers->EnumerateHeaderLines(&iter, &key, &value))
|
|
header_pairs.emplace_back(key, value);
|
|
SendResponseHeaders(
|
|
static_cast<HttpStatusCode>(parsed_headers->response_code()),
|
|
/*status_reason=*/"", header_pairs);
|
|
}
|
|
|
|
void SendContents(const std::string& contents,
|
|
base::OnceClosure callback) override {
|
|
DCHECK(data_frame_);
|
|
data_frame_->AddChunk(contents);
|
|
data_frame_->SetSendCompletionCallback(std::move(callback));
|
|
connection_->adapter()->ResumeStream(stream_id_);
|
|
connection_->SendIfNotProcessing();
|
|
}
|
|
|
|
void FinishResponse() override {
|
|
data_frame_->set_last_frame(true);
|
|
connection_->adapter()->ResumeStream(stream_id_);
|
|
connection_->SendIfNotProcessing();
|
|
}
|
|
|
|
void SendContentsAndFinish(const std::string& contents) override {
|
|
data_frame_->set_last_frame(true);
|
|
SendContents(contents, base::DoNothing());
|
|
}
|
|
|
|
void SendHeadersContentAndFinish(HttpStatusCode status,
|
|
const std::string& status_reason,
|
|
const base::StringPairs& headers,
|
|
const std::string& contents) override {
|
|
std::unique_ptr<DataFrameSource> data_frame =
|
|
std::make_unique<DataFrameSource>(connection_, stream_id_);
|
|
data_frame->AddChunk(contents);
|
|
data_frame->set_last_frame(true);
|
|
connection_->adapter()->SubmitResponse(
|
|
stream_id_, GenerateHeaders(status, headers), std::move(data_frame));
|
|
connection_->SendIfNotProcessing();
|
|
}
|
|
base::WeakPtr<ResponseDelegate> GetWeakPtr() {
|
|
return weak_factory_.GetWeakPtr();
|
|
}
|
|
|
|
private:
|
|
std::vector<std::unique_ptr<HttpResponse>> responses_;
|
|
StreamId stream_id_;
|
|
const raw_ptr<Http2Connection> connection_;
|
|
raw_ptr<DataFrameSource> data_frame_;
|
|
base::WeakPtrFactory<ResponseDelegate> weak_factory_{this};
|
|
};
|
|
|
|
Http2Connection::Http2Connection(
|
|
std::unique_ptr<StreamSocket> socket,
|
|
EmbeddedTestServerConnectionListener* connection_listener,
|
|
EmbeddedTestServer* embedded_test_server)
|
|
: socket_(std::move(socket)),
|
|
connection_listener_(connection_listener),
|
|
embedded_test_server_(embedded_test_server),
|
|
read_buf_(base::MakeRefCounted<IOBufferWithSize>(4096)) {
|
|
http2::adapter::OgHttp2Adapter::Options options;
|
|
options.perspective = http2::adapter::Perspective::kServer;
|
|
adapter_ = http2::adapter::OgHttp2Adapter::Create(*this, options);
|
|
}
|
|
|
|
Http2Connection::~Http2Connection() = default;
|
|
|
|
void Http2Connection::OnSocketReady() {
|
|
ReadData();
|
|
}
|
|
|
|
void Http2Connection::ReadData() {
|
|
while (true) {
|
|
int rv = socket_->Read(
|
|
read_buf_.get(), read_buf_->size(),
|
|
base::BindOnce(&Http2Connection::OnDataRead, base::Unretained(this)));
|
|
if (rv == ERR_IO_PENDING)
|
|
return;
|
|
if (!HandleData(rv))
|
|
return;
|
|
}
|
|
}
|
|
|
|
void Http2Connection::OnDataRead(int rv) {
|
|
if (HandleData(rv))
|
|
ReadData();
|
|
}
|
|
|
|
bool Http2Connection::HandleData(int rv) {
|
|
if (rv <= 0) {
|
|
embedded_test_server_->RemoveConnection(this);
|
|
return false;
|
|
}
|
|
|
|
if (connection_listener_)
|
|
connection_listener_->ReadFromSocket(*socket_, rv);
|
|
|
|
absl::string_view remaining_buffer(read_buf_->data(), rv);
|
|
while (!remaining_buffer.empty()) {
|
|
int result = adapter_->ProcessBytes(remaining_buffer);
|
|
if (result < 0)
|
|
return false;
|
|
remaining_buffer = remaining_buffer.substr(result);
|
|
}
|
|
|
|
// Any frames and data sources will be queued up and sent all at once below
|
|
DCHECK(!processing_responses_);
|
|
processing_responses_ = true;
|
|
while (!ready_streams_.empty()) {
|
|
StreamId stream_id = ready_streams_.front();
|
|
ready_streams_.pop();
|
|
auto delegate = std::make_unique<ResponseDelegate>(this, stream_id);
|
|
ResponseDelegate* delegate_ptr = delegate.get();
|
|
response_map_[stream_id] = std::move(delegate);
|
|
embedded_test_server_->HandleRequest(delegate_ptr->GetWeakPtr(),
|
|
std::move(request_map_[stream_id]));
|
|
request_map_.erase(stream_id);
|
|
}
|
|
adapter_->Send();
|
|
processing_responses_ = false;
|
|
return true;
|
|
}
|
|
|
|
StreamSocket* Http2Connection::Socket() {
|
|
return socket_.get();
|
|
}
|
|
|
|
std::unique_ptr<StreamSocket> Http2Connection::TakeSocket() {
|
|
return std::move(socket_);
|
|
}
|
|
|
|
base::WeakPtr<HttpConnection> Http2Connection::GetWeakPtr() {
|
|
return weak_factory_.GetWeakPtr();
|
|
}
|
|
|
|
int64_t Http2Connection::OnReadyToSend(absl::string_view serialized) {
|
|
if (write_buf_)
|
|
return kSendBlocked;
|
|
|
|
write_buf_ = base::MakeRefCounted<DrainableIOBuffer>(
|
|
base::MakeRefCounted<StringIOBuffer>(std::string(serialized)),
|
|
serialized.size());
|
|
SendInternal();
|
|
return serialized.size();
|
|
}
|
|
|
|
bool Http2Connection::OnCloseStream(StreamId stream_id,
|
|
http2::adapter::Http2ErrorCode error_code) {
|
|
response_map_.erase(stream_id);
|
|
return true;
|
|
}
|
|
|
|
void Http2Connection::SendInternal() {
|
|
DCHECK(socket_);
|
|
DCHECK(write_buf_);
|
|
while (write_buf_->BytesRemaining() > 0) {
|
|
int rv = socket_->Write(write_buf_.get(), write_buf_->BytesRemaining(),
|
|
base::BindOnce(&Http2Connection::OnSendInternalDone,
|
|
base::Unretained(this)),
|
|
TRAFFIC_ANNOTATION_FOR_TESTS);
|
|
if (rv == ERR_IO_PENDING)
|
|
return;
|
|
|
|
if (rv < 0) {
|
|
embedded_test_server_->RemoveConnection(this);
|
|
break;
|
|
}
|
|
|
|
write_buf_->DidConsume(rv);
|
|
}
|
|
write_buf_ = nullptr;
|
|
}
|
|
|
|
void Http2Connection::OnSendInternalDone(int rv) {
|
|
DCHECK(write_buf_);
|
|
if (rv < 0) {
|
|
embedded_test_server_->RemoveConnection(this);
|
|
write_buf_ = nullptr;
|
|
return;
|
|
}
|
|
write_buf_->DidConsume(rv);
|
|
|
|
SendInternal();
|
|
|
|
if (!write_buf_) {
|
|
// Now that writing is no longer blocked, any blocked streams can be
|
|
// resumed.
|
|
for (const auto& stream_id : blocked_streams_)
|
|
adapter_->ResumeStream(stream_id);
|
|
|
|
if (adapter_->want_write()) {
|
|
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
|
|
FROM_HERE, base::BindOnce(&Http2Connection::SendIfNotProcessing,
|
|
weak_factory_.GetWeakPtr()));
|
|
}
|
|
}
|
|
}
|
|
|
|
void Http2Connection::SendIfNotProcessing() {
|
|
if (!processing_responses_) {
|
|
processing_responses_ = true;
|
|
adapter_->Send();
|
|
processing_responses_ = false;
|
|
}
|
|
}
|
|
|
|
http2::adapter::Http2VisitorInterface::OnHeaderResult
|
|
Http2Connection::OnHeaderForStream(http2::adapter::Http2StreamId stream_id,
|
|
absl::string_view key,
|
|
absl::string_view value) {
|
|
header_map_[stream_id][key.data()] = value.data();
|
|
return http2::adapter::Http2VisitorInterface::HEADER_OK;
|
|
}
|
|
|
|
bool Http2Connection::OnEndHeadersForStream(
|
|
http2::adapter::Http2StreamId stream_id) {
|
|
HttpRequest::HeaderMap header_map = header_map_[stream_id];
|
|
auto request = std::make_unique<HttpRequest>();
|
|
// TODO(crbug.com/1375303): Handle proxy cases.
|
|
request->relative_url = header_map[":path"];
|
|
request->base_url = GURL(header_map[":authority"]);
|
|
request->method_string = header_map[":method"];
|
|
request->method = HttpRequestParser::GetMethodType(request->method_string);
|
|
request->headers = header_map;
|
|
|
|
request->has_content = false;
|
|
|
|
SSLInfo ssl_info;
|
|
DCHECK(socket_->GetSSLInfo(&ssl_info));
|
|
request->ssl_info = ssl_info;
|
|
request_map_[stream_id] = std::move(request);
|
|
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnEndStream(http2::adapter::Http2StreamId stream_id) {
|
|
ready_streams_.push(stream_id);
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnFrameHeader(StreamId /*stream_id*/,
|
|
size_t /*length*/,
|
|
uint8_t /*type*/,
|
|
uint8_t /*flags*/) {
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnBeginHeadersForStream(StreamId stream_id) {
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnBeginDataForStream(StreamId stream_id,
|
|
size_t payload_length) {
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnDataForStream(StreamId stream_id,
|
|
absl::string_view data) {
|
|
auto request = request_map_.find(stream_id);
|
|
if (request == request_map_.end()) {
|
|
// We should not receive data before receiving headers.
|
|
return false;
|
|
}
|
|
|
|
request->second->has_content = true;
|
|
request->second->content.append(data.data(), data.size());
|
|
adapter_->MarkDataConsumedForStream(stream_id, data.size());
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnDataPaddingLength(StreamId stream_id,
|
|
size_t padding_length) {
|
|
adapter_->MarkDataConsumedForStream(stream_id, padding_length);
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnGoAway(StreamId last_accepted_stream_id,
|
|
http2::adapter::Http2ErrorCode error_code,
|
|
absl::string_view opaque_data) {
|
|
return true;
|
|
}
|
|
|
|
int Http2Connection::OnBeforeFrameSent(uint8_t frame_type,
|
|
StreamId stream_id,
|
|
size_t length,
|
|
uint8_t flags) {
|
|
return 0;
|
|
}
|
|
|
|
int Http2Connection::OnFrameSent(uint8_t frame_type,
|
|
StreamId stream_id,
|
|
size_t length,
|
|
uint8_t flags,
|
|
uint32_t error_code) {
|
|
return 0;
|
|
}
|
|
|
|
bool Http2Connection::OnInvalidFrame(StreamId stream_id,
|
|
InvalidFrameError error) {
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnMetadataForStream(StreamId stream_id,
|
|
absl::string_view metadata) {
|
|
return true;
|
|
}
|
|
|
|
bool Http2Connection::OnMetadataEndForStream(StreamId stream_id) {
|
|
return true;
|
|
}
|
|
|
|
} // namespace test_server
|
|
|
|
} // namespace net
|