224 lines
8.4 KiB
Java
224 lines
8.4 KiB
Java
|
|
/*
|
||
|
|
* Copyright (C) 2023 The Android Open Source Project
|
||
|
|
*
|
||
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
* you may not use this file except in compliance with the License.
|
||
|
|
* You may obtain a copy of the License at
|
||
|
|
*
|
||
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
*
|
||
|
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
* See the License for the specific language governing permissions and
|
||
|
|
* limitations under the License.
|
||
|
|
*/
|
||
|
|
|
||
|
|
package android.federatedcompute;
|
||
|
|
|
||
|
|
import static java.util.concurrent.TimeUnit.MILLISECONDS;
|
||
|
|
|
||
|
|
import android.annotation.CallbackExecutor;
|
||
|
|
import android.annotation.NonNull;
|
||
|
|
import android.annotation.Nullable;
|
||
|
|
import android.content.ComponentName;
|
||
|
|
import android.content.Context;
|
||
|
|
import android.content.Intent;
|
||
|
|
import android.content.ServiceConnection;
|
||
|
|
import android.content.pm.ResolveInfo;
|
||
|
|
import android.content.pm.ServiceInfo;
|
||
|
|
import android.federatedcompute.aidl.IFederatedComputeCallback;
|
||
|
|
import android.federatedcompute.aidl.IFederatedComputeService;
|
||
|
|
import android.federatedcompute.common.ScheduleFederatedComputeRequest;
|
||
|
|
import android.ondevicepersonalization.OnDevicePersonalizationException;
|
||
|
|
import android.os.IBinder;
|
||
|
|
import android.os.OutcomeReceiver;
|
||
|
|
import android.os.RemoteException;
|
||
|
|
import android.util.Log;
|
||
|
|
|
||
|
|
import com.android.internal.annotations.GuardedBy;
|
||
|
|
|
||
|
|
import java.util.List;
|
||
|
|
import java.util.Objects;
|
||
|
|
import java.util.concurrent.CountDownLatch;
|
||
|
|
import java.util.concurrent.Executor;
|
||
|
|
|
||
|
|
/**
|
||
|
|
* FederatedCompute Manager.
|
||
|
|
*
|
||
|
|
* @hide
|
||
|
|
*/
|
||
|
|
final class FederatedComputeManager {
|
||
|
|
private static final String TAG = "FederatedComputeManager";
|
||
|
|
private static final String FEDERATED_COMPUTATION_SERVICE_INTENT_FILTER_NAME =
|
||
|
|
"android.federatedcompute.FederatedComputeService";
|
||
|
|
private static final int BINDER_CONNECTION_TIMEOUT_MS = 5000;
|
||
|
|
|
||
|
|
// A CountDownloadLatch which will be opened when the connection is established or any error
|
||
|
|
// occurs.
|
||
|
|
private CountDownLatch mConnectionCountDownLatch;
|
||
|
|
// Concurrency mLock.
|
||
|
|
private final Object mLock = new Object();
|
||
|
|
|
||
|
|
@GuardedBy("mLock")
|
||
|
|
private IFederatedComputeService mFcpService;
|
||
|
|
|
||
|
|
@GuardedBy("mLock")
|
||
|
|
private ServiceConnection mServiceConnection;
|
||
|
|
|
||
|
|
private final Context mContext;
|
||
|
|
|
||
|
|
FederatedComputeManager(Context context) {
|
||
|
|
this.mContext = context;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Schedule FederatedCompute task.
|
||
|
|
*
|
||
|
|
* @hide
|
||
|
|
*/
|
||
|
|
public void scheduleFederatedCompute(
|
||
|
|
@NonNull ScheduleFederatedComputeRequest request,
|
||
|
|
@NonNull @CallbackExecutor Executor executor,
|
||
|
|
@NonNull OutcomeReceiver<Object, Exception> callback) {
|
||
|
|
Objects.requireNonNull(request);
|
||
|
|
final IFederatedComputeService service = getService(executor);
|
||
|
|
try {
|
||
|
|
IFederatedComputeCallback federatedComputeCallback =
|
||
|
|
new IFederatedComputeCallback.Stub() {
|
||
|
|
@Override
|
||
|
|
public void onSuccess() {
|
||
|
|
executor.execute(() -> callback.onResult(null));
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public void onFailure(int errorCode) {
|
||
|
|
executor.execute(
|
||
|
|
() ->
|
||
|
|
callback.onError(
|
||
|
|
new OnDevicePersonalizationException(
|
||
|
|
errorCode)));
|
||
|
|
}
|
||
|
|
};
|
||
|
|
service.scheduleFederatedCompute(
|
||
|
|
request.getTrainingOptions(), federatedComputeCallback);
|
||
|
|
} catch (RemoteException e) {
|
||
|
|
Log.e(TAG, "Remote Exception", e);
|
||
|
|
executor.execute(() -> callback.onError(e));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
private IFederatedComputeService getService(@NonNull Executor executor) {
|
||
|
|
synchronized (mLock) {
|
||
|
|
if (mFcpService != null) {
|
||
|
|
return mFcpService;
|
||
|
|
}
|
||
|
|
if (mServiceConnection == null) {
|
||
|
|
Intent intent = new Intent(FEDERATED_COMPUTATION_SERVICE_INTENT_FILTER_NAME);
|
||
|
|
ComponentName serviceComponent = resolveService(intent);
|
||
|
|
if (serviceComponent == null) {
|
||
|
|
Log.e(TAG, "Invalid component for federatedcompute service");
|
||
|
|
throw new IllegalStateException(
|
||
|
|
"Invalid component for federatedcompute service");
|
||
|
|
}
|
||
|
|
intent.setComponent(serviceComponent);
|
||
|
|
// This latch will open when the connection is established or any error occurs.
|
||
|
|
mConnectionCountDownLatch = new CountDownLatch(1);
|
||
|
|
mServiceConnection = new FederatedComputeServiceConnection();
|
||
|
|
boolean result =
|
||
|
|
mContext.bindService(
|
||
|
|
intent, Context.BIND_AUTO_CREATE, executor, mServiceConnection);
|
||
|
|
if (!result) {
|
||
|
|
mServiceConnection = null;
|
||
|
|
throw new IllegalStateException("Unable to bind to the service");
|
||
|
|
} else {
|
||
|
|
Log.i(TAG, "bindService() succeeded...");
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
Log.i(TAG, "bindService() already pending...");
|
||
|
|
}
|
||
|
|
try {
|
||
|
|
mConnectionCountDownLatch.await(BINDER_CONNECTION_TIMEOUT_MS, MILLISECONDS);
|
||
|
|
} catch (InterruptedException e) {
|
||
|
|
throw new IllegalStateException("Thread interrupted"); // TODO Handle it better.
|
||
|
|
}
|
||
|
|
synchronized (mLock) {
|
||
|
|
if (mFcpService == null) {
|
||
|
|
throw new IllegalStateException("Failed to connect to the service");
|
||
|
|
}
|
||
|
|
return mFcpService;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Find the ComponentName of the service, given its intent and package manager.
|
||
|
|
*
|
||
|
|
* @return ComponentName of the service. Null if the service is not found.
|
||
|
|
*/
|
||
|
|
@Nullable
|
||
|
|
private ComponentName resolveService(@NonNull Intent intent) {
|
||
|
|
List<ResolveInfo> services = mContext.getPackageManager().queryIntentServices(intent, 0);
|
||
|
|
if (services == null || services.isEmpty()) {
|
||
|
|
Log.e(TAG, "Failed to find federatedcompute service");
|
||
|
|
return null;
|
||
|
|
}
|
||
|
|
|
||
|
|
for (int i = 0; i < services.size(); i++) {
|
||
|
|
ServiceInfo serviceInfo = services.get(i).serviceInfo;
|
||
|
|
if (serviceInfo == null) {
|
||
|
|
Log.e(TAG, "Failed to find serviceInfo for federatedcompute service.");
|
||
|
|
return null;
|
||
|
|
}
|
||
|
|
// There should only be one matching service inside the given package.
|
||
|
|
// If there's more than one, return the first one found.
|
||
|
|
return new ComponentName(serviceInfo.packageName, serviceInfo.name);
|
||
|
|
}
|
||
|
|
Log.e(TAG, "Didn't find any matching federatedcompute service.");
|
||
|
|
return null;
|
||
|
|
}
|
||
|
|
|
||
|
|
public void unbindFromService() {
|
||
|
|
synchronized (mLock) {
|
||
|
|
if (mServiceConnection != null) {
|
||
|
|
Log.i(TAG, "unbinding...");
|
||
|
|
mContext.unbindService(mServiceConnection);
|
||
|
|
}
|
||
|
|
mServiceConnection = null;
|
||
|
|
mFcpService = null;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
private class FederatedComputeServiceConnection implements ServiceConnection {
|
||
|
|
@Override
|
||
|
|
public void onServiceConnected(ComponentName name, IBinder service) {
|
||
|
|
Log.d(TAG, "onServiceConnected");
|
||
|
|
synchronized (mLock) {
|
||
|
|
mFcpService = IFederatedComputeService.Stub.asInterface(service);
|
||
|
|
}
|
||
|
|
mConnectionCountDownLatch.countDown();
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public void onServiceDisconnected(ComponentName name) {
|
||
|
|
Log.d(TAG, "onServiceDisconnected");
|
||
|
|
unbindFromService();
|
||
|
|
mConnectionCountDownLatch.countDown();
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public void onBindingDied(ComponentName name) {
|
||
|
|
Log.e(TAG, "onBindingDied");
|
||
|
|
unbindFromService();
|
||
|
|
mConnectionCountDownLatch.countDown();
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public void onNullBinding(ComponentName name) {
|
||
|
|
Log.e(TAG, "onNullBinding shouldn't happen.");
|
||
|
|
unbindFromService();
|
||
|
|
mConnectionCountDownLatch.countDown();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|