// Copyright 2021 The Pigweed Authors // // Licensed under the Apache License, Version 2.0 (the "License"); you may not // use this file except in compliance with the License. You may obtain a copy of // the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations under // the License. #include "pw_protobuf/encoder.h" #include #include #include #include #include "pw_assert/check.h" #include "pw_bytes/span.h" #include "pw_protobuf/internal/codegen.h" #include "pw_protobuf/serialized_size.h" #include "pw_protobuf/stream_decoder.h" #include "pw_protobuf/wire_format.h" #include "pw_span/span.h" #include "pw_status/status.h" #include "pw_status/try.h" #include "pw_stream/memory_stream.h" #include "pw_stream/stream.h" #include "pw_string/string.h" #include "pw_varint/varint.h" namespace pw::protobuf { using internal::VarintType; StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number, bool write_when_empty) { PW_CHECK(!nested_encoder_open()); PW_CHECK(ValidFieldNumber(field_number)); nested_field_number_ = field_number; // Pass the unused space of the scratch buffer to the nested encoder to use // as their scratch buffer. size_t key_size = varint::EncodedSize(FieldKey(field_number, WireType::kDelimited)); size_t reserved_size = key_size + config::kMaxVarintSize; size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(), writer_.ConservativeWriteLimit()); // Account for reserved bytes. max_size = max_size > reserved_size ? max_size - reserved_size : 0; // Cap based on max varint size. max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize), static_cast(max_size)); ByteSpan nested_buffer; if (max_size > 0) { nested_buffer = ByteSpan( memory_writer_.data() + reserved_size + memory_writer_.bytes_written(), max_size); } else { nested_buffer = ByteSpan(); } return StreamEncoder(*this, nested_buffer, write_when_empty); } void StreamEncoder::CloseEncoder() { // If this was an invalidated StreamEncoder which cannot be used, permit the // object to be cleanly destructed by doing nothing. if (nested_field_number_ == kFirstReservedNumber) { return; } PW_CHECK( !nested_encoder_open(), "Tried to destruct a proto encoder with an active submessage encoder"); if (parent_ != nullptr) { parent_->CloseNestedMessage(*this); } } void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) { PW_DCHECK_PTR_EQ(nested.parent_, this, "CloseNestedMessage() called on the wrong Encoder parent"); // Make the nested encoder look like it has an open child to block writes for // the remainder of the object's life. nested.nested_field_number_ = kFirstReservedNumber; nested.parent_ = nullptr; // Temporarily cache the field number of the child so we can re-enable // writing to this encoder. uint32_t temp_field_number = nested_field_number_; nested_field_number_ = 0; // TODO(amontanez): If a submessage fails, we could optionally discard // it and continue happily. For now, we'll always invalidate the entire // encoder if a single submessage fails. status_.Update(nested.status_); if (!status_.ok()) { return; } if (varint::EncodedSize(nested.memory_writer_.bytes_written()) > config::kMaxVarintSize) { status_ = Status::OutOfRange(); return; } if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) { return; } status_ = WriteLengthDelimitedField(temp_field_number, nested.memory_writer_.WrittenData()); } Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) { PW_TRY(UpdateStatusForWrite( field_number, WireType::kVarint, varint::EncodedSize(value))); WriteVarint(FieldKey(field_number, WireType::kVarint)) .IgnoreError(); // TODO(b/242598609): Handle Status properly return WriteVarint(value); } Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number, ConstByteSpan data) { PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size())); status_.Update(WriteLengthDelimitedKeyAndLengthPrefix( field_number, data.size(), writer_)); PW_TRY(status_); if (Status status = writer_.Write(data); !status.ok()) { status_ = status; } return status_; } Status StreamEncoder::WriteLengthDelimitedFieldFromStream( uint32_t field_number, stream::Reader& bytes_reader, size_t num_bytes, ByteSpan stream_pipe_buffer) { PW_CHECK_UINT_GT( stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size"); PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes)); status_.Update( WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_)); PW_TRY(status_); // Stream data from `bytes_reader` to `writer_`. // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later // time. for (size_t bytes_written = 0; bytes_written < num_bytes;) { const size_t chunk_size_bytes = std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes()); const Result read_result = bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes); status_.Update(read_result.status()); PW_TRY(status_); status_.Update(writer_.Write(read_result.value())); PW_TRY(status_); bytes_written += read_result.value().size(); } return OkStatus(); } Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) { WireType type = data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64; PW_TRY(UpdateStatusForWrite(field_number, type, data.size())); WriteVarint(FieldKey(field_number, type)) .IgnoreError(); // TODO(b/242598609): Handle Status properly if (Status status = writer_.Write(data); !status.ok()) { status_ = status; } return status_; } Status StreamEncoder::WritePackedFixed(uint32_t field_number, span values, size_t elem_size) { if (values.empty()) { return status_; } PW_CHECK_NOTNULL(values.data()); PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t)); PW_TRY(UpdateStatusForWrite( field_number, WireType::kDelimited, values.size_bytes())); WriteVarint(FieldKey(field_number, WireType::kDelimited)) .IgnoreError(); // TODO(b/242598609): Handle Status properly WriteVarint(values.size_bytes()) .IgnoreError(); // TODO(b/242598609): Handle Status properly for (auto val_start = values.begin(); val_start != values.end(); val_start += elem_size) { // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as // little-endian for serialization. std::array data; if (endian::native == endian::little) { std::copy(val_start, val_start + elem_size, std::begin(data)); } else { std::reverse_copy(val_start, val_start + elem_size, std::begin(data)); } status_.Update(writer_.Write(span(data).first(elem_size))); PW_TRY(status_); } return status_; } Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number, WireType type, size_t data_size) { PW_CHECK(!nested_encoder_open()); PW_TRY(status_); if (!ValidFieldNumber(field_number)) { return status_ = Status::InvalidArgument(); } const Result field_size = SizeOfField(field_number, type, data_size); status_.Update(field_size.status()); PW_TRY(status_); if (field_size.value() > writer_.ConservativeWriteLimit()) { status_ = Status::ResourceExhausted(); } return status_; } Status StreamEncoder::Write(span message, span table) { PW_CHECK(!nested_encoder_open()); PW_TRY(status_); for (const auto& field : table) { // Calculate the span of bytes corresponding to the structure field to // read from. const auto values = message.subspan(field.field_offset(), field.field_size()); PW_CHECK(values.begin() >= message.begin() && values.end() <= message.end()); // If the field is using callbacks, interpret the input field accordingly // and allow the caller to provide custom handling. if (field.use_callback()) { const Callback* callback = reinterpret_cast*>( values.data()); PW_TRY(callback->Encode(*this)); continue; } switch (field.wire_type()) { case WireType::kFixed64: case WireType::kFixed32: { // Fixed fields call WriteFixed() for singular case and // WritePackedFixed() for repeated fields. PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32 ? sizeof(uint32_t) : sizeof(uint64_t)), "Mismatched message field type and size"); if (field.is_fixed_size()) { PW_CHECK(field.is_repeated(), "Non-repeated fixed size field"); if (static_cast( std::count(values.begin(), values.end(), std::byte{0})) < values.size()) { PW_TRY(WritePackedFixed( field.field_number(), values, field.elem_size())); } } else if (field.is_repeated()) { // The struct member for this field is a vector of a type // corresponding to the field element size. Cast to the correct // vector type so we're not performing type aliasing (except for // unsigned vs signed which is explicitly allowed). if (field.elem_size() == sizeof(uint64_t)) { const auto* vector = reinterpret_cast*>( values.data()); if (!vector->empty()) { PW_TRY(WritePackedFixed( field.field_number(), as_bytes(span(vector->data(), vector->size())), field.elem_size())); } } else if (field.elem_size() == sizeof(uint32_t)) { const auto* vector = reinterpret_cast*>( values.data()); if (!vector->empty()) { PW_TRY(WritePackedFixed( field.field_number(), as_bytes(span(vector->data(), vector->size())), field.elem_size())); } } } else if (field.is_optional()) { // The struct member for this field is a std::optional of a type // corresponding to the field element size. Cast to the correct // optional type so we're not performing type aliasing (except for // unsigned vs signed which is explicitly allowed), and write from // a temporary. if (field.elem_size() == sizeof(uint64_t)) { const auto* optional = reinterpret_cast*>(values.data()); if (optional->has_value()) { uint64_t value = optional->value(); PW_TRY( WriteFixed(field.field_number(), as_bytes(span(&value, 1)))); } } else if (field.elem_size() == sizeof(uint32_t)) { const auto* optional = reinterpret_cast*>(values.data()); if (optional->has_value()) { uint32_t value = optional->value(); PW_TRY( WriteFixed(field.field_number(), as_bytes(span(&value, 1)))); } } } else { PW_CHECK(values.size() == field.elem_size(), "Mismatched message field type and size"); if (static_cast( std::count(values.begin(), values.end(), std::byte{0})) < values.size()) { PW_TRY(WriteFixed(field.field_number(), values)); } } break; } case WireType::kVarint: { // Varint fields call WriteVarintField() for singular case and // WritePackedVarints() for repeated fields. PW_CHECK(field.elem_size() == sizeof(uint64_t) || field.elem_size() == sizeof(uint32_t) || field.elem_size() == sizeof(bool), "Mismatched message field type and size"); if (field.is_fixed_size()) { // The struct member for this field is an array of type corresponding // to the field element size. Cast to a span of the correct type over // the array so we're not performing type aliasing (except for // unsigned vs signed which is explicitly allowed). PW_CHECK(field.is_repeated(), "Non-repeated fixed size field"); if (static_cast( std::count(values.begin(), values.end(), std::byte{0})) == values.size()) { continue; } if (field.elem_size() == sizeof(uint64_t)) { PW_TRY(WritePackedVarints( field.field_number(), span(reinterpret_cast(values.data()), values.size() / field.elem_size()), field.varint_type())); } else if (field.elem_size() == sizeof(uint32_t)) { PW_TRY(WritePackedVarints( field.field_number(), span(reinterpret_cast(values.data()), values.size() / field.elem_size()), field.varint_type())); } else if (field.elem_size() == sizeof(bool)) { static_assert(sizeof(bool) == sizeof(uint8_t), "bool must be same size as uint8_t"); PW_TRY(WritePackedVarints( field.field_number(), span(reinterpret_cast(values.data()), values.size() / field.elem_size()), field.varint_type())); } } else if (field.is_repeated()) { // The struct member for this field is a vector of a type // corresponding to the field element size. Cast to the correct // vector type so we're not performing type aliasing (except for // unsigned vs signed which is explicitly allowed). if (field.elem_size() == sizeof(uint64_t)) { const auto* vector = reinterpret_cast*>( values.data()); if (!vector->empty()) { PW_TRY(WritePackedVarints(field.field_number(), span(vector->data(), vector->size()), field.varint_type())); } } else if (field.elem_size() == sizeof(uint32_t)) { const auto* vector = reinterpret_cast*>( values.data()); if (!vector->empty()) { PW_TRY(WritePackedVarints(field.field_number(), span(vector->data(), vector->size()), field.varint_type())); } } else if (field.elem_size() == sizeof(bool)) { static_assert(sizeof(bool) == sizeof(uint8_t), "bool must be same size as uint8_t"); const auto* vector = reinterpret_cast*>( values.data()); if (!vector->empty()) { PW_TRY(WritePackedVarints(field.field_number(), span(vector->data(), vector->size()), field.varint_type())); } } } else if (field.is_optional()) { // The struct member for this field is a std::optional of a type // corresponding to the field element size. Cast to the correct // optional type so we're not performing type aliasing (except for // unsigned vs signed which is explicitly allowed), and write from // a temporary. uint64_t value = 0; if (field.elem_size() == sizeof(uint64_t)) { if (field.varint_type() == VarintType::kUnsigned) { const auto* optional = reinterpret_cast*>( values.data()); if (!optional->has_value()) { continue; } value = optional->value(); } else { const auto* optional = reinterpret_cast*>( values.data()); if (!optional->has_value()) { continue; } value = field.varint_type() == VarintType::kZigZag ? varint::ZigZagEncode(optional->value()) : optional->value(); } } else if (field.elem_size() == sizeof(uint32_t)) { if (field.varint_type() == VarintType::kUnsigned) { const auto* optional = reinterpret_cast*>( values.data()); if (!optional->has_value()) { continue; } value = optional->value(); } else { const auto* optional = reinterpret_cast*>( values.data()); if (!optional->has_value()) { continue; } value = field.varint_type() == VarintType::kZigZag ? varint::ZigZagEncode(optional->value()) : optional->value(); } } else if (field.elem_size() == sizeof(bool)) { const auto* optional = reinterpret_cast*>(values.data()); if (!optional->has_value()) { continue; } value = optional->value(); } PW_TRY(WriteVarintField(field.field_number(), value)); } else { // The struct member for this field is a scalar of a type // corresponding to the field element size. Cast to the correct // type to retrieve the value before passing to WriteVarintField() // so we're not performing type aliasing (except for unsigned vs // signed which is explicitly allowed). PW_CHECK(values.size() == field.elem_size(), "Mismatched message field type and size"); uint64_t value = 0; if (field.elem_size() == sizeof(uint64_t)) { if (field.varint_type() == VarintType::kZigZag) { value = varint::ZigZagEncode( *reinterpret_cast(values.data())); } else if (field.varint_type() == VarintType::kNormal) { value = *reinterpret_cast(values.data()); } else { value = *reinterpret_cast(values.data()); } if (!value) { continue; } } else if (field.elem_size() == sizeof(uint32_t)) { if (field.varint_type() == VarintType::kZigZag) { value = varint::ZigZagEncode( *reinterpret_cast(values.data())); } else if (field.varint_type() == VarintType::kNormal) { value = *reinterpret_cast(values.data()); } else { value = *reinterpret_cast(values.data()); } if (!value) { continue; } } else if (field.elem_size() == sizeof(bool)) { value = *reinterpret_cast(values.data()); if (!value) { continue; } } PW_TRY(WriteVarintField(field.field_number(), value)); } break; } case WireType::kDelimited: { // Delimited fields are always a singular case because of the // inability to cast to a generic vector with an element of a certain // size (we always need a type). PW_CHECK(!field.is_repeated(), "Repeated delimited messages always require a callback"); if (field.nested_message_fields()) { // Nested Message. Struct member is an embedded struct for the // nested field. Obtain a nested encoder and recursively call Write() // using the fields table pointer from this field. auto nested_encoder = GetNestedEncoder(field.field_number(), /*write_when_empty=*/false); PW_TRY(nested_encoder.Write(values, *field.nested_message_fields())); } else if (field.is_fixed_size()) { // Fixed-length bytes field. Struct member is a std::array. // Call WriteLengthDelimitedField() to output it to the stream. PW_CHECK(field.elem_size() == sizeof(std::byte), "Mismatched message field type and size"); if (static_cast( std::count(values.begin(), values.end(), std::byte{0})) < values.size()) { PW_TRY(WriteLengthDelimitedField(field.field_number(), values)); } } else { // bytes or string field with a maximum size. Struct member is // pw::Vector for bytes or pw::InlineString<> for string. // Use the contents as a span and call WriteLengthDelimitedField() to // output it to the stream. PW_CHECK(field.elem_size() == sizeof(std::byte), "Mismatched message field type and size"); if (field.is_string()) { PW_TRY(WriteStringOrBytes>( field.field_number(), values.data())); } else { PW_TRY(WriteStringOrBytes>( field.field_number(), values.data())); } } break; } } } return status_; } } // namespace pw::protobuf