1378 lines
56 KiB
Protocol Buffer
1378 lines
56 KiB
Protocol Buffer
|
|
/*
|
||
|
|
* Copyright (C) 2023 The Android Open Source Project
|
||
|
|
*
|
||
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
* you may not use this file except in compliance with the License.
|
||
|
|
* You may obtain a copy of the License at
|
||
|
|
*
|
||
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
*
|
||
|
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
* See the License for the specific language governing permissions and
|
||
|
|
* limitations under the License.
|
||
|
|
*/
|
||
|
|
|
||
|
|
syntax = "proto3";
|
||
|
|
|
||
|
|
package com.android.federatedcompute.proto;
|
||
|
|
|
||
|
|
import "google/protobuf/any.proto";
|
||
|
|
import "tensorflow/core/framework/tensor.proto";
|
||
|
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||
|
|
import "tensorflow/core/framework/types.proto";
|
||
|
|
import "tensorflow/core/protobuf/saver.proto";
|
||
|
|
import "tensorflow/core/protobuf/struct.proto";
|
||
|
|
|
||
|
|
option java_package = "com.android.federatedcompute.proto";
|
||
|
|
option java_multiple_files = true;
|
||
|
|
option java_outer_classname = "PlanProto";
|
||
|
|
|
||
|
|
// Primitives
|
||
|
|
// ===========
|
||
|
|
|
||
|
|
// Represents an operation to save or restore from a checkpoint. Some
|
||
|
|
// instances of this message may only be used either for restore or for
|
||
|
|
// save, others for both directions. This is documented together with
|
||
|
|
// their usage.
|
||
|
|
//
|
||
|
|
// This op has four essential uses:
|
||
|
|
// 1. read and apply a checkpoint.
|
||
|
|
// 2. write a checkpoint.
|
||
|
|
// 3. read and apply from an aggregated side channel.
|
||
|
|
// 4. write to a side channel (grouped with write a checkpoint).
|
||
|
|
// We should consider splitting this into four separate messages.
|
||
|
|
message CheckpointOp {
|
||
|
|
// An optional standard saver def. If not provided, only the
|
||
|
|
// op(s) below will be executed. This must be a version 1 SaverDef.
|
||
|
|
tensorflow.SaverDef saver_def = 1;
|
||
|
|
|
||
|
|
// An optional operation to run before the saver_def is executed for
|
||
|
|
// restore.
|
||
|
|
string before_restore_op = 2;
|
||
|
|
|
||
|
|
// An optional operation to run after the saver_def has been
|
||
|
|
// executed for restore. If side_channel_tensors are provided, then
|
||
|
|
// they should be provided in a feed_dict to this op.
|
||
|
|
string after_restore_op = 3;
|
||
|
|
|
||
|
|
// An optional operation to run before the saver_def will be
|
||
|
|
// executed for save.
|
||
|
|
string before_save_op = 4;
|
||
|
|
|
||
|
|
// An optional operation to run after the saver_def has been
|
||
|
|
// executed for save. If there are side_channel_tensors, this op
|
||
|
|
// should be run after the side_channel_tensors have been fetched.
|
||
|
|
string after_save_op = 5;
|
||
|
|
|
||
|
|
// In addition to being saved and restored from a checkpoint, one can
|
||
|
|
// also save and restore via a side channel. The keys in this map are
|
||
|
|
// the names of the tensors transmitted by the side channel. These (key)
|
||
|
|
// tensors should be read off just before saving a SaveDef and used
|
||
|
|
// by the code that handles the side channel. Any variables provided this
|
||
|
|
// way should NOT be saved in the SaveDef.
|
||
|
|
//
|
||
|
|
// For restoring, the variables that are provided by the side channel
|
||
|
|
// are restored differently than those for a checkpoint. For those from
|
||
|
|
// the side channel, these should be restored by calling the before_restore_op
|
||
|
|
// with a feed dict whose keys are the restore_names in the SideChannel and
|
||
|
|
// whose values are the values to be restored.
|
||
|
|
map<string, SideChannel> side_channel_tensors = 6;
|
||
|
|
|
||
|
|
// An optional name of a tensor in to which a unique token for the current
|
||
|
|
// session should be written.
|
||
|
|
//
|
||
|
|
// This session identifier allows TensorFlow ops such as `ServeSlices` or
|
||
|
|
// `ExternalDataset` to refer to callbacks and other session-global objects
|
||
|
|
// registered before running the session.
|
||
|
|
string session_token_tensor_name = 7;
|
||
|
|
}
|
||
|
|
|
||
|
|
message SideChannel {
|
||
|
|
// A side channel whose variables are processed via SecureAggregation.
|
||
|
|
// This side channel implements aggregation via sum over a set of
|
||
|
|
// clients, so the restored tensor will be a sum of multiple clients
|
||
|
|
// inputs into the side channel. Hence this will restore during the
|
||
|
|
// read_aggregate_update restore, not the per-client read_update restore.
|
||
|
|
message SecureAggregand {
|
||
|
|
message Dimension {
|
||
|
|
int64 size = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Dimensions of the aggregand. This is used by the secure aggregation
|
||
|
|
// protocol in its early rounds, not as redundant info which could be
|
||
|
|
// obtained by reading the dimensions of the tensor itself.
|
||
|
|
repeated Dimension dimension = 3;
|
||
|
|
|
||
|
|
// The data type anticipated by the server-side graph.
|
||
|
|
tensorflow.DataType dtype = 4;
|
||
|
|
|
||
|
|
// SecureAggregation will compute sum modulo this modulus.
|
||
|
|
message FixedModulus {
|
||
|
|
uint64 modulus = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// SecureAggregation will for each shard compute sum modulo m with m at
|
||
|
|
// least (1 + shard_size * (base_modulus - 1)), then aggregate
|
||
|
|
// shard results with non-modular addition. Here, shard_size is the number
|
||
|
|
// of clients in the shard.
|
||
|
|
//
|
||
|
|
// Note that the modulus for each shard will be greater than the largest
|
||
|
|
// possible (non-modular) sum of the inputs to that shard. That is,
|
||
|
|
// assuming each client has input on range [0, base_modulus), the result
|
||
|
|
// will be identical to non-modular addition (i.e. federated_sum).
|
||
|
|
//
|
||
|
|
// While any m >= (1 + shard_size * (base_modulus - 1)), the current
|
||
|
|
// implementation takes
|
||
|
|
// m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the
|
||
|
|
// smallest possible value of m that is also a power of 2. This choice is
|
||
|
|
// made because (a) it uses the same number of bits per vector entry as
|
||
|
|
// valid smaller m, using the current on-the-wire encoding scheme, and (b)
|
||
|
|
// it enables the underlying mask-generation PRNG to run in its most
|
||
|
|
// computationally efficient mode, which can be up to 2x faster.
|
||
|
|
message ModulusTimesShardSize {
|
||
|
|
uint64 base_modulus = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
oneof modulus_scheme {
|
||
|
|
// Bitwidth of the aggregand.
|
||
|
|
//
|
||
|
|
// This is the bitwidth of an input value (i.e. the bitwidth that
|
||
|
|
// quantization should target). The Secure Aggregation bitwidth (i.e.,
|
||
|
|
// the bitwidth of the *sum* of the input values) will be a function of
|
||
|
|
// this bitwidth and the number of participating clients, as negotiated
|
||
|
|
// with the server when the protocol is initiated.
|
||
|
|
//
|
||
|
|
// Deprecated; prefer fixed_modulus instead.
|
||
|
|
int32 quantized_input_bitwidth = 2 [deprecated = true];
|
||
|
|
|
||
|
|
FixedModulus fixed_modulus = 5;
|
||
|
|
ModulusTimesShardSize modulus_times_shard_size = 6;
|
||
|
|
}
|
||
|
|
|
||
|
|
reserved 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// What type of side channel is used.
|
||
|
|
oneof type {
|
||
|
|
SecureAggregand secure_aggregand = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// When restoring the name of the tensor to restore to. This is the name
|
||
|
|
// (key) supplied in the feed_dict in the before_restore_op in order to
|
||
|
|
// restore the tensor provided by the side channel (which will be the
|
||
|
|
// value in the feed_dict).
|
||
|
|
string restore_name = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Container for a metric used by the internal toolkit.
|
||
|
|
message Metric {
|
||
|
|
// Name of an Op to run to read the value.
|
||
|
|
string variable_name = 1;
|
||
|
|
|
||
|
|
// A human-readable name for the statistic. Metric names are usually
|
||
|
|
// camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
|
||
|
|
// Must be 7-bit ASCII and under 122 characters.
|
||
|
|
string stat_name = 2;
|
||
|
|
|
||
|
|
// The human-readable name of another metric by which this metric should be
|
||
|
|
// normalized, if any. If empty, this Metric should be aggregated with simple
|
||
|
|
// summation. If not empty, the Metric is aggregated according to
|
||
|
|
// weighted_metric_sum = sum_i (metric_i * weight_i)
|
||
|
|
// weight_sum = sum_i weight_i
|
||
|
|
// average_metric_value = weighted_metric_sum / weight_sum
|
||
|
|
string weight_name = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Controls the format of output metrics users receive. Represents instructions
|
||
|
|
// for how metrics are to be output to users, controlling the end format of
|
||
|
|
// the metric users receive.
|
||
|
|
message OutputMetric {
|
||
|
|
// Metric name.
|
||
|
|
string name = 1;
|
||
|
|
|
||
|
|
oneof value_source {
|
||
|
|
// A metric representing one stat with aggregation type sum.
|
||
|
|
SumOptions sum = 2;
|
||
|
|
|
||
|
|
// A metric representing a ratio between metrics with aggregation
|
||
|
|
// type average.
|
||
|
|
AverageOptions average = 3;
|
||
|
|
|
||
|
|
// A metric that is not aggregated by the MetricReportAggregator or
|
||
|
|
// metrics_loader. This includes metrics like 'num_server_updates' that are
|
||
|
|
// aggregated in TensorFlow.
|
||
|
|
NoneOptions none = 4;
|
||
|
|
|
||
|
|
// A metric representing one stat with aggregation type only sample.
|
||
|
|
// Samples at most 101 clients' values.
|
||
|
|
OnlySampleOptions only_sample = 5;
|
||
|
|
}
|
||
|
|
// Iff True, the metric will be plotted in the default view of the
|
||
|
|
// task level Colab automatically.
|
||
|
|
oneof visualization_info {
|
||
|
|
bool auto_plot = 6 [deprecated = true];
|
||
|
|
VisualizationSpec plot_spec = 7;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
message VisualizationSpec {
|
||
|
|
// Different allowable plot types.
|
||
|
|
enum VisualizationType {
|
||
|
|
NONE = 0;
|
||
|
|
DEFAULT_PLOT_FOR_TASK_TYPE = 1;
|
||
|
|
LINE_PLOT = 2;
|
||
|
|
LINE_PLOT_WITH_PERCENTILES = 3;
|
||
|
|
HISTOGRAM = 4;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Defines the plot type to provide downstream.
|
||
|
|
VisualizationType plot_type = 1;
|
||
|
|
|
||
|
|
// The x-axis which to provide for the given metric. Must be the name of a
|
||
|
|
// metric or counter. Recommended x_axis options are source_round, round,
|
||
|
|
// or time.
|
||
|
|
string x_axis = 2;
|
||
|
|
|
||
|
|
// Iff True, metric will be displayed on a population level dashboard.
|
||
|
|
bool plot_on_population_dashboard = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A metric representing one stat with aggregation type sum.
|
||
|
|
message SumOptions {
|
||
|
|
// Name for corresponding Metric stat_name field.
|
||
|
|
string stat_name = 1;
|
||
|
|
|
||
|
|
// Iff True, a cumulative sum over rounds will be provided in addition to a
|
||
|
|
// sum per round for the value metric.
|
||
|
|
bool include_cumulative_sum = 2;
|
||
|
|
|
||
|
|
// Iff True, sample of at most 101 clients' values.
|
||
|
|
// Used to calculate quantiles in downstream visualization pipeline.
|
||
|
|
bool include_client_samples = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A metric representing a ratio between metrics with aggregation type average.
|
||
|
|
// Represents: numerator stat / denominator stat.
|
||
|
|
message AverageOptions {
|
||
|
|
// Numerator stat name pointing to corresponding Metric stat_name.
|
||
|
|
string numerator_stat_name = 1;
|
||
|
|
|
||
|
|
// Denominator stat name pointing to corresponding Metric stat_name.
|
||
|
|
string denominator_stat_name = 2;
|
||
|
|
|
||
|
|
// Name for corresponding Metric stat_name that is the ratio of the
|
||
|
|
// numerator stat / denominator stat.
|
||
|
|
string average_stat_name = 3;
|
||
|
|
|
||
|
|
// Iff True, sample of at most 101 client's values.
|
||
|
|
// Used to calculate quantiles in downstream visualization pipeline.
|
||
|
|
bool include_client_samples = 4;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A metric representing one stat with aggregation type none.
|
||
|
|
message NoneOptions {
|
||
|
|
// Name for corresponding Metric stat_name field.
|
||
|
|
string stat_name = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A metric representing one stat with aggregation type only sample.
|
||
|
|
message OnlySampleOptions {
|
||
|
|
// Name for corresponding Metric stat_name field.
|
||
|
|
string stat_name = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents a data set. This is used for testing.
|
||
|
|
message Dataset {
|
||
|
|
// Represents the data set for one client.
|
||
|
|
message ClientDataset {
|
||
|
|
// A string identifying the client.
|
||
|
|
string client_id = 1;
|
||
|
|
|
||
|
|
// A list of serialized tf.Example protos.
|
||
|
|
repeated bytes example = 2;
|
||
|
|
|
||
|
|
// Represents a dataset whose examples are selected by an ExampleSelector.
|
||
|
|
message SelectedExample {
|
||
|
|
ExampleSelector selector = 1;
|
||
|
|
repeated bytes example = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A list of (selector, dataset) pairs. Used in testing some *TFF-based
|
||
|
|
// tasks* that require multiple datasets as client input, e.g., a TFF-based
|
||
|
|
// personalization eval task requires each client to provide at least two
|
||
|
|
// datasets: one for train, and the other for test.
|
||
|
|
repeated SelectedExample selected_example = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A list of client data.
|
||
|
|
repeated ClientDataset client_data = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents predicates over metrics - i.e., expectations. This is used in
|
||
|
|
// training/eval tests to encode metric names and values expected to be reported
|
||
|
|
// by a client execution.
|
||
|
|
message MetricTestPredicates {
|
||
|
|
// The value must lie in [lower_bound; upper_bound]. Can also be used for
|
||
|
|
// approximate matching (lower == value - epsilon; upper = value + epsilon).
|
||
|
|
message Interval {
|
||
|
|
double lower_bound = 1;
|
||
|
|
double upper_bound = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// The value must be a real value as long as the value of the weight_name
|
||
|
|
// metric is non-zero. If the weight metric is zero, then it is acceptable for
|
||
|
|
// the value to be non-real.
|
||
|
|
message RealIfNonzeroWeight {
|
||
|
|
string weight_name = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
message MetricCriterion {
|
||
|
|
// Name of the metric.
|
||
|
|
string name = 1;
|
||
|
|
|
||
|
|
// FL training round this metric is expected to appear in.
|
||
|
|
int32 training_round_index = 2;
|
||
|
|
|
||
|
|
// If none of the following is set, no matching is performed; but the
|
||
|
|
// metric is still expected to be present (with whatever value).
|
||
|
|
oneof Criterion {
|
||
|
|
// The reported metric must be < lt.
|
||
|
|
float lt = 3;
|
||
|
|
// The reported metric must be > gt.
|
||
|
|
float gt = 4;
|
||
|
|
// The reported metric must be <= le.
|
||
|
|
float le = 5;
|
||
|
|
// The reported metric must be >= ge.
|
||
|
|
float ge = 6;
|
||
|
|
// The reported metric must be == eq.
|
||
|
|
float eq = 7;
|
||
|
|
// The reported metric must lie in the interval.
|
||
|
|
Interval interval = 8;
|
||
|
|
// The reported metric is not NaN or +/- infinity.
|
||
|
|
bool real = 9;
|
||
|
|
// The reported metric is real (i.e., not NaN or +/- infinity) if the
|
||
|
|
// value of an associated weight is not 0.
|
||
|
|
RealIfNonzeroWeight real_if_nonzero_weight = 10;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
repeated MetricCriterion metric_criterion = 1;
|
||
|
|
|
||
|
|
reserved 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Client Phase
|
||
|
|
// ============
|
||
|
|
|
||
|
|
// A `TensorflowSpec` that is executed on the client in a single `tf.Session`.
|
||
|
|
// In federated optimization, this will correspond to one `ServerPhase`.
|
||
|
|
message ClientPhase {
|
||
|
|
// A short CamelCase name for the ClientPhase.
|
||
|
|
string name = 2;
|
||
|
|
|
||
|
|
// Minimum number of clients in aggregation.
|
||
|
|
// In secure aggregation mode this is used to configure the protocol instance
|
||
|
|
// in a way that server can't learn aggregated values with number of
|
||
|
|
// participants lower than this number.
|
||
|
|
// Without secure aggregation server still respects this parameter,
|
||
|
|
// ensuring that aggregated values never leave server RAM unless they include
|
||
|
|
// data from (at least) specified number of participants.
|
||
|
|
int32 minimum_number_of_participants = 3;
|
||
|
|
|
||
|
|
// If populated, `io_router` must be specified.
|
||
|
|
oneof spec {
|
||
|
|
// A functional interface for the TensorFlow logic the client should
|
||
|
|
// perform.
|
||
|
|
TensorflowSpec tensorflow_spec = 4 [lazy = true];
|
||
|
|
// Spec for client plans that issue example queries and send the query
|
||
|
|
// results directly to an aggregator with no or little additional
|
||
|
|
// processing.
|
||
|
|
ExampleQuerySpec example_query_spec = 9 [lazy = true];
|
||
|
|
}
|
||
|
|
|
||
|
|
// The specification of the inputs coming either from customer apps
|
||
|
|
// (Local Compute) or the federated protocol (Federated Compute).
|
||
|
|
oneof io_router {
|
||
|
|
FederatedComputeIORouter federated_compute = 5 [lazy = true];
|
||
|
|
LocalComputeIORouter local_compute = 6 [lazy = true];
|
||
|
|
FederatedComputeEligibilityIORouter federated_compute_eligibility = 7
|
||
|
|
[lazy = true];
|
||
|
|
FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true];
|
||
|
|
}
|
||
|
|
|
||
|
|
reserved 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// TensorflowSpec message describes a single call into TensorFlow, including the
|
||
|
|
// expected input tensors that must be fed when making that call, which
|
||
|
|
// output tensors to be fetched, and any operations that have no output but must
|
||
|
|
// be run. The TensorFlow session will then use the input tensors to do some
|
||
|
|
// computation, generally reading from one or more datasets, and provide some
|
||
|
|
// outputs.
|
||
|
|
//
|
||
|
|
// Conceptually, client or server code uses this proto along with an IORouter
|
||
|
|
// to build maps of names to input tensors, vectors of output tensor names,
|
||
|
|
// and vectors of target nodes:
|
||
|
|
//
|
||
|
|
// CreateTensorflowArguments(
|
||
|
|
// TensorflowSpec& spec,
|
||
|
|
// IORouter& io_router,
|
||
|
|
// const vector<pair<string, Tensor>>* input_tensors,
|
||
|
|
// const vector<string>* output_tensor_names,
|
||
|
|
// const vector<string>* target_node_names);
|
||
|
|
//
|
||
|
|
// Where `input_tensor`, `output_tensor_names` and `target_node_names`
|
||
|
|
// correspond to the arguments of TensorFlow C++ API for
|
||
|
|
// `tensorflow::Session:Run()`, and the client executes only a single
|
||
|
|
// invocation.
|
||
|
|
//
|
||
|
|
// Note: the execution engine never sees any concepts related to the federated
|
||
|
|
// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors
|
||
|
|
// in, tensors out" interface. New aggregation methods can be added without
|
||
|
|
// having to modify the execution engine / TensorflowSpec message, instead they
|
||
|
|
// should modify the IORouter messages.
|
||
|
|
//
|
||
|
|
// Note: both `input_tensor_specs` and `output_tensor_specs` are full
|
||
|
|
// `tensorflow.TensorSpecProto` messages, though TensorFlow technically
|
||
|
|
// only requires the names to feed the values into the session. The additional
|
||
|
|
// dtypes/shape information must always be included in case the runtime
|
||
|
|
// executing this TensorflowSpec wants to perform additional, optional static
|
||
|
|
// assertions. The runtimes however are free to ignore the dtype/shapes and only
|
||
|
|
// rely on the names if so desired.
|
||
|
|
//
|
||
|
|
// Assertions:
|
||
|
|
// - all names in `input_tensor_specs`, `output_tensor_specs`, and
|
||
|
|
// `target_node_names` must appear in the serialized GraphDef where
|
||
|
|
// the TF execution will be invoked.
|
||
|
|
// - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise
|
||
|
|
// there is nothing to execute in the graph.
|
||
|
|
message TensorflowSpec {
|
||
|
|
// The name of a tensor into which a unique token for the current session
|
||
|
|
// should be written. The corresponding tensor is a scalar string tensor and
|
||
|
|
// is separate from `input_tensors` as there is only one.
|
||
|
|
//
|
||
|
|
// A session token allows TensorFlow ops such as `ServeSlices` or
|
||
|
|
// `ExternalDataset` to refer to callbacks and other session-global objects
|
||
|
|
// registered before running the session. In the `ExternalDataset` case, a
|
||
|
|
// single dataset_token is valid for multiple `tf.data.Dataset` objects as
|
||
|
|
// the token can be thought of as a handle to a dataset factory.
|
||
|
|
string dataset_token_tensor_name = 1;
|
||
|
|
|
||
|
|
// TensorSpecs of inputs which will be passed to TF.
|
||
|
|
//
|
||
|
|
// Corresponds to the `feed_dict` parameter of `tf.Session.run()` in
|
||
|
|
// TensorFlow's Python API, excluding the dataset_token listed above.
|
||
|
|
//
|
||
|
|
// Assertions:
|
||
|
|
// - All the tensor names designated as inputs in the corresponding IORouter
|
||
|
|
// must be listed (otherwise the IORouter input work is unused).
|
||
|
|
// - All placeholders in the TF graph must be listed here, with the
|
||
|
|
// exception of the dataset_token which is explicitly set above (otherwise
|
||
|
|
// TensorFlow will fail to execute).
|
||
|
|
repeated tensorflow.TensorSpecProto input_tensor_specs = 2;
|
||
|
|
|
||
|
|
// TensorSpecs that should be fetched from TF after execution.
|
||
|
|
//
|
||
|
|
// Corresponds to the `fetches` parameter of `tf.Session.run()` in
|
||
|
|
// TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++
|
||
|
|
// API.
|
||
|
|
//
|
||
|
|
// Assertions:
|
||
|
|
// - The set of tensor names here must strictly match the tensor names
|
||
|
|
// designated as outputs in the corresponding IORouter (if any exist).
|
||
|
|
repeated tensorflow.TensorSpecProto output_tensor_specs = 3;
|
||
|
|
|
||
|
|
// Node names in the graph that should be executed, but the output not
|
||
|
|
// returned.
|
||
|
|
//
|
||
|
|
// Corresponds to the `fetches` parameter of `tf.Session.run()` in
|
||
|
|
// TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++
|
||
|
|
// API.
|
||
|
|
//
|
||
|
|
// This is intended for use with operations that do not produce tensors, but
|
||
|
|
// nonetheless are required to run (e.g. serializing checkpoints).
|
||
|
|
repeated string target_node_names = 4;
|
||
|
|
|
||
|
|
// Map of Tensor names to constant inputs.
|
||
|
|
// Note: tensors specified via this message should not be included in
|
||
|
|
// input_tensor_specs.
|
||
|
|
map<string, tensorflow.TensorProto> constant_inputs = 5;
|
||
|
|
}
|
||
|
|
|
||
|
|
// ExampleQuerySpec message describes client execution that issues example
|
||
|
|
// queries and sends the query results directly to an aggregator with no or
|
||
|
|
// little additional processing.
|
||
|
|
// This message describes one or more example store queries that perform the
|
||
|
|
// client side analytics computation in C++. The corresponding output vectors
|
||
|
|
// will be converted into the expected federated protocol output format.
|
||
|
|
// This must be used in conjunction with the `FederatedExampleQueryIORouter`.
|
||
|
|
message ExampleQuerySpec {
|
||
|
|
message OutputVectorSpec {
|
||
|
|
// The output vector name.
|
||
|
|
string vector_name = 1;
|
||
|
|
|
||
|
|
// Supported data types for the vector of information.
|
||
|
|
enum DataType {
|
||
|
|
UNSPECIFIED = 0;
|
||
|
|
INT32 = 1;
|
||
|
|
INT64 = 2;
|
||
|
|
BOOL = 3;
|
||
|
|
FLOAT = 4;
|
||
|
|
DOUBLE = 5;
|
||
|
|
BYTES = 6;
|
||
|
|
STRING = 7;
|
||
|
|
}
|
||
|
|
|
||
|
|
// The data type for each entry in the vector.
|
||
|
|
DataType data_type = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
message ExampleQuery {
|
||
|
|
// The `ExampleSelector` to issue the query with.
|
||
|
|
ExampleSelector example_selector = 1;
|
||
|
|
|
||
|
|
// Indicates that the query returns vector data and must return a single
|
||
|
|
// ExampleQueryResult result containing a VectorData entry matching each
|
||
|
|
// OutputVectorSpec.vector_name.
|
||
|
|
//
|
||
|
|
// If the query instead returns no result, then it will be treated as is if
|
||
|
|
// an error was returned. In that case, or if the query explicitly returns
|
||
|
|
// an error, then the client will abort its session.
|
||
|
|
//
|
||
|
|
// The keys in the map are the names the vectors should be aggregated under,
|
||
|
|
// and must match the keys in FederatedExampleQueryIORouter.aggregations.
|
||
|
|
map<string, OutputVectorSpec> output_vector_specs = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// The queries to run.
|
||
|
|
repeated ExampleQuery example_queries = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// The input and output router for Federated Compute plans.
|
||
|
|
//
|
||
|
|
// This proto is the glue between the federated protocol and the TensorFlow
|
||
|
|
// execution engine. This message describes how to prepare data coming from the
|
||
|
|
// incoming `CheckinResponse` (defined in
|
||
|
|
// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what
|
||
|
|
// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on
|
||
|
|
// the server).
|
||
|
|
//
|
||
|
|
// TODO(team) we could replace `input_checkpoint_file_tensor_name` with
|
||
|
|
// an `input_tensors` field, which would then be a tensor that contains the
|
||
|
|
// input TensorProtos directly and skipping disk I/O, rather than referring to a
|
||
|
|
// checkpoint file path.
|
||
|
|
message FederatedComputeIORouter {
|
||
|
|
// ===========================================================================
|
||
|
|
// Inputs
|
||
|
|
// ===========================================================================
|
||
|
|
// The name of the scalar string tensor that is fed the file path to the
|
||
|
|
// initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint).
|
||
|
|
//
|
||
|
|
// The federated protocol code would copy the `CheckinResponse`'s initial
|
||
|
|
// checkpoint to a temporary file and then pass that file path through this
|
||
|
|
// tensor.
|
||
|
|
//
|
||
|
|
// Ops may be added to the client graph that take this tensor as input and
|
||
|
|
// reads the path.
|
||
|
|
//
|
||
|
|
// This field is optional. It may be omitted if the client graph does not use
|
||
|
|
// an initial checkpoint.
|
||
|
|
string input_filepath_tensor_name = 1;
|
||
|
|
|
||
|
|
// The name of the scalar string tensor that is fed the file path to which
|
||
|
|
// client work should serialize the bytes to send back to the server.
|
||
|
|
//
|
||
|
|
// The federated protocol code generates a temporary file and passes the file
|
||
|
|
// path through this tensor.
|
||
|
|
//
|
||
|
|
// Ops may be be added to the client graph that use this tensor as an argument
|
||
|
|
// to write files (e.g. writing checkpoints to disk).
|
||
|
|
//
|
||
|
|
// This field is optional. It must be omitted if the client graph does not
|
||
|
|
// generate any output files (e.g. when all output tensors of `TensorflowSpec`
|
||
|
|
// use Secure Aggregation). If this field is not set, then the `ReportRequest`
|
||
|
|
// message in the federated protocol will not have the
|
||
|
|
// `Report.update_checkpoint` field set. This absence of a value here can be
|
||
|
|
// used to validate that the plan only uses Secure Aggregation.
|
||
|
|
//
|
||
|
|
// Conversely, if this field is set and executing the associated
|
||
|
|
// TensorflowSpec does not write to the path is indication of an internal
|
||
|
|
// framework error. The runtime should notify the caller that the computation
|
||
|
|
// was setup incorrectly.
|
||
|
|
string output_filepath_tensor_name = 2;
|
||
|
|
|
||
|
|
// ===========================================================================
|
||
|
|
// Outputs
|
||
|
|
// ===========================================================================
|
||
|
|
// Describes which output tensors should be aggregated using an aggregation
|
||
|
|
// protocol, and the configuration for those protocols.
|
||
|
|
//
|
||
|
|
// Assertions:
|
||
|
|
// - All keys must exist in the associated `TensorflowSpec` as
|
||
|
|
// `output_tensor_specs.name` values.
|
||
|
|
map<string, AggregationConfig> aggregations = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// The input and output router for client plans that do not use TensorFlow.
|
||
|
|
//
|
||
|
|
// This proto is the glue between the federated protocol and the example query
|
||
|
|
// execution engine, describing how the query results should ultimately be
|
||
|
|
// aggregated.
|
||
|
|
message FederatedExampleQueryIORouter {
|
||
|
|
// Describes how each output vector should be aggregated using an aggregation
|
||
|
|
// protocol, and the configuration for those protocols.
|
||
|
|
// Keys must match the keys in ExampleQuerySpec.output_vector_specs.
|
||
|
|
// Note that currently only the TFV1CheckpointAggregation config is supported.
|
||
|
|
map<string, AggregationConfig> aggregations = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// The specification for how to aggregate the associated tensor across clients
|
||
|
|
// on the server.
|
||
|
|
message AggregationConfig {
|
||
|
|
oneof protocol_config {
|
||
|
|
// Indicates that the given output tensor should be processed using Secure
|
||
|
|
// Aggregation, using the specified config options.
|
||
|
|
SecureAggregationConfig secure_aggregation = 2;
|
||
|
|
|
||
|
|
// Note: in the future we could add a `SimpleAggregationConfig` to add
|
||
|
|
// support for simple aggregation without writing to an intermediate
|
||
|
|
// checkpoint file first.
|
||
|
|
|
||
|
|
// Indicates that the given output tensor or vector (e.g. as produced by an
|
||
|
|
// ExampleQuerySpec) should be placed in an output TF v1 checkpoint.
|
||
|
|
//
|
||
|
|
// Currently only ExampleQuerySpec output vectors are supported by this
|
||
|
|
// aggregation type (i.e. it cannot be used with TensorflowSpec output
|
||
|
|
// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
|
||
|
|
// its corresponding data type.
|
||
|
|
TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parameters for the SecAgg protocol (go/secagg).
|
||
|
|
//
|
||
|
|
// Currently only the server uses the SecAgg parameters, so we only use this
|
||
|
|
// message to signify usage of SecAgg.
|
||
|
|
message SecureAggregationConfig {}
|
||
|
|
|
||
|
|
// Parameters for the TFV1 Checkpoint Aggregation protocol.
|
||
|
|
//
|
||
|
|
// Currently only ExampleQuerySpec output vectors are supported by this
|
||
|
|
// aggregation type (i.e. it cannot be used with TensorflowSpec output
|
||
|
|
// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
|
||
|
|
// its corresponding data type.
|
||
|
|
message TFV1CheckpointAggregation {}
|
||
|
|
|
||
|
|
// The input and output router for eligibility-computing plans. These plans
|
||
|
|
// compute which other plans a client is eligible to run, and are returned by
|
||
|
|
// clients via a `EligibilityEvalCheckinResponse` (defined in
|
||
|
|
// fcp/protos/federated_api.proto).
|
||
|
|
message FederatedComputeEligibilityIORouter {
|
||
|
|
// The name of the scalar string tensor that is fed the file path to the
|
||
|
|
// initial checkpoint (e.g. as provided via
|
||
|
|
// `EligibilityEvalPayload.init_checkpoint`).
|
||
|
|
//
|
||
|
|
// For more detail see the
|
||
|
|
// `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same
|
||
|
|
// semantics.
|
||
|
|
//
|
||
|
|
// This field is optional. It may be omitted if the client graph does not use
|
||
|
|
// an initial checkpoint.
|
||
|
|
//
|
||
|
|
// This tensor name must exist in the associated
|
||
|
|
// `TensorflowSpec.input_tensor_specs` list.
|
||
|
|
string input_filepath_tensor_name = 1;
|
||
|
|
|
||
|
|
// Name of the output tensor (a string scalar) containing the serialized
|
||
|
|
// `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The
|
||
|
|
// client code will parse this proto and place it in the
|
||
|
|
// `task_eligibility_info` field of the subsequent `CheckinRequest`.
|
||
|
|
//
|
||
|
|
// This tensor name must exist in the associated
|
||
|
|
// `TensorflowSpec.output_tensor_specs` list.
|
||
|
|
string task_eligibility_info_tensor_name = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// The input and output router for Local Compute plans.
|
||
|
|
//
|
||
|
|
// This proto is the glue between the customers app and the TensorFlow
|
||
|
|
// execution engine. This message describes how to prepare data coming from the
|
||
|
|
// customer app (e.g. the input directory the app setup), and the temporary,
|
||
|
|
// scratch output directory that will be notified to the customer app upon
|
||
|
|
// completion of `TensorflowSpec`.
|
||
|
|
message LocalComputeIORouter {
|
||
|
|
// ===========================================================================
|
||
|
|
// Inputs
|
||
|
|
// ===========================================================================
|
||
|
|
// The name of the placeholder tensor representing the input resource path(s).
|
||
|
|
// It can be a single input directory or file path (in this case the
|
||
|
|
// `input_dir_tensor_name` is populated) or multiple input resources
|
||
|
|
// represented as a map from names to input directories or file paths (in this
|
||
|
|
// case the `multiple_input_resources` is populated).
|
||
|
|
//
|
||
|
|
// In the multiple input resources case, the placeholder tensors are
|
||
|
|
// represented as a map: the keys are the input resource names defined by the
|
||
|
|
// users when constructing the `LocalComputation` Python object, and the
|
||
|
|
// values are the corresponding placeholder tensor names created by the local
|
||
|
|
// computation plan builder.
|
||
|
|
//
|
||
|
|
// Apps will have the ability to create contracts between their Android code
|
||
|
|
// and `LocalComputation` toolkit code to place files inside the input
|
||
|
|
// resource paths with known names (Android code) and create graphs with ops
|
||
|
|
// to read from these paths (file names can be specified in toolkit code).
|
||
|
|
oneof input_resource {
|
||
|
|
string input_dir_tensor_name = 1;
|
||
|
|
// Directly using the `map` field is not allowed in `oneof`, so we have to
|
||
|
|
// wrap it in a new message.
|
||
|
|
MultipleInputResources multiple_input_resources = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Scalar string tensor name that will contain the output directory path.
|
||
|
|
//
|
||
|
|
// The provided directory should be considered temporary scratch that will be
|
||
|
|
// deleted, not persisted. It is the responsibility of the calling app to
|
||
|
|
// move the desired files to a permanent location once the client returns this
|
||
|
|
// directory back to the calling app.
|
||
|
|
string output_dir_tensor_name = 2;
|
||
|
|
|
||
|
|
// ===========================================================================
|
||
|
|
// Outputs
|
||
|
|
// ===========================================================================
|
||
|
|
// NOTE: LocalCompute has no outputs other than what the client graph writes
|
||
|
|
// to `output_dir` specified above.
|
||
|
|
}
|
||
|
|
|
||
|
|
// Describes the multiple input resources in `LocalComputeIORouter`.
|
||
|
|
message MultipleInputResources {
|
||
|
|
// The keys are the input resource names (defined by the users when
|
||
|
|
// constructing the `LocalComputation` Python object), and the values are the
|
||
|
|
// corresponding placeholder tensor names created by the local computation
|
||
|
|
// plan builder.
|
||
|
|
map<string, string> input_resource_tensor_name_map = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Describes a queue to which input is fed.
|
||
|
|
message AsyncInputFeed {
|
||
|
|
// The op for enqueuing an example input.
|
||
|
|
string enqueue_op = 1;
|
||
|
|
|
||
|
|
// The input placeholders for the enqueue op.
|
||
|
|
repeated string enqueue_params = 2;
|
||
|
|
|
||
|
|
// The op for closing the input queue.
|
||
|
|
string close_op = 3;
|
||
|
|
|
||
|
|
// Whether the work that should be fed asynchronously is the data itself
|
||
|
|
// or a description of where that data lives.
|
||
|
|
bool feed_values_are_data = 4;
|
||
|
|
}
|
||
|
|
|
||
|
|
message DatasetInput {
|
||
|
|
// Initializer of iterator corresponding to tf.data.Dataset object which
|
||
|
|
// handles the input data. Stores name of an op in the graph.
|
||
|
|
string initializer = 1;
|
||
|
|
|
||
|
|
// Placeholders necessary to initialize the dataset.
|
||
|
|
DatasetInputPlaceholders placeholders = 2;
|
||
|
|
|
||
|
|
// Batch size to be used in tf.data.Dataset.
|
||
|
|
int32 batch_size = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
message DatasetInputPlaceholders {
|
||
|
|
// Name of placeholder corresponding to filename(s) of SSTable(s) to read data
|
||
|
|
// from.
|
||
|
|
string filename = 1;
|
||
|
|
|
||
|
|
// Name of placeholder corresponding to key_prefix initializing the
|
||
|
|
// SSTableDataset. Note the value fed should be unique user id, not a prefix.
|
||
|
|
string key_prefix = 2;
|
||
|
|
|
||
|
|
// Name of placeholder corresponding to number of rounds the local training
|
||
|
|
// should be run for.
|
||
|
|
string num_epochs = 3;
|
||
|
|
|
||
|
|
// Name of placeholder corresponding to batch size.
|
||
|
|
string batch_size = 4;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Specifies an example selection procedure.
|
||
|
|
message ExampleSelector {
|
||
|
|
// Selection criteria following a contract agreed upon between client and
|
||
|
|
// model designers.
|
||
|
|
google.protobuf.Any criteria = 1;
|
||
|
|
|
||
|
|
// A URI identifying the example collection to read from. Format should adhere
|
||
|
|
// to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments
|
||
|
|
// should adhere to the following rules:
|
||
|
|
// - The scheme ${COLLECTION} should be one of:
|
||
|
|
// - "app" for app-hosted example
|
||
|
|
// - "simulation" for collections not connected to an app (e.g., if used
|
||
|
|
// purely for simulation)
|
||
|
|
// - The authority ${APP_NAME} identifies the owner of the example
|
||
|
|
// collection and should be either the app's package name, or be left empty
|
||
|
|
// (which means "the current app package name").
|
||
|
|
// - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with
|
||
|
|
// a forward slash ("/").
|
||
|
|
// - The query and fragment are currently not used, but they may become used
|
||
|
|
// for something in the future. To keep open that possibility they must
|
||
|
|
// currently be left empty.
|
||
|
|
//
|
||
|
|
// Example: "app://com.google.some.app/someCollection/name"
|
||
|
|
// identifies the collection "/someCollection/name" owned and hosted by the
|
||
|
|
// app with package name "com.google.some.app".
|
||
|
|
//
|
||
|
|
// Example: "app:/someCollection/name" or "app:///someCollection/name"
|
||
|
|
// both identify the collection "/someCollection/name" owned and hosted by the
|
||
|
|
// app associated with the training job in which this URI appears.
|
||
|
|
//
|
||
|
|
// The path will not be interpreted by the runtime, and will be passed to the
|
||
|
|
// example collection implementation for interpretation. Thus, in the case of
|
||
|
|
// app-hosted example stores, the path segment's interpretation is a contract
|
||
|
|
// between the app's example store developers, and the app's model designers.
|
||
|
|
//
|
||
|
|
// If an `app://` URI is set, then the `TrainerOptions` collection name must
|
||
|
|
// not be set.
|
||
|
|
string collection_uri = 2;
|
||
|
|
|
||
|
|
// Resumption token following a contract agreed upon between client and
|
||
|
|
// model designers.
|
||
|
|
google.protobuf.Any resumption_token = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Selector for slices to fetch as part of a `federated_select` operation.
|
||
|
|
message SlicesSelector {
|
||
|
|
// The string ID under which the slices are served.
|
||
|
|
//
|
||
|
|
// This value must have been returned by a previous call to the `serve_slices`
|
||
|
|
// op run during the `write_client_init` operation.
|
||
|
|
string served_at_id = 1;
|
||
|
|
|
||
|
|
// The indices of slices to fetch.
|
||
|
|
repeated int32 keys = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents slice data to be served as part of a `federated_select` operation.
|
||
|
|
// This is used for testing.
|
||
|
|
message SlicesTestDataset {
|
||
|
|
// The test data to use. The keys map to the `SlicesSelector.served_at_id`
|
||
|
|
// field. E.g. test slice data for a slice with `served_at_id`="foo" and
|
||
|
|
// `keys`=2 would be store in `dataset["foo"].slice_data[2]`.
|
||
|
|
map<string, SlicesTestData> dataset = 1;
|
||
|
|
}
|
||
|
|
message SlicesTestData {
|
||
|
|
// The test slice data to serve. Each entry's index corresponds to the slice
|
||
|
|
// key it is the test data for.
|
||
|
|
repeated bytes slice_data = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Server Phase V2
|
||
|
|
// ===============
|
||
|
|
|
||
|
|
// Represents a server phase with three distinct components: pre-broadcast,
|
||
|
|
// aggregation, and post-aggregation.
|
||
|
|
//
|
||
|
|
// The pre-broadcast and post-aggregation components are described with
|
||
|
|
// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec
|
||
|
|
// messages, respectively. These messages in combination with the server
|
||
|
|
// IORouter messages specify how to set up a single TF sess.run call for each
|
||
|
|
// component.
|
||
|
|
//
|
||
|
|
// The pre-broadcast logic is obtained by transforming the server_prepare TFF
|
||
|
|
// computation in the DistributeAggregateForm. It takes the server state as
|
||
|
|
// input, and it generates the checkpoint to broadcast to the clients and
|
||
|
|
// potentially an intermediate server state. The intermediate server state may
|
||
|
|
// be used by the aggregation and post-aggregation logic.
|
||
|
|
//
|
||
|
|
// The aggregation logic represents the aggregation of client results at the
|
||
|
|
// server and is described using a list of ServerAggregationConfig messages.
|
||
|
|
// Each ServerAggregationConfig message describes a single aggregation operation
|
||
|
|
// on a set of input/output tensors. The input tensors may represent parts of
|
||
|
|
// either the client results or the intermediate server state. These messages
|
||
|
|
// are obtained by transforming the client_to_server_aggregation TFF computation
|
||
|
|
// in the DistributeAggregateForm.
|
||
|
|
//
|
||
|
|
// The post-aggregation logic is obtained by transforming the server_result TFF
|
||
|
|
// computation in the DistributeAggregateForm. It takes the intermediate server
|
||
|
|
// state and the aggregated client results as input, and it generates the new
|
||
|
|
// server state and potentially other server-side output.
|
||
|
|
//
|
||
|
|
// Note that while a ServerPhaseV2 message can be generated for all types of
|
||
|
|
// intrinsics, it is currently only compatible with the ClientPhase message if
|
||
|
|
// the aggregations being used are exclusively federated_sum (not SecAgg). If
|
||
|
|
// this compatibility requirement is satisfied, it is also valid to run the
|
||
|
|
// aggregation portion of this ServerPhaseV2 message alongside the pre- and
|
||
|
|
// post-aggregation logic from the original ServerPhase message. Ultimately,
|
||
|
|
// we expect the full ServerPhaseV2 message to be run and the ServerPhase
|
||
|
|
// message to be deprecated.
|
||
|
|
message ServerPhaseV2 {
|
||
|
|
// A short CamelCase name for the ServerPhaseV2.
|
||
|
|
string name = 1;
|
||
|
|
|
||
|
|
// A functional interface for the TensorFlow logic the server should perform
|
||
|
|
// prior to the server-to-client broadcast. This should be used with the
|
||
|
|
// TensorFlow graph defined in server_graph_prepare_bytes.
|
||
|
|
TensorflowSpec tensorflow_spec_prepare = 3;
|
||
|
|
|
||
|
|
// The specification of inputs needed by the server_prepare TF logic.
|
||
|
|
oneof server_prepare_io_router {
|
||
|
|
ServerPrepareIORouter prepare_router = 4;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A list of client-to-server aggregations to perform.
|
||
|
|
repeated ServerAggregationConfig aggregations = 2;
|
||
|
|
|
||
|
|
// A functional interface for the TensorFlow logic the server should perform
|
||
|
|
// post-aggregation. This should be used with the TensorFlow graph defined
|
||
|
|
// in server_graph_result_bytes.
|
||
|
|
TensorflowSpec tensorflow_spec_result = 5;
|
||
|
|
|
||
|
|
// The specification of inputs and outputs needed by the server_result TF
|
||
|
|
// logic.
|
||
|
|
oneof server_result_io_router {
|
||
|
|
ServerResultIORouter result_router = 6;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Routing for server_prepare graph
|
||
|
|
message ServerPrepareIORouter {
|
||
|
|
// The name of the scalar string tensor in the server_prepare TF graph that
|
||
|
|
// is fed the filepath to the initial server state checkpoint. The
|
||
|
|
// server_prepare logic reads from this filepath.
|
||
|
|
string prepare_server_state_input_filepath_tensor_name = 1;
|
||
|
|
|
||
|
|
// The name of the scalar string tensor in the server_prepare TF graph that
|
||
|
|
// is fed the filepath where the client checkpoint should be stored. The
|
||
|
|
// server_prepare logic writes to this filepath.
|
||
|
|
string prepare_output_filepath_tensor_name = 2;
|
||
|
|
|
||
|
|
// The name of the scalar string tensor in the server_prepare TF graph that
|
||
|
|
// is fed the filepath where the intermediate state checkpoint should be
|
||
|
|
// stored. The server_prepare logic writes to this filepath. The intermediate
|
||
|
|
// state checkpoint will be consumed by both the logic used to set parameters
|
||
|
|
// for aggregation and the post-aggregation logic.
|
||
|
|
string prepare_intermediate_state_output_filepath_tensor_name = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Routing for server_result graph
|
||
|
|
message ServerResultIORouter {
|
||
|
|
// The name of the scalar string tensor in the server_result TF graph that is
|
||
|
|
// fed the filepath to the intermediate state checkpoint. The server_result
|
||
|
|
// logic reads from this filepath.
|
||
|
|
string result_intermediate_state_input_filepath_tensor_name = 1;
|
||
|
|
|
||
|
|
// The name of the scalar string tensor in the server_result TF graph that is
|
||
|
|
// fed the filepath to the aggregated client result checkpoint. The
|
||
|
|
// server_result logic reads from this filepath.
|
||
|
|
string result_aggregate_result_input_filepath_tensor_name = 2;
|
||
|
|
|
||
|
|
// The name of the scalar string tensor in the server_result TF graph that is
|
||
|
|
// fed the filepath where the updated server state should be stored. The
|
||
|
|
// server_result logic writes to this filepath.
|
||
|
|
string result_server_state_output_filepath_tensor_name = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents a single aggregation operation, combining one or more input
|
||
|
|
// tensors from a collection of clients into one or more output tensors on the
|
||
|
|
// server.
|
||
|
|
message ServerAggregationConfig {
|
||
|
|
// The uri of the aggregation intrinsic (e.g. 'federated_sum').
|
||
|
|
string intrinsic_uri = 1;
|
||
|
|
|
||
|
|
// Describes an argument to the aggregation operation.
|
||
|
|
message IntrinsicArg {
|
||
|
|
oneof arg {
|
||
|
|
// Refers to a tensor within the checkpoint provided by each client.
|
||
|
|
tensorflow.TensorSpecProto input_tensor = 2;
|
||
|
|
|
||
|
|
// Refers to a tensor within the intermediate server state checkpoint.
|
||
|
|
tensorflow.TensorSpecProto state_tensor = 3;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// List of arguments for the aggregation operation. The arguments can be
|
||
|
|
// dependent on client data (in which case they must be retrieved from
|
||
|
|
// clients) or they can be independent of client data (in which case they
|
||
|
|
// can be configured server-side). For now we assume all client-independent
|
||
|
|
// arguments are constants. The arguments must be in the order expected by
|
||
|
|
// the server.
|
||
|
|
repeated IntrinsicArg intrinsic_args = 4;
|
||
|
|
|
||
|
|
// List of server-side outputs produced by the aggregation operation.
|
||
|
|
repeated tensorflow.TensorSpecProto output_tensors = 5;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Server Phase
|
||
|
|
// ============
|
||
|
|
|
||
|
|
// Represents a server phase which implements TF-based aggregation of multiple
|
||
|
|
// client updates.
|
||
|
|
//
|
||
|
|
// There are two different modes of aggregation that are described
|
||
|
|
// by the values in this message. The first is aggregation that is
|
||
|
|
// coming from coordinated sets of clients. This includes aggregation
|
||
|
|
// done via checkpoints from clients or aggregation done over a set
|
||
|
|
// of clients by a process like secure aggregation. The results of
|
||
|
|
// this first aggregation are saved to intermediate aggregation
|
||
|
|
// checkpoints. The second aggregation then comes from taking
|
||
|
|
// these intermediate checkpoints and aggregating over them.
|
||
|
|
//
|
||
|
|
// These two different modes of aggregation are done on different
|
||
|
|
// servers, the first in the 'L1' servers and the second in the
|
||
|
|
// 'L2' servers, so we use this nomenclature to describe these
|
||
|
|
// phases below.
|
||
|
|
//
|
||
|
|
// The ServerPhase message is currently in the process of being replaced by the
|
||
|
|
// ServerPhaseV2 message as we switch the plan building pipeline to use
|
||
|
|
// DistributeAggregateForm instead of MapReduceForm. During the migration
|
||
|
|
// process, we may generate both messages and use components from either
|
||
|
|
// message during execution.
|
||
|
|
//
|
||
|
|
message ServerPhase {
|
||
|
|
// A short CamelCase name for the ServerPhase.
|
||
|
|
string name = 8;
|
||
|
|
|
||
|
|
// ===========================================================================
|
||
|
|
// L1 "Intermediate" Aggregation.
|
||
|
|
//
|
||
|
|
// This is the initial aggregation that creates partial aggregates from client
|
||
|
|
// results. L1 Aggregation may be run on many different instances.
|
||
|
|
//
|
||
|
|
// Pre-condition:
|
||
|
|
// The execution environment has loaded the graph from `server_graph_bytes`.
|
||
|
|
|
||
|
|
// 1. Initialize the phase.
|
||
|
|
//
|
||
|
|
// Operation to run before the first aggregation happens.
|
||
|
|
// For instance, clears the accumulators so that a new aggregation can begin.
|
||
|
|
string phase_init_op = 1;
|
||
|
|
|
||
|
|
// 2. For each client in set of clients:
|
||
|
|
// a. Restore variables from the client checkpoint.
|
||
|
|
//
|
||
|
|
// Loads a checkpoint from a single client written via
|
||
|
|
// `FederatedComputeIORouter.output_filepath_tensor_name`. This is done once
|
||
|
|
// for every client checkpoint in a round.
|
||
|
|
CheckpointOp read_update = 3;
|
||
|
|
// b. Aggregate the data coming from the client checkpoint.
|
||
|
|
//
|
||
|
|
// An operation that aggregates the data from read_update.
|
||
|
|
// Generally this will add to accumulators and it may leverage internal data
|
||
|
|
// inside the graph to adjust the weights of the Tensors.
|
||
|
|
//
|
||
|
|
// Executed once for each `read_update`, to (for example) update accumulator
|
||
|
|
// variables using the values loaded during `read_update`.
|
||
|
|
string aggregate_into_accumulators_op = 4;
|
||
|
|
|
||
|
|
// 3. After all clients have been aggregated, possibly restore
|
||
|
|
// variables that have been aggregated via a separate process.
|
||
|
|
//
|
||
|
|
// Optionally restores variables where aggregation is done across
|
||
|
|
// an entire round of client data updates. In contrast to `read_update`,
|
||
|
|
// which restores once per client, this occurs after all clients
|
||
|
|
// in a round have been processed. This allows, for example, side
|
||
|
|
// channels where aggregation is done by a separate process (such
|
||
|
|
// as in secure aggregation), in which the side channel aggregated
|
||
|
|
// tensor is passed to the `before_restore_op` which ensure the
|
||
|
|
// variables are restored properly. The `after_restore_op` will then
|
||
|
|
// be responsible for performing the accumulation.
|
||
|
|
//
|
||
|
|
// Note that in current use this should not have a SaverDef, but
|
||
|
|
// should only be used for side channels.
|
||
|
|
CheckpointOp read_aggregated_update = 10;
|
||
|
|
|
||
|
|
// 4. Write the aggregated variables to an intermediate checkpoint.
|
||
|
|
//
|
||
|
|
// We require that `aggregate_into_accumulators_op` is associative and
|
||
|
|
// commutative, so that the aggregates can be computed across
|
||
|
|
// multiple TensorFlow sessions.
|
||
|
|
// As an example, say we are computing the sum of 5 client updates:
|
||
|
|
// A = X1 + X2 + X3 + X4 + X5
|
||
|
|
// We can always do this in one session by calling `read_update`j and
|
||
|
|
// `aggregate_into_accumulators_op` once for each client checkpoint.
|
||
|
|
//
|
||
|
|
// Alternatively, we could compute:
|
||
|
|
// A1 = X1 + X2 in one TensorFlow session, and
|
||
|
|
// A2 = X3 + X4 + X5 in a different session.
|
||
|
|
// Each of these sessions can then write their accumulator state
|
||
|
|
// with the `write_intermediate_update` CheckpointOp, and a yet another third
|
||
|
|
// session can then call `read_intermediate_update` and
|
||
|
|
// `aggregate_into_accumulators_op` on each of these checkpoints to compute:
|
||
|
|
// A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5).
|
||
|
|
CheckpointOp write_intermediate_update = 7;
|
||
|
|
// End L1 "Intermediate" Aggregation.
|
||
|
|
// ===========================================================================
|
||
|
|
|
||
|
|
// ===========================================================================
|
||
|
|
// L2 Aggregation and Coordinator.
|
||
|
|
//
|
||
|
|
// This aggregates intermediate checkpoints from L1 Aggregation and performs
|
||
|
|
// the finalizing of the update. Unlike L1 there will only be one instance
|
||
|
|
// that does this aggregation.
|
||
|
|
|
||
|
|
// Pre-condition:
|
||
|
|
// The execution environment has loaded the graph from `server_graph_bytes`
|
||
|
|
// and restored the global model using `server_savepoint` from the parent
|
||
|
|
// `Plan` message.
|
||
|
|
|
||
|
|
// 1. Initialize the phase.
|
||
|
|
//
|
||
|
|
// This currently re-uses the `phase_init_op` from L1 aggregation above.
|
||
|
|
|
||
|
|
// 2. Write a checkpoint that can be sent to the client.
|
||
|
|
//
|
||
|
|
// Generates a checkpoint to be sent to the client, to be read by
|
||
|
|
// `FederatedComputeIORouter.input_filepath_tensor_name`.
|
||
|
|
|
||
|
|
CheckpointOp write_client_init = 2;
|
||
|
|
|
||
|
|
// 3. For each intermediate checkpoint:
|
||
|
|
// a. Restore variables from the intermediate checkpoint.
|
||
|
|
//
|
||
|
|
// The corresponding read checkpoint op to the write_intermediate_update.
|
||
|
|
// This is used instead of read_update for intermediate checkpoints because
|
||
|
|
// the format of these updates may be different than those used in updates
|
||
|
|
// from clients (which may, for example, be compressed).
|
||
|
|
CheckpointOp read_intermediate_update = 9;
|
||
|
|
// b. Aggregate the data coming from the intermediate checkpoint.
|
||
|
|
//
|
||
|
|
// An operation that aggregates the data from `read_intermediate_update`.
|
||
|
|
// Generally this will add to accumulators and it may leverage internal data
|
||
|
|
// inside the graph to adjust the weights of the Tensors.
|
||
|
|
string intermediate_aggregate_into_accumulators_op = 11;
|
||
|
|
|
||
|
|
// 4. Write the aggregated intermediate variables to a checkpoint.
|
||
|
|
//
|
||
|
|
// This is used for downstream, cross-round aggregation of metrics.
|
||
|
|
// These variables will be read back into a session with
|
||
|
|
// read_intermediate_update.
|
||
|
|
//
|
||
|
|
// Tasks which do not use FL metrics may unset the CheckpointOp.saver_def
|
||
|
|
// to disable writing accumulator checkpoints.
|
||
|
|
CheckpointOp write_accumulators = 12;
|
||
|
|
|
||
|
|
// 5. Finalize the round.
|
||
|
|
//
|
||
|
|
// This can include:
|
||
|
|
// - Applying the update aggregated from the intermediate checkpoints to the
|
||
|
|
// global model and other updates to cross-round state variables.
|
||
|
|
// - Computing final round metric values (e.g. the `report` of a
|
||
|
|
// `tff.federated_aggregate`).
|
||
|
|
string apply_aggregrated_updates_op = 5;
|
||
|
|
|
||
|
|
// 5. Fetch the server aggregated metrics.
|
||
|
|
//
|
||
|
|
// A list of names of metric variables to fetch from the TensorFlow session.
|
||
|
|
repeated Metric metrics = 6;
|
||
|
|
|
||
|
|
// 6. Serialize the updated server state (e.g. the coefficients of the global
|
||
|
|
// model in FL) using `server_savepoint` in the parent `Plan` message.
|
||
|
|
|
||
|
|
// End L2 Aggregation.
|
||
|
|
// ===========================================================================
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents the server phase in an eligibility computation.
|
||
|
|
//
|
||
|
|
// This phase produces a checkpoint to be sent to clients. This checkpoint is
|
||
|
|
// then used as an input to the clients' task eligibility computations.
|
||
|
|
// This phase *does not include any aggregation.*
|
||
|
|
message ServerEligibilityComputationPhase {
|
||
|
|
// A short CamelCase name for the ServerEligibilityComputationPhase.
|
||
|
|
string name = 1;
|
||
|
|
|
||
|
|
// The names of the TensorFlow nodes to run in order to produce output.
|
||
|
|
repeated string target_node_names = 2;
|
||
|
|
|
||
|
|
// The specification of inputs and outputs to the TensorFlow graph.
|
||
|
|
oneof server_eligibility_io_router {
|
||
|
|
TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents the inputs and outputs of a `ServerEligibilityComputationPhase`
|
||
|
|
// which takes a single `TaskEligibilityContext` as input.
|
||
|
|
message TEContextServerEligibilityIORouter {
|
||
|
|
// The name of the scalar string tensor that must be fed a serialized
|
||
|
|
// `TaskEligibilityContext`.
|
||
|
|
string context_proto_input_tensor_name = 1;
|
||
|
|
|
||
|
|
// The name of the scalar string tensor that must be fed the path to which
|
||
|
|
// the server graph should write the checkpoint file to be sent to the client.
|
||
|
|
string output_filepath_tensor_name = 2;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Plan
|
||
|
|
// =====
|
||
|
|
|
||
|
|
// Represents the overall plan for performing federated optimization or
|
||
|
|
// personalization, as handed over to the production system. This will
|
||
|
|
// typically be split down into individual pieces for different production
|
||
|
|
// parts, e.g. server and client side.
|
||
|
|
// NEXT_TAG: 15
|
||
|
|
message Plan {
|
||
|
|
reserved 1, 3, 5;
|
||
|
|
|
||
|
|
// The actual type of the server_*_graph_bytes fields below is expected to be
|
||
|
|
// tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form
|
||
|
|
// for two reasons.
|
||
|
|
// 1) We may use execution engines other than TensorFlow.
|
||
|
|
// 2) We wish to avoid the cost of deserialized and re-serializing large
|
||
|
|
// graphs, in the Federated Learning service.
|
||
|
|
|
||
|
|
// While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes,
|
||
|
|
// server_graph_prepare_bytes, and server_graph_result_bytes may all be set.
|
||
|
|
// If we're using a MapReduceForm-based server implementation, only
|
||
|
|
// server_graph_bytes will be used. If we're using a DistributeAggregateForm-
|
||
|
|
// based server implementation, only server_graph_prepare_bytes and
|
||
|
|
// server_graph_result_bytes will be used.
|
||
|
|
|
||
|
|
// Optional. The TensorFlow graph used for all server processing described by
|
||
|
|
// ServerPhase. For personalization, this will not be set.
|
||
|
|
google.protobuf.Any server_graph_bytes = 7;
|
||
|
|
|
||
|
|
// Optional. The TensorFlow graph used for all server processing described by
|
||
|
|
// ServerPhaseV2.tensorflow_spec_prepare.
|
||
|
|
google.protobuf.Any server_graph_prepare_bytes = 13;
|
||
|
|
|
||
|
|
// Optional. The TensorFlow graph used for all server processing described by
|
||
|
|
// ServerPhaseV2.tensorflow_spec_result.
|
||
|
|
google.protobuf.Any server_graph_result_bytes = 14;
|
||
|
|
|
||
|
|
// A savepoint to sync the server checkpoint with a persistent
|
||
|
|
// storage system. The storage initially holds a seeded checkpoint
|
||
|
|
// which can subsequently read and updated by this savepoint.
|
||
|
|
// Optional-- not present in eligibility computation plans (those with a
|
||
|
|
// ServerEligibilityComputationPhase). This is used in conjunction with
|
||
|
|
// ServerPhase only.
|
||
|
|
CheckpointOp server_savepoint = 2;
|
||
|
|
|
||
|
|
// Required. The TensorFlow graph that describes the TensorFlow logic a client
|
||
|
|
// should perform. It should be consistent with the `TensorflowSpec` field in
|
||
|
|
// the `client_phase`. The actual type is expected to be tensorflow.GraphDef.
|
||
|
|
// The TensorFlow graph is stored in serialized form for two reasons.
|
||
|
|
// 1) We may use execution engines other than TensorFlow.
|
||
|
|
// 2) We wish to avoid the cost of deserialized and re-serializing large
|
||
|
|
// graphs, in the Federated Learning service.
|
||
|
|
google.protobuf.Any client_graph_bytes = 8;
|
||
|
|
|
||
|
|
// Optional. The FlatBuffer used for TFLite training.
|
||
|
|
// It contains the same model information as the client_graph_bytes, but with
|
||
|
|
// a different format.
|
||
|
|
bytes client_tflite_graph_bytes = 12;
|
||
|
|
|
||
|
|
// A pair of client phase and server phase which are processed in
|
||
|
|
// sync. The server execution defines how the results of a client
|
||
|
|
// phase are aggregated, and how the checkpoints for clients are
|
||
|
|
// generated.
|
||
|
|
message Phase {
|
||
|
|
// Required. The client phase.
|
||
|
|
ClientPhase client_phase = 1;
|
||
|
|
|
||
|
|
// Optional. Server phase for TF-based aggregation; not provided for
|
||
|
|
// personalization or eligibility tasks.
|
||
|
|
ServerPhase server_phase = 2;
|
||
|
|
|
||
|
|
// Optional. Server phase for native aggregation; only provided for tasks
|
||
|
|
// that have enabled the corresponding flag.
|
||
|
|
ServerPhaseV2 server_phase_v2 = 4;
|
||
|
|
|
||
|
|
// Optional. Only provided for eligibility tasks.
|
||
|
|
ServerEligibilityComputationPhase server_eligibility_phase = 3;
|
||
|
|
}
|
||
|
|
|
||
|
|
// A pair of client and server computations to run.
|
||
|
|
repeated Phase phase = 4;
|
||
|
|
|
||
|
|
// Metrics that are persistent across different phases. This
|
||
|
|
// includes, for example, counters that track how much work of
|
||
|
|
// different kinds has been done.
|
||
|
|
repeated Metric metrics = 6;
|
||
|
|
|
||
|
|
// Describes how metrics in both the client and server phases should be
|
||
|
|
// aggregated.
|
||
|
|
repeated OutputMetric output_metrics = 10;
|
||
|
|
|
||
|
|
// Version of the plan:
|
||
|
|
// version == 0 - Old plan without version field, containing b/65131070
|
||
|
|
// version >= 1 - plan supports multi-shard aggregation mode (L1/L2)
|
||
|
|
int32 version = 9;
|
||
|
|
|
||
|
|
// A TensorFlow ConfigProto packed in an Any.
|
||
|
|
//
|
||
|
|
// If this field is unset, if the Any proto is set but empty, or if the Any
|
||
|
|
// proto is populated with an empty ConfigProto (i.e. its `type_url` field is
|
||
|
|
// set, but the `value` field is empty) then the client implementation may
|
||
|
|
// choose a set of configuration parameters to provide to TensorFlow by
|
||
|
|
// default.
|
||
|
|
//
|
||
|
|
// In all other cases this field must contain a valid packed ConfigProto
|
||
|
|
// (invalid values will result in an error at execution time), and in this
|
||
|
|
// case the client will not provide any other configuration parameters by
|
||
|
|
// default.
|
||
|
|
google.protobuf.Any tensorflow_config_proto = 11;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents a client part of the plan of federated optimization.
|
||
|
|
// This also used to describe a client-only plan for standalone on-device
|
||
|
|
// training, known as personalization.
|
||
|
|
// NEXT_TAG: 6
|
||
|
|
message ClientOnlyPlan {
|
||
|
|
reserved 3;
|
||
|
|
|
||
|
|
// The graph to use for training, in binary form.
|
||
|
|
bytes graph = 1;
|
||
|
|
|
||
|
|
// Optional. The flatbuffer used for TFLite training.
|
||
|
|
// Whether "graph" or "tflite_graph" is used for training is up to the client
|
||
|
|
// code to allow for a flag-controlled a/b rollout.
|
||
|
|
bytes tflite_graph = 5;
|
||
|
|
|
||
|
|
// The client phase to execute.
|
||
|
|
ClientPhase phase = 2;
|
||
|
|
|
||
|
|
// A TensorFlow ConfigProto.
|
||
|
|
google.protobuf.Any tensorflow_config_proto = 4;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Represents the cross round aggregation portion for user defined measurements.
|
||
|
|
// This is used by tools that process / analyze accumulator checkpoints
|
||
|
|
// after a round of computation, to achieve aggregation beyond a round.
|
||
|
|
message CrossRoundAggregationExecution {
|
||
|
|
// Operation to run before reading accumulator checkpoint.
|
||
|
|
string init_op = 1;
|
||
|
|
|
||
|
|
// Reads accumulator checkpoint.
|
||
|
|
CheckpointOp read_aggregated_update = 2;
|
||
|
|
|
||
|
|
// Operation to merge loaded checkpoint into accumulator.
|
||
|
|
string merge_op = 3;
|
||
|
|
|
||
|
|
// Reads and writes the final aggregated accumulator vars.
|
||
|
|
CheckpointOp read_write_final_accumulators = 6;
|
||
|
|
|
||
|
|
// Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable`
|
||
|
|
// to the user defined name of the signal.
|
||
|
|
repeated Measurement measurements = 4;
|
||
|
|
|
||
|
|
// The `tf.Graph` used for aggregating accumulator checkpoints when
|
||
|
|
// loading metrics.
|
||
|
|
google.protobuf.Any cross_round_aggregation_graph_bytes = 5;
|
||
|
|
}
|
||
|
|
|
||
|
|
message Measurement {
|
||
|
|
// Name of a TensorFlow op to run to read/fetch the value of this measurement.
|
||
|
|
string read_op_name = 1;
|
||
|
|
|
||
|
|
// A human-readable name for the measurement. Names are usually
|
||
|
|
// camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
|
||
|
|
string name = 2;
|
||
|
|
|
||
|
|
reserved 3;
|
||
|
|
|
||
|
|
// A serialized `tff.Type` for the measurement.
|
||
|
|
bytes tff_type = 4;
|
||
|
|
}
|