125 lines
4.4 KiB
C++
125 lines
4.4 KiB
C++
/*
|
|
* Copyright (C) 2022 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#define LOG_TAG "FlatbufferModelBuilder"
|
|
|
|
#include "FlatbufferModelBuilder.h"
|
|
|
|
#include <LegacyUtils.h>
|
|
|
|
#include "FlatbufferModelBuilderUtils.h"
|
|
#include "operation_converters/OperationConverterResolver.h"
|
|
|
|
namespace android {
|
|
namespace nn {
|
|
|
|
void FlatbufferModelBuilder::verifyModel(const tflite::Model* model) {
|
|
flatbuffers::Verifier verifier(mBuilder.GetBufferPointer(), mBuilder.GetSize());
|
|
CHECK(model != nullptr);
|
|
CHECK(model->Verify(verifier));
|
|
}
|
|
|
|
void FlatbufferModelBuilder::initializeBufferVector() {
|
|
mBufferVector.clear();
|
|
|
|
std::vector<uint8_t> emptyData;
|
|
auto emptyBuffer = tflite::CreateBufferDirect(mBuilder, &emptyData);
|
|
mBufferVector.push_back(emptyBuffer);
|
|
}
|
|
|
|
void FlatbufferModelBuilder::initializeOpCodeIndexForOperationType() {
|
|
mOpCodeIndexForOperationType.clear();
|
|
mOpCodeIndexForOperationType.resize(kNumberOfOperationTypes, -1);
|
|
}
|
|
|
|
std::vector<MetadataFlatbuffer> FlatbufferModelBuilder::createMetadataVector() {
|
|
std::vector<MetadataFlatbuffer> metadataVector;
|
|
for (uint32_t i = 0; i < mBufferVector.size(); i++) {
|
|
auto metadata = tflite::CreateMetadataDirect(mBuilder, std::to_string(i).c_str() /* name */,
|
|
i /* buffer */);
|
|
metadataVector.push_back(metadata);
|
|
}
|
|
return metadataVector;
|
|
}
|
|
|
|
Result<const tflite::Model*> FlatbufferModelBuilder::createTfliteModel() {
|
|
mModel = makeModel();
|
|
|
|
// Initialize and clear data structures
|
|
initializeBufferVector();
|
|
mOpCodesVector.clear();
|
|
initializeOpCodeIndexForOperationType();
|
|
|
|
// Generate subgraphs
|
|
auto subgraphsVector = NN_TRY(createSubGraphs());
|
|
|
|
auto metadataVector = createMetadataVector();
|
|
|
|
ModelFlatbuffer flatbufferModel = tflite::CreateModelDirect(
|
|
mBuilder, 3 /* version*/, &mOpCodesVector /* operator_codes */,
|
|
&subgraphsVector /* subgraphs */, nullptr /* description */,
|
|
&mBufferVector /* buffers */, nullptr /* metadata_buffer */,
|
|
&metadataVector /* metadata */);
|
|
mBuilder.Finish(flatbufferModel);
|
|
|
|
const tflite::Model* tfliteModel = tflite::GetModel(mBuilder.GetBufferPointer());
|
|
verifyModel(tfliteModel);
|
|
return tfliteModel;
|
|
}
|
|
|
|
Result<SubGraphFlatbuffer> FlatbufferModelBuilder::createSubGraphFlatbuffer(
|
|
const Model::Subgraph& subgraph) {
|
|
// TFLite does not support unspecified ranks in Operands
|
|
NN_TRY(checkAllTensorOperandsHaveSpecifiedRank(subgraph.operands));
|
|
// TFLite does not support dynamic shapes for subgrah output Operands
|
|
NN_TRY(checkNoSubgraphOutputOperandsHaveDynamicShape(subgraph.operands));
|
|
|
|
SubGraphContext context(&mModel, &subgraph, &mBuilder, &mOpCodesVector,
|
|
&mOpCodeIndexForOperationType, &mBufferVector);
|
|
for (const Operation& operation : subgraph.operations) {
|
|
const IOperationConverter* converter =
|
|
OperationConverterResolver::get()->findOperationConverter(operation.type);
|
|
NN_RET_CHECK(converter != nullptr)
|
|
<< "IOperationConverter not implemented for OperationType: " << operation.type;
|
|
|
|
NN_TRY(converter->convert(operation, &context));
|
|
}
|
|
|
|
for (uint32_t idx : subgraph.inputIndexes) {
|
|
context.addSubGraphInput(idx);
|
|
}
|
|
for (uint32_t idx : subgraph.outputIndexes) {
|
|
context.addSubGraphOutput(idx);
|
|
}
|
|
|
|
return context.finish();
|
|
}
|
|
|
|
Result<std::vector<SubGraphFlatbuffer>> FlatbufferModelBuilder::createSubGraphs() {
|
|
// We do not support control flow yet
|
|
NN_RET_CHECK(mModel.referenced.empty()) << "Control flow for multiple subgraphs not supported";
|
|
|
|
std::vector<SubGraphFlatbuffer> subGraphVector;
|
|
|
|
auto mainSubGraph = NN_TRY(createSubGraphFlatbuffer(mModel.main));
|
|
subGraphVector.push_back(mainSubGraph);
|
|
|
|
return subGraphVector;
|
|
}
|
|
|
|
} // namespace nn
|
|
} // namespace android
|