unplugged-system/packages/modules/OnDevicePersonalization/federatedcompute/proto/plan.proto

1378 lines
56 KiB
Protocol Buffer

/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package com.android.federatedcompute.proto;
import "google/protobuf/any.proto";
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/protobuf/saver.proto";
import "tensorflow/core/protobuf/struct.proto";
option java_package = "com.android.federatedcompute.proto";
option java_multiple_files = true;
option java_outer_classname = "PlanProto";
// Primitives
// ===========
// Represents an operation to save or restore from a checkpoint. Some
// instances of this message may only be used either for restore or for
// save, others for both directions. This is documented together with
// their usage.
//
// This op has four essential uses:
// 1. read and apply a checkpoint.
// 2. write a checkpoint.
// 3. read and apply from an aggregated side channel.
// 4. write to a side channel (grouped with write a checkpoint).
// We should consider splitting this into four separate messages.
message CheckpointOp {
// An optional standard saver def. If not provided, only the
// op(s) below will be executed. This must be a version 1 SaverDef.
tensorflow.SaverDef saver_def = 1;
// An optional operation to run before the saver_def is executed for
// restore.
string before_restore_op = 2;
// An optional operation to run after the saver_def has been
// executed for restore. If side_channel_tensors are provided, then
// they should be provided in a feed_dict to this op.
string after_restore_op = 3;
// An optional operation to run before the saver_def will be
// executed for save.
string before_save_op = 4;
// An optional operation to run after the saver_def has been
// executed for save. If there are side_channel_tensors, this op
// should be run after the side_channel_tensors have been fetched.
string after_save_op = 5;
// In addition to being saved and restored from a checkpoint, one can
// also save and restore via a side channel. The keys in this map are
// the names of the tensors transmitted by the side channel. These (key)
// tensors should be read off just before saving a SaveDef and used
// by the code that handles the side channel. Any variables provided this
// way should NOT be saved in the SaveDef.
//
// For restoring, the variables that are provided by the side channel
// are restored differently than those for a checkpoint. For those from
// the side channel, these should be restored by calling the before_restore_op
// with a feed dict whose keys are the restore_names in the SideChannel and
// whose values are the values to be restored.
map<string, SideChannel> side_channel_tensors = 6;
// An optional name of a tensor in to which a unique token for the current
// session should be written.
//
// This session identifier allows TensorFlow ops such as `ServeSlices` or
// `ExternalDataset` to refer to callbacks and other session-global objects
// registered before running the session.
string session_token_tensor_name = 7;
}
message SideChannel {
// A side channel whose variables are processed via SecureAggregation.
// This side channel implements aggregation via sum over a set of
// clients, so the restored tensor will be a sum of multiple clients
// inputs into the side channel. Hence this will restore during the
// read_aggregate_update restore, not the per-client read_update restore.
message SecureAggregand {
message Dimension {
int64 size = 1;
}
// Dimensions of the aggregand. This is used by the secure aggregation
// protocol in its early rounds, not as redundant info which could be
// obtained by reading the dimensions of the tensor itself.
repeated Dimension dimension = 3;
// The data type anticipated by the server-side graph.
tensorflow.DataType dtype = 4;
// SecureAggregation will compute sum modulo this modulus.
message FixedModulus {
uint64 modulus = 1;
}
// SecureAggregation will for each shard compute sum modulo m with m at
// least (1 + shard_size * (base_modulus - 1)), then aggregate
// shard results with non-modular addition. Here, shard_size is the number
// of clients in the shard.
//
// Note that the modulus for each shard will be greater than the largest
// possible (non-modular) sum of the inputs to that shard. That is,
// assuming each client has input on range [0, base_modulus), the result
// will be identical to non-modular addition (i.e. federated_sum).
//
// While any m >= (1 + shard_size * (base_modulus - 1)), the current
// implementation takes
// m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the
// smallest possible value of m that is also a power of 2. This choice is
// made because (a) it uses the same number of bits per vector entry as
// valid smaller m, using the current on-the-wire encoding scheme, and (b)
// it enables the underlying mask-generation PRNG to run in its most
// computationally efficient mode, which can be up to 2x faster.
message ModulusTimesShardSize {
uint64 base_modulus = 1;
}
oneof modulus_scheme {
// Bitwidth of the aggregand.
//
// This is the bitwidth of an input value (i.e. the bitwidth that
// quantization should target). The Secure Aggregation bitwidth (i.e.,
// the bitwidth of the *sum* of the input values) will be a function of
// this bitwidth and the number of participating clients, as negotiated
// with the server when the protocol is initiated.
//
// Deprecated; prefer fixed_modulus instead.
int32 quantized_input_bitwidth = 2 [deprecated = true];
FixedModulus fixed_modulus = 5;
ModulusTimesShardSize modulus_times_shard_size = 6;
}
reserved 1;
}
// What type of side channel is used.
oneof type {
SecureAggregand secure_aggregand = 1;
}
// When restoring the name of the tensor to restore to. This is the name
// (key) supplied in the feed_dict in the before_restore_op in order to
// restore the tensor provided by the side channel (which will be the
// value in the feed_dict).
string restore_name = 2;
}
// Container for a metric used by the internal toolkit.
message Metric {
// Name of an Op to run to read the value.
string variable_name = 1;
// A human-readable name for the statistic. Metric names are usually
// camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
// Must be 7-bit ASCII and under 122 characters.
string stat_name = 2;
// The human-readable name of another metric by which this metric should be
// normalized, if any. If empty, this Metric should be aggregated with simple
// summation. If not empty, the Metric is aggregated according to
// weighted_metric_sum = sum_i (metric_i * weight_i)
// weight_sum = sum_i weight_i
// average_metric_value = weighted_metric_sum / weight_sum
string weight_name = 3;
}
// Controls the format of output metrics users receive. Represents instructions
// for how metrics are to be output to users, controlling the end format of
// the metric users receive.
message OutputMetric {
// Metric name.
string name = 1;
oneof value_source {
// A metric representing one stat with aggregation type sum.
SumOptions sum = 2;
// A metric representing a ratio between metrics with aggregation
// type average.
AverageOptions average = 3;
// A metric that is not aggregated by the MetricReportAggregator or
// metrics_loader. This includes metrics like 'num_server_updates' that are
// aggregated in TensorFlow.
NoneOptions none = 4;
// A metric representing one stat with aggregation type only sample.
// Samples at most 101 clients' values.
OnlySampleOptions only_sample = 5;
}
// Iff True, the metric will be plotted in the default view of the
// task level Colab automatically.
oneof visualization_info {
bool auto_plot = 6 [deprecated = true];
VisualizationSpec plot_spec = 7;
}
}
message VisualizationSpec {
// Different allowable plot types.
enum VisualizationType {
NONE = 0;
DEFAULT_PLOT_FOR_TASK_TYPE = 1;
LINE_PLOT = 2;
LINE_PLOT_WITH_PERCENTILES = 3;
HISTOGRAM = 4;
}
// Defines the plot type to provide downstream.
VisualizationType plot_type = 1;
// The x-axis which to provide for the given metric. Must be the name of a
// metric or counter. Recommended x_axis options are source_round, round,
// or time.
string x_axis = 2;
// Iff True, metric will be displayed on a population level dashboard.
bool plot_on_population_dashboard = 3;
}
// A metric representing one stat with aggregation type sum.
message SumOptions {
// Name for corresponding Metric stat_name field.
string stat_name = 1;
// Iff True, a cumulative sum over rounds will be provided in addition to a
// sum per round for the value metric.
bool include_cumulative_sum = 2;
// Iff True, sample of at most 101 clients' values.
// Used to calculate quantiles in downstream visualization pipeline.
bool include_client_samples = 3;
}
// A metric representing a ratio between metrics with aggregation type average.
// Represents: numerator stat / denominator stat.
message AverageOptions {
// Numerator stat name pointing to corresponding Metric stat_name.
string numerator_stat_name = 1;
// Denominator stat name pointing to corresponding Metric stat_name.
string denominator_stat_name = 2;
// Name for corresponding Metric stat_name that is the ratio of the
// numerator stat / denominator stat.
string average_stat_name = 3;
// Iff True, sample of at most 101 client's values.
// Used to calculate quantiles in downstream visualization pipeline.
bool include_client_samples = 4;
}
// A metric representing one stat with aggregation type none.
message NoneOptions {
// Name for corresponding Metric stat_name field.
string stat_name = 1;
}
// A metric representing one stat with aggregation type only sample.
message OnlySampleOptions {
// Name for corresponding Metric stat_name field.
string stat_name = 1;
}
// Represents a data set. This is used for testing.
message Dataset {
// Represents the data set for one client.
message ClientDataset {
// A string identifying the client.
string client_id = 1;
// A list of serialized tf.Example protos.
repeated bytes example = 2;
// Represents a dataset whose examples are selected by an ExampleSelector.
message SelectedExample {
ExampleSelector selector = 1;
repeated bytes example = 2;
}
// A list of (selector, dataset) pairs. Used in testing some *TFF-based
// tasks* that require multiple datasets as client input, e.g., a TFF-based
// personalization eval task requires each client to provide at least two
// datasets: one for train, and the other for test.
repeated SelectedExample selected_example = 3;
}
// A list of client data.
repeated ClientDataset client_data = 1;
}
// Represents predicates over metrics - i.e., expectations. This is used in
// training/eval tests to encode metric names and values expected to be reported
// by a client execution.
message MetricTestPredicates {
// The value must lie in [lower_bound; upper_bound]. Can also be used for
// approximate matching (lower == value - epsilon; upper = value + epsilon).
message Interval {
double lower_bound = 1;
double upper_bound = 2;
}
// The value must be a real value as long as the value of the weight_name
// metric is non-zero. If the weight metric is zero, then it is acceptable for
// the value to be non-real.
message RealIfNonzeroWeight {
string weight_name = 1;
}
message MetricCriterion {
// Name of the metric.
string name = 1;
// FL training round this metric is expected to appear in.
int32 training_round_index = 2;
// If none of the following is set, no matching is performed; but the
// metric is still expected to be present (with whatever value).
oneof Criterion {
// The reported metric must be < lt.
float lt = 3;
// The reported metric must be > gt.
float gt = 4;
// The reported metric must be <= le.
float le = 5;
// The reported metric must be >= ge.
float ge = 6;
// The reported metric must be == eq.
float eq = 7;
// The reported metric must lie in the interval.
Interval interval = 8;
// The reported metric is not NaN or +/- infinity.
bool real = 9;
// The reported metric is real (i.e., not NaN or +/- infinity) if the
// value of an associated weight is not 0.
RealIfNonzeroWeight real_if_nonzero_weight = 10;
}
}
repeated MetricCriterion metric_criterion = 1;
reserved 2;
}
// Client Phase
// ============
// A `TensorflowSpec` that is executed on the client in a single `tf.Session`.
// In federated optimization, this will correspond to one `ServerPhase`.
message ClientPhase {
// A short CamelCase name for the ClientPhase.
string name = 2;
// Minimum number of clients in aggregation.
// In secure aggregation mode this is used to configure the protocol instance
// in a way that server can't learn aggregated values with number of
// participants lower than this number.
// Without secure aggregation server still respects this parameter,
// ensuring that aggregated values never leave server RAM unless they include
// data from (at least) specified number of participants.
int32 minimum_number_of_participants = 3;
// If populated, `io_router` must be specified.
oneof spec {
// A functional interface for the TensorFlow logic the client should
// perform.
TensorflowSpec tensorflow_spec = 4 [lazy = true];
// Spec for client plans that issue example queries and send the query
// results directly to an aggregator with no or little additional
// processing.
ExampleQuerySpec example_query_spec = 9 [lazy = true];
}
// The specification of the inputs coming either from customer apps
// (Local Compute) or the federated protocol (Federated Compute).
oneof io_router {
FederatedComputeIORouter federated_compute = 5 [lazy = true];
LocalComputeIORouter local_compute = 6 [lazy = true];
FederatedComputeEligibilityIORouter federated_compute_eligibility = 7
[lazy = true];
FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true];
}
reserved 1;
}
// TensorflowSpec message describes a single call into TensorFlow, including the
// expected input tensors that must be fed when making that call, which
// output tensors to be fetched, and any operations that have no output but must
// be run. The TensorFlow session will then use the input tensors to do some
// computation, generally reading from one or more datasets, and provide some
// outputs.
//
// Conceptually, client or server code uses this proto along with an IORouter
// to build maps of names to input tensors, vectors of output tensor names,
// and vectors of target nodes:
//
// CreateTensorflowArguments(
// TensorflowSpec& spec,
// IORouter& io_router,
// const vector<pair<string, Tensor>>* input_tensors,
// const vector<string>* output_tensor_names,
// const vector<string>* target_node_names);
//
// Where `input_tensor`, `output_tensor_names` and `target_node_names`
// correspond to the arguments of TensorFlow C++ API for
// `tensorflow::Session:Run()`, and the client executes only a single
// invocation.
//
// Note: the execution engine never sees any concepts related to the federated
// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors
// in, tensors out" interface. New aggregation methods can be added without
// having to modify the execution engine / TensorflowSpec message, instead they
// should modify the IORouter messages.
//
// Note: both `input_tensor_specs` and `output_tensor_specs` are full
// `tensorflow.TensorSpecProto` messages, though TensorFlow technically
// only requires the names to feed the values into the session. The additional
// dtypes/shape information must always be included in case the runtime
// executing this TensorflowSpec wants to perform additional, optional static
// assertions. The runtimes however are free to ignore the dtype/shapes and only
// rely on the names if so desired.
//
// Assertions:
// - all names in `input_tensor_specs`, `output_tensor_specs`, and
// `target_node_names` must appear in the serialized GraphDef where
// the TF execution will be invoked.
// - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise
// there is nothing to execute in the graph.
message TensorflowSpec {
// The name of a tensor into which a unique token for the current session
// should be written. The corresponding tensor is a scalar string tensor and
// is separate from `input_tensors` as there is only one.
//
// A session token allows TensorFlow ops such as `ServeSlices` or
// `ExternalDataset` to refer to callbacks and other session-global objects
// registered before running the session. In the `ExternalDataset` case, a
// single dataset_token is valid for multiple `tf.data.Dataset` objects as
// the token can be thought of as a handle to a dataset factory.
string dataset_token_tensor_name = 1;
// TensorSpecs of inputs which will be passed to TF.
//
// Corresponds to the `feed_dict` parameter of `tf.Session.run()` in
// TensorFlow's Python API, excluding the dataset_token listed above.
//
// Assertions:
// - All the tensor names designated as inputs in the corresponding IORouter
// must be listed (otherwise the IORouter input work is unused).
// - All placeholders in the TF graph must be listed here, with the
// exception of the dataset_token which is explicitly set above (otherwise
// TensorFlow will fail to execute).
repeated tensorflow.TensorSpecProto input_tensor_specs = 2;
// TensorSpecs that should be fetched from TF after execution.
//
// Corresponds to the `fetches` parameter of `tf.Session.run()` in
// TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++
// API.
//
// Assertions:
// - The set of tensor names here must strictly match the tensor names
// designated as outputs in the corresponding IORouter (if any exist).
repeated tensorflow.TensorSpecProto output_tensor_specs = 3;
// Node names in the graph that should be executed, but the output not
// returned.
//
// Corresponds to the `fetches` parameter of `tf.Session.run()` in
// TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++
// API.
//
// This is intended for use with operations that do not produce tensors, but
// nonetheless are required to run (e.g. serializing checkpoints).
repeated string target_node_names = 4;
// Map of Tensor names to constant inputs.
// Note: tensors specified via this message should not be included in
// input_tensor_specs.
map<string, tensorflow.TensorProto> constant_inputs = 5;
}
// ExampleQuerySpec message describes client execution that issues example
// queries and sends the query results directly to an aggregator with no or
// little additional processing.
// This message describes one or more example store queries that perform the
// client side analytics computation in C++. The corresponding output vectors
// will be converted into the expected federated protocol output format.
// This must be used in conjunction with the `FederatedExampleQueryIORouter`.
message ExampleQuerySpec {
message OutputVectorSpec {
// The output vector name.
string vector_name = 1;
// Supported data types for the vector of information.
enum DataType {
UNSPECIFIED = 0;
INT32 = 1;
INT64 = 2;
BOOL = 3;
FLOAT = 4;
DOUBLE = 5;
BYTES = 6;
STRING = 7;
}
// The data type for each entry in the vector.
DataType data_type = 2;
}
message ExampleQuery {
// The `ExampleSelector` to issue the query with.
ExampleSelector example_selector = 1;
// Indicates that the query returns vector data and must return a single
// ExampleQueryResult result containing a VectorData entry matching each
// OutputVectorSpec.vector_name.
//
// If the query instead returns no result, then it will be treated as is if
// an error was returned. In that case, or if the query explicitly returns
// an error, then the client will abort its session.
//
// The keys in the map are the names the vectors should be aggregated under,
// and must match the keys in FederatedExampleQueryIORouter.aggregations.
map<string, OutputVectorSpec> output_vector_specs = 2;
}
// The queries to run.
repeated ExampleQuery example_queries = 1;
}
// The input and output router for Federated Compute plans.
//
// This proto is the glue between the federated protocol and the TensorFlow
// execution engine. This message describes how to prepare data coming from the
// incoming `CheckinResponse` (defined in
// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what
// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on
// the server).
//
// TODO(team) we could replace `input_checkpoint_file_tensor_name` with
// an `input_tensors` field, which would then be a tensor that contains the
// input TensorProtos directly and skipping disk I/O, rather than referring to a
// checkpoint file path.
message FederatedComputeIORouter {
// ===========================================================================
// Inputs
// ===========================================================================
// The name of the scalar string tensor that is fed the file path to the
// initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint).
//
// The federated protocol code would copy the `CheckinResponse`'s initial
// checkpoint to a temporary file and then pass that file path through this
// tensor.
//
// Ops may be added to the client graph that take this tensor as input and
// reads the path.
//
// This field is optional. It may be omitted if the client graph does not use
// an initial checkpoint.
string input_filepath_tensor_name = 1;
// The name of the scalar string tensor that is fed the file path to which
// client work should serialize the bytes to send back to the server.
//
// The federated protocol code generates a temporary file and passes the file
// path through this tensor.
//
// Ops may be be added to the client graph that use this tensor as an argument
// to write files (e.g. writing checkpoints to disk).
//
// This field is optional. It must be omitted if the client graph does not
// generate any output files (e.g. when all output tensors of `TensorflowSpec`
// use Secure Aggregation). If this field is not set, then the `ReportRequest`
// message in the federated protocol will not have the
// `Report.update_checkpoint` field set. This absence of a value here can be
// used to validate that the plan only uses Secure Aggregation.
//
// Conversely, if this field is set and executing the associated
// TensorflowSpec does not write to the path is indication of an internal
// framework error. The runtime should notify the caller that the computation
// was setup incorrectly.
string output_filepath_tensor_name = 2;
// ===========================================================================
// Outputs
// ===========================================================================
// Describes which output tensors should be aggregated using an aggregation
// protocol, and the configuration for those protocols.
//
// Assertions:
// - All keys must exist in the associated `TensorflowSpec` as
// `output_tensor_specs.name` values.
map<string, AggregationConfig> aggregations = 3;
}
// The input and output router for client plans that do not use TensorFlow.
//
// This proto is the glue between the federated protocol and the example query
// execution engine, describing how the query results should ultimately be
// aggregated.
message FederatedExampleQueryIORouter {
// Describes how each output vector should be aggregated using an aggregation
// protocol, and the configuration for those protocols.
// Keys must match the keys in ExampleQuerySpec.output_vector_specs.
// Note that currently only the TFV1CheckpointAggregation config is supported.
map<string, AggregationConfig> aggregations = 1;
}
// The specification for how to aggregate the associated tensor across clients
// on the server.
message AggregationConfig {
oneof protocol_config {
// Indicates that the given output tensor should be processed using Secure
// Aggregation, using the specified config options.
SecureAggregationConfig secure_aggregation = 2;
// Note: in the future we could add a `SimpleAggregationConfig` to add
// support for simple aggregation without writing to an intermediate
// checkpoint file first.
// Indicates that the given output tensor or vector (e.g. as produced by an
// ExampleQuerySpec) should be placed in an output TF v1 checkpoint.
//
// Currently only ExampleQuerySpec output vectors are supported by this
// aggregation type (i.e. it cannot be used with TensorflowSpec output
// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
// its corresponding data type.
TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3;
}
}
// Parameters for the SecAgg protocol (go/secagg).
//
// Currently only the server uses the SecAgg parameters, so we only use this
// message to signify usage of SecAgg.
message SecureAggregationConfig {}
// Parameters for the TFV1 Checkpoint Aggregation protocol.
//
// Currently only ExampleQuerySpec output vectors are supported by this
// aggregation type (i.e. it cannot be used with TensorflowSpec output
// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
// its corresponding data type.
message TFV1CheckpointAggregation {}
// The input and output router for eligibility-computing plans. These plans
// compute which other plans a client is eligible to run, and are returned by
// clients via a `EligibilityEvalCheckinResponse` (defined in
// fcp/protos/federated_api.proto).
message FederatedComputeEligibilityIORouter {
// The name of the scalar string tensor that is fed the file path to the
// initial checkpoint (e.g. as provided via
// `EligibilityEvalPayload.init_checkpoint`).
//
// For more detail see the
// `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same
// semantics.
//
// This field is optional. It may be omitted if the client graph does not use
// an initial checkpoint.
//
// This tensor name must exist in the associated
// `TensorflowSpec.input_tensor_specs` list.
string input_filepath_tensor_name = 1;
// Name of the output tensor (a string scalar) containing the serialized
// `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The
// client code will parse this proto and place it in the
// `task_eligibility_info` field of the subsequent `CheckinRequest`.
//
// This tensor name must exist in the associated
// `TensorflowSpec.output_tensor_specs` list.
string task_eligibility_info_tensor_name = 2;
}
// The input and output router for Local Compute plans.
//
// This proto is the glue between the customers app and the TensorFlow
// execution engine. This message describes how to prepare data coming from the
// customer app (e.g. the input directory the app setup), and the temporary,
// scratch output directory that will be notified to the customer app upon
// completion of `TensorflowSpec`.
message LocalComputeIORouter {
// ===========================================================================
// Inputs
// ===========================================================================
// The name of the placeholder tensor representing the input resource path(s).
// It can be a single input directory or file path (in this case the
// `input_dir_tensor_name` is populated) or multiple input resources
// represented as a map from names to input directories or file paths (in this
// case the `multiple_input_resources` is populated).
//
// In the multiple input resources case, the placeholder tensors are
// represented as a map: the keys are the input resource names defined by the
// users when constructing the `LocalComputation` Python object, and the
// values are the corresponding placeholder tensor names created by the local
// computation plan builder.
//
// Apps will have the ability to create contracts between their Android code
// and `LocalComputation` toolkit code to place files inside the input
// resource paths with known names (Android code) and create graphs with ops
// to read from these paths (file names can be specified in toolkit code).
oneof input_resource {
string input_dir_tensor_name = 1;
// Directly using the `map` field is not allowed in `oneof`, so we have to
// wrap it in a new message.
MultipleInputResources multiple_input_resources = 3;
}
// Scalar string tensor name that will contain the output directory path.
//
// The provided directory should be considered temporary scratch that will be
// deleted, not persisted. It is the responsibility of the calling app to
// move the desired files to a permanent location once the client returns this
// directory back to the calling app.
string output_dir_tensor_name = 2;
// ===========================================================================
// Outputs
// ===========================================================================
// NOTE: LocalCompute has no outputs other than what the client graph writes
// to `output_dir` specified above.
}
// Describes the multiple input resources in `LocalComputeIORouter`.
message MultipleInputResources {
// The keys are the input resource names (defined by the users when
// constructing the `LocalComputation` Python object), and the values are the
// corresponding placeholder tensor names created by the local computation
// plan builder.
map<string, string> input_resource_tensor_name_map = 1;
}
// Describes a queue to which input is fed.
message AsyncInputFeed {
// The op for enqueuing an example input.
string enqueue_op = 1;
// The input placeholders for the enqueue op.
repeated string enqueue_params = 2;
// The op for closing the input queue.
string close_op = 3;
// Whether the work that should be fed asynchronously is the data itself
// or a description of where that data lives.
bool feed_values_are_data = 4;
}
message DatasetInput {
// Initializer of iterator corresponding to tf.data.Dataset object which
// handles the input data. Stores name of an op in the graph.
string initializer = 1;
// Placeholders necessary to initialize the dataset.
DatasetInputPlaceholders placeholders = 2;
// Batch size to be used in tf.data.Dataset.
int32 batch_size = 3;
}
message DatasetInputPlaceholders {
// Name of placeholder corresponding to filename(s) of SSTable(s) to read data
// from.
string filename = 1;
// Name of placeholder corresponding to key_prefix initializing the
// SSTableDataset. Note the value fed should be unique user id, not a prefix.
string key_prefix = 2;
// Name of placeholder corresponding to number of rounds the local training
// should be run for.
string num_epochs = 3;
// Name of placeholder corresponding to batch size.
string batch_size = 4;
}
// Specifies an example selection procedure.
message ExampleSelector {
// Selection criteria following a contract agreed upon between client and
// model designers.
google.protobuf.Any criteria = 1;
// A URI identifying the example collection to read from. Format should adhere
// to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments
// should adhere to the following rules:
// - The scheme ${COLLECTION} should be one of:
// - "app" for app-hosted example
// - "simulation" for collections not connected to an app (e.g., if used
// purely for simulation)
// - The authority ${APP_NAME} identifies the owner of the example
// collection and should be either the app's package name, or be left empty
// (which means "the current app package name").
// - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with
// a forward slash ("/").
// - The query and fragment are currently not used, but they may become used
// for something in the future. To keep open that possibility they must
// currently be left empty.
//
// Example: "app://com.google.some.app/someCollection/name"
// identifies the collection "/someCollection/name" owned and hosted by the
// app with package name "com.google.some.app".
//
// Example: "app:/someCollection/name" or "app:///someCollection/name"
// both identify the collection "/someCollection/name" owned and hosted by the
// app associated with the training job in which this URI appears.
//
// The path will not be interpreted by the runtime, and will be passed to the
// example collection implementation for interpretation. Thus, in the case of
// app-hosted example stores, the path segment's interpretation is a contract
// between the app's example store developers, and the app's model designers.
//
// If an `app://` URI is set, then the `TrainerOptions` collection name must
// not be set.
string collection_uri = 2;
// Resumption token following a contract agreed upon between client and
// model designers.
google.protobuf.Any resumption_token = 3;
}
// Selector for slices to fetch as part of a `federated_select` operation.
message SlicesSelector {
// The string ID under which the slices are served.
//
// This value must have been returned by a previous call to the `serve_slices`
// op run during the `write_client_init` operation.
string served_at_id = 1;
// The indices of slices to fetch.
repeated int32 keys = 2;
}
// Represents slice data to be served as part of a `federated_select` operation.
// This is used for testing.
message SlicesTestDataset {
// The test data to use. The keys map to the `SlicesSelector.served_at_id`
// field. E.g. test slice data for a slice with `served_at_id`="foo" and
// `keys`=2 would be store in `dataset["foo"].slice_data[2]`.
map<string, SlicesTestData> dataset = 1;
}
message SlicesTestData {
// The test slice data to serve. Each entry's index corresponds to the slice
// key it is the test data for.
repeated bytes slice_data = 2;
}
// Server Phase V2
// ===============
// Represents a server phase with three distinct components: pre-broadcast,
// aggregation, and post-aggregation.
//
// The pre-broadcast and post-aggregation components are described with
// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec
// messages, respectively. These messages in combination with the server
// IORouter messages specify how to set up a single TF sess.run call for each
// component.
//
// The pre-broadcast logic is obtained by transforming the server_prepare TFF
// computation in the DistributeAggregateForm. It takes the server state as
// input, and it generates the checkpoint to broadcast to the clients and
// potentially an intermediate server state. The intermediate server state may
// be used by the aggregation and post-aggregation logic.
//
// The aggregation logic represents the aggregation of client results at the
// server and is described using a list of ServerAggregationConfig messages.
// Each ServerAggregationConfig message describes a single aggregation operation
// on a set of input/output tensors. The input tensors may represent parts of
// either the client results or the intermediate server state. These messages
// are obtained by transforming the client_to_server_aggregation TFF computation
// in the DistributeAggregateForm.
//
// The post-aggregation logic is obtained by transforming the server_result TFF
// computation in the DistributeAggregateForm. It takes the intermediate server
// state and the aggregated client results as input, and it generates the new
// server state and potentially other server-side output.
//
// Note that while a ServerPhaseV2 message can be generated for all types of
// intrinsics, it is currently only compatible with the ClientPhase message if
// the aggregations being used are exclusively federated_sum (not SecAgg). If
// this compatibility requirement is satisfied, it is also valid to run the
// aggregation portion of this ServerPhaseV2 message alongside the pre- and
// post-aggregation logic from the original ServerPhase message. Ultimately,
// we expect the full ServerPhaseV2 message to be run and the ServerPhase
// message to be deprecated.
message ServerPhaseV2 {
// A short CamelCase name for the ServerPhaseV2.
string name = 1;
// A functional interface for the TensorFlow logic the server should perform
// prior to the server-to-client broadcast. This should be used with the
// TensorFlow graph defined in server_graph_prepare_bytes.
TensorflowSpec tensorflow_spec_prepare = 3;
// The specification of inputs needed by the server_prepare TF logic.
oneof server_prepare_io_router {
ServerPrepareIORouter prepare_router = 4;
}
// A list of client-to-server aggregations to perform.
repeated ServerAggregationConfig aggregations = 2;
// A functional interface for the TensorFlow logic the server should perform
// post-aggregation. This should be used with the TensorFlow graph defined
// in server_graph_result_bytes.
TensorflowSpec tensorflow_spec_result = 5;
// The specification of inputs and outputs needed by the server_result TF
// logic.
oneof server_result_io_router {
ServerResultIORouter result_router = 6;
}
}
// Routing for server_prepare graph
message ServerPrepareIORouter {
// The name of the scalar string tensor in the server_prepare TF graph that
// is fed the filepath to the initial server state checkpoint. The
// server_prepare logic reads from this filepath.
string prepare_server_state_input_filepath_tensor_name = 1;
// The name of the scalar string tensor in the server_prepare TF graph that
// is fed the filepath where the client checkpoint should be stored. The
// server_prepare logic writes to this filepath.
string prepare_output_filepath_tensor_name = 2;
// The name of the scalar string tensor in the server_prepare TF graph that
// is fed the filepath where the intermediate state checkpoint should be
// stored. The server_prepare logic writes to this filepath. The intermediate
// state checkpoint will be consumed by both the logic used to set parameters
// for aggregation and the post-aggregation logic.
string prepare_intermediate_state_output_filepath_tensor_name = 3;
}
// Routing for server_result graph
message ServerResultIORouter {
// The name of the scalar string tensor in the server_result TF graph that is
// fed the filepath to the intermediate state checkpoint. The server_result
// logic reads from this filepath.
string result_intermediate_state_input_filepath_tensor_name = 1;
// The name of the scalar string tensor in the server_result TF graph that is
// fed the filepath to the aggregated client result checkpoint. The
// server_result logic reads from this filepath.
string result_aggregate_result_input_filepath_tensor_name = 2;
// The name of the scalar string tensor in the server_result TF graph that is
// fed the filepath where the updated server state should be stored. The
// server_result logic writes to this filepath.
string result_server_state_output_filepath_tensor_name = 3;
}
// Represents a single aggregation operation, combining one or more input
// tensors from a collection of clients into one or more output tensors on the
// server.
message ServerAggregationConfig {
// The uri of the aggregation intrinsic (e.g. 'federated_sum').
string intrinsic_uri = 1;
// Describes an argument to the aggregation operation.
message IntrinsicArg {
oneof arg {
// Refers to a tensor within the checkpoint provided by each client.
tensorflow.TensorSpecProto input_tensor = 2;
// Refers to a tensor within the intermediate server state checkpoint.
tensorflow.TensorSpecProto state_tensor = 3;
}
}
// List of arguments for the aggregation operation. The arguments can be
// dependent on client data (in which case they must be retrieved from
// clients) or they can be independent of client data (in which case they
// can be configured server-side). For now we assume all client-independent
// arguments are constants. The arguments must be in the order expected by
// the server.
repeated IntrinsicArg intrinsic_args = 4;
// List of server-side outputs produced by the aggregation operation.
repeated tensorflow.TensorSpecProto output_tensors = 5;
}
// Server Phase
// ============
// Represents a server phase which implements TF-based aggregation of multiple
// client updates.
//
// There are two different modes of aggregation that are described
// by the values in this message. The first is aggregation that is
// coming from coordinated sets of clients. This includes aggregation
// done via checkpoints from clients or aggregation done over a set
// of clients by a process like secure aggregation. The results of
// this first aggregation are saved to intermediate aggregation
// checkpoints. The second aggregation then comes from taking
// these intermediate checkpoints and aggregating over them.
//
// These two different modes of aggregation are done on different
// servers, the first in the 'L1' servers and the second in the
// 'L2' servers, so we use this nomenclature to describe these
// phases below.
//
// The ServerPhase message is currently in the process of being replaced by the
// ServerPhaseV2 message as we switch the plan building pipeline to use
// DistributeAggregateForm instead of MapReduceForm. During the migration
// process, we may generate both messages and use components from either
// message during execution.
//
message ServerPhase {
// A short CamelCase name for the ServerPhase.
string name = 8;
// ===========================================================================
// L1 "Intermediate" Aggregation.
//
// This is the initial aggregation that creates partial aggregates from client
// results. L1 Aggregation may be run on many different instances.
//
// Pre-condition:
// The execution environment has loaded the graph from `server_graph_bytes`.
// 1. Initialize the phase.
//
// Operation to run before the first aggregation happens.
// For instance, clears the accumulators so that a new aggregation can begin.
string phase_init_op = 1;
// 2. For each client in set of clients:
// a. Restore variables from the client checkpoint.
//
// Loads a checkpoint from a single client written via
// `FederatedComputeIORouter.output_filepath_tensor_name`. This is done once
// for every client checkpoint in a round.
CheckpointOp read_update = 3;
// b. Aggregate the data coming from the client checkpoint.
//
// An operation that aggregates the data from read_update.
// Generally this will add to accumulators and it may leverage internal data
// inside the graph to adjust the weights of the Tensors.
//
// Executed once for each `read_update`, to (for example) update accumulator
// variables using the values loaded during `read_update`.
string aggregate_into_accumulators_op = 4;
// 3. After all clients have been aggregated, possibly restore
// variables that have been aggregated via a separate process.
//
// Optionally restores variables where aggregation is done across
// an entire round of client data updates. In contrast to `read_update`,
// which restores once per client, this occurs after all clients
// in a round have been processed. This allows, for example, side
// channels where aggregation is done by a separate process (such
// as in secure aggregation), in which the side channel aggregated
// tensor is passed to the `before_restore_op` which ensure the
// variables are restored properly. The `after_restore_op` will then
// be responsible for performing the accumulation.
//
// Note that in current use this should not have a SaverDef, but
// should only be used for side channels.
CheckpointOp read_aggregated_update = 10;
// 4. Write the aggregated variables to an intermediate checkpoint.
//
// We require that `aggregate_into_accumulators_op` is associative and
// commutative, so that the aggregates can be computed across
// multiple TensorFlow sessions.
// As an example, say we are computing the sum of 5 client updates:
// A = X1 + X2 + X3 + X4 + X5
// We can always do this in one session by calling `read_update`j and
// `aggregate_into_accumulators_op` once for each client checkpoint.
//
// Alternatively, we could compute:
// A1 = X1 + X2 in one TensorFlow session, and
// A2 = X3 + X4 + X5 in a different session.
// Each of these sessions can then write their accumulator state
// with the `write_intermediate_update` CheckpointOp, and a yet another third
// session can then call `read_intermediate_update` and
// `aggregate_into_accumulators_op` on each of these checkpoints to compute:
// A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5).
CheckpointOp write_intermediate_update = 7;
// End L1 "Intermediate" Aggregation.
// ===========================================================================
// ===========================================================================
// L2 Aggregation and Coordinator.
//
// This aggregates intermediate checkpoints from L1 Aggregation and performs
// the finalizing of the update. Unlike L1 there will only be one instance
// that does this aggregation.
// Pre-condition:
// The execution environment has loaded the graph from `server_graph_bytes`
// and restored the global model using `server_savepoint` from the parent
// `Plan` message.
// 1. Initialize the phase.
//
// This currently re-uses the `phase_init_op` from L1 aggregation above.
// 2. Write a checkpoint that can be sent to the client.
//
// Generates a checkpoint to be sent to the client, to be read by
// `FederatedComputeIORouter.input_filepath_tensor_name`.
CheckpointOp write_client_init = 2;
// 3. For each intermediate checkpoint:
// a. Restore variables from the intermediate checkpoint.
//
// The corresponding read checkpoint op to the write_intermediate_update.
// This is used instead of read_update for intermediate checkpoints because
// the format of these updates may be different than those used in updates
// from clients (which may, for example, be compressed).
CheckpointOp read_intermediate_update = 9;
// b. Aggregate the data coming from the intermediate checkpoint.
//
// An operation that aggregates the data from `read_intermediate_update`.
// Generally this will add to accumulators and it may leverage internal data
// inside the graph to adjust the weights of the Tensors.
string intermediate_aggregate_into_accumulators_op = 11;
// 4. Write the aggregated intermediate variables to a checkpoint.
//
// This is used for downstream, cross-round aggregation of metrics.
// These variables will be read back into a session with
// read_intermediate_update.
//
// Tasks which do not use FL metrics may unset the CheckpointOp.saver_def
// to disable writing accumulator checkpoints.
CheckpointOp write_accumulators = 12;
// 5. Finalize the round.
//
// This can include:
// - Applying the update aggregated from the intermediate checkpoints to the
// global model and other updates to cross-round state variables.
// - Computing final round metric values (e.g. the `report` of a
// `tff.federated_aggregate`).
string apply_aggregrated_updates_op = 5;
// 5. Fetch the server aggregated metrics.
//
// A list of names of metric variables to fetch from the TensorFlow session.
repeated Metric metrics = 6;
// 6. Serialize the updated server state (e.g. the coefficients of the global
// model in FL) using `server_savepoint` in the parent `Plan` message.
// End L2 Aggregation.
// ===========================================================================
}
// Represents the server phase in an eligibility computation.
//
// This phase produces a checkpoint to be sent to clients. This checkpoint is
// then used as an input to the clients' task eligibility computations.
// This phase *does not include any aggregation.*
message ServerEligibilityComputationPhase {
// A short CamelCase name for the ServerEligibilityComputationPhase.
string name = 1;
// The names of the TensorFlow nodes to run in order to produce output.
repeated string target_node_names = 2;
// The specification of inputs and outputs to the TensorFlow graph.
oneof server_eligibility_io_router {
TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true];
}
}
// Represents the inputs and outputs of a `ServerEligibilityComputationPhase`
// which takes a single `TaskEligibilityContext` as input.
message TEContextServerEligibilityIORouter {
// The name of the scalar string tensor that must be fed a serialized
// `TaskEligibilityContext`.
string context_proto_input_tensor_name = 1;
// The name of the scalar string tensor that must be fed the path to which
// the server graph should write the checkpoint file to be sent to the client.
string output_filepath_tensor_name = 2;
}
// Plan
// =====
// Represents the overall plan for performing federated optimization or
// personalization, as handed over to the production system. This will
// typically be split down into individual pieces for different production
// parts, e.g. server and client side.
// NEXT_TAG: 15
message Plan {
reserved 1, 3, 5;
// The actual type of the server_*_graph_bytes fields below is expected to be
// tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form
// for two reasons.
// 1) We may use execution engines other than TensorFlow.
// 2) We wish to avoid the cost of deserialized and re-serializing large
// graphs, in the Federated Learning service.
// While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes,
// server_graph_prepare_bytes, and server_graph_result_bytes may all be set.
// If we're using a MapReduceForm-based server implementation, only
// server_graph_bytes will be used. If we're using a DistributeAggregateForm-
// based server implementation, only server_graph_prepare_bytes and
// server_graph_result_bytes will be used.
// Optional. The TensorFlow graph used for all server processing described by
// ServerPhase. For personalization, this will not be set.
google.protobuf.Any server_graph_bytes = 7;
// Optional. The TensorFlow graph used for all server processing described by
// ServerPhaseV2.tensorflow_spec_prepare.
google.protobuf.Any server_graph_prepare_bytes = 13;
// Optional. The TensorFlow graph used for all server processing described by
// ServerPhaseV2.tensorflow_spec_result.
google.protobuf.Any server_graph_result_bytes = 14;
// A savepoint to sync the server checkpoint with a persistent
// storage system. The storage initially holds a seeded checkpoint
// which can subsequently read and updated by this savepoint.
// Optional-- not present in eligibility computation plans (those with a
// ServerEligibilityComputationPhase). This is used in conjunction with
// ServerPhase only.
CheckpointOp server_savepoint = 2;
// Required. The TensorFlow graph that describes the TensorFlow logic a client
// should perform. It should be consistent with the `TensorflowSpec` field in
// the `client_phase`. The actual type is expected to be tensorflow.GraphDef.
// The TensorFlow graph is stored in serialized form for two reasons.
// 1) We may use execution engines other than TensorFlow.
// 2) We wish to avoid the cost of deserialized and re-serializing large
// graphs, in the Federated Learning service.
google.protobuf.Any client_graph_bytes = 8;
// Optional. The FlatBuffer used for TFLite training.
// It contains the same model information as the client_graph_bytes, but with
// a different format.
bytes client_tflite_graph_bytes = 12;
// A pair of client phase and server phase which are processed in
// sync. The server execution defines how the results of a client
// phase are aggregated, and how the checkpoints for clients are
// generated.
message Phase {
// Required. The client phase.
ClientPhase client_phase = 1;
// Optional. Server phase for TF-based aggregation; not provided for
// personalization or eligibility tasks.
ServerPhase server_phase = 2;
// Optional. Server phase for native aggregation; only provided for tasks
// that have enabled the corresponding flag.
ServerPhaseV2 server_phase_v2 = 4;
// Optional. Only provided for eligibility tasks.
ServerEligibilityComputationPhase server_eligibility_phase = 3;
}
// A pair of client and server computations to run.
repeated Phase phase = 4;
// Metrics that are persistent across different phases. This
// includes, for example, counters that track how much work of
// different kinds has been done.
repeated Metric metrics = 6;
// Describes how metrics in both the client and server phases should be
// aggregated.
repeated OutputMetric output_metrics = 10;
// Version of the plan:
// version == 0 - Old plan without version field, containing b/65131070
// version >= 1 - plan supports multi-shard aggregation mode (L1/L2)
int32 version = 9;
// A TensorFlow ConfigProto packed in an Any.
//
// If this field is unset, if the Any proto is set but empty, or if the Any
// proto is populated with an empty ConfigProto (i.e. its `type_url` field is
// set, but the `value` field is empty) then the client implementation may
// choose a set of configuration parameters to provide to TensorFlow by
// default.
//
// In all other cases this field must contain a valid packed ConfigProto
// (invalid values will result in an error at execution time), and in this
// case the client will not provide any other configuration parameters by
// default.
google.protobuf.Any tensorflow_config_proto = 11;
}
// Represents a client part of the plan of federated optimization.
// This also used to describe a client-only plan for standalone on-device
// training, known as personalization.
// NEXT_TAG: 6
message ClientOnlyPlan {
reserved 3;
// The graph to use for training, in binary form.
bytes graph = 1;
// Optional. The flatbuffer used for TFLite training.
// Whether "graph" or "tflite_graph" is used for training is up to the client
// code to allow for a flag-controlled a/b rollout.
bytes tflite_graph = 5;
// The client phase to execute.
ClientPhase phase = 2;
// A TensorFlow ConfigProto.
google.protobuf.Any tensorflow_config_proto = 4;
}
// Represents the cross round aggregation portion for user defined measurements.
// This is used by tools that process / analyze accumulator checkpoints
// after a round of computation, to achieve aggregation beyond a round.
message CrossRoundAggregationExecution {
// Operation to run before reading accumulator checkpoint.
string init_op = 1;
// Reads accumulator checkpoint.
CheckpointOp read_aggregated_update = 2;
// Operation to merge loaded checkpoint into accumulator.
string merge_op = 3;
// Reads and writes the final aggregated accumulator vars.
CheckpointOp read_write_final_accumulators = 6;
// Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable`
// to the user defined name of the signal.
repeated Measurement measurements = 4;
// The `tf.Graph` used for aggregating accumulator checkpoints when
// loading metrics.
google.protobuf.Any cross_round_aggregation_graph_bytes = 5;
}
message Measurement {
// Name of a TensorFlow op to run to read/fetch the value of this measurement.
string read_op_name = 1;
// A human-readable name for the measurement. Names are usually
// camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
string name = 2;
reserved 3;
// A serialized `tff.Type` for the measurement.
bytes tff_type = 4;
}