85 lines
3.5 KiB
Diff
85 lines
3.5 KiB
Diff
Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change.
|
|
This patch will be applied only until TF's TFRT commit is automatically bumped.
|
|
|
|
---
|
|
|
|
diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h
|
|
index 3d311c3..a216716 100644
|
|
--- a/backends/gpu/include/tfrt/gpu/gpu_types.h
|
|
+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h
|
|
@@ -295,11 +295,7 @@
|
|
wrapper::CurrentContext current, wrapper::Stream stream,
|
|
wrapper::CclComm comm)>;
|
|
|
|
- explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
|
|
- wrapper::OwningCclComm comm, int num_ranks);
|
|
- // TODO(hanbinyoon): Remove after transitioning to the above constructor.
|
|
- explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
|
|
- wrapper::OwningCclComm comm);
|
|
+ GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm);
|
|
~GpuCclHandle();
|
|
|
|
GpuCclHandle(GpuCclHandle&&) = default;
|
|
@@ -311,8 +307,6 @@
|
|
llvm::Error ExecuteCallbacks(wrapper::CurrentContext current,
|
|
wrapper::Stream stream);
|
|
|
|
- int num_ranks() const { return num_ranks_; }
|
|
-
|
|
const wrapper::OwningCclComm& operator->() const { return comm_; }
|
|
wrapper::CclComm get() const { return comm_.get(); }
|
|
wrapper::CclComm release();
|
|
@@ -322,7 +316,6 @@
|
|
private:
|
|
AsyncValueRef<GpuContext> context_;
|
|
wrapper::OwningCclComm comm_;
|
|
- int num_ranks_;
|
|
std::vector<Callback> callbacks_;
|
|
};
|
|
|
|
diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc
|
|
index 38529bc..01e3dba 100644
|
|
--- a/backends/gpu/lib/gpu_types.cc
|
|
+++ b/backends/gpu/lib/gpu_types.cc
|
|
@@ -214,15 +214,8 @@
|
|
GpuBlasHandle::~GpuBlasHandle() = default;
|
|
|
|
GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
|
|
- wrapper::OwningCclComm comm, int num_ranks)
|
|
- : context_(std::move(context)),
|
|
- comm_(std::move(comm)),
|
|
- num_ranks_(num_ranks) {}
|
|
-
|
|
-// TODO(hanbinyoon): Remove after transitioning to the above constructor.
|
|
-GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
|
|
wrapper::OwningCclComm comm)
|
|
- : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {}
|
|
+ : context_(std::move(context)), comm_(std::move(comm)) {}
|
|
|
|
GpuCclHandle::~GpuCclHandle() = default;
|
|
|
|
diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc
|
|
index 52ce820..9cfc1de 100644
|
|
--- a/backends/gpu/lib/kernels/ccl_kernels.cc
|
|
+++ b/backends/gpu/lib/kernels/ccl_kernels.cc
|
|
@@ -107,8 +107,6 @@
|
|
auto width = ToWidthInBytes(type);
|
|
if (!width) return width.takeError();
|
|
assert(*width != 0);
|
|
- if (input->size() != output->size() * handle->num_ranks())
|
|
- return MakeStringError("Input size must be output size times ranks.");
|
|
|
|
handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(),
|
|
recvcount = output->size() / *width, type,
|
|
@@ -116,6 +114,10 @@
|
|
wrapper::CurrentContext current,
|
|
wrapper::Stream stream,
|
|
wrapper::CclComm comm) -> llvm::Error {
|
|
+ auto count = wrapper::CclCommCount(comm);
|
|
+ if (!count) return count.takeError();
|
|
+ if (input->size() != output->size() * *count)
|
|
+ return MakeStringError("Input size must be output size times ranks.");
|
|
return wrapper::CclReduceScatter(current, input->pointer(),
|
|
output->pointer(), recvcount, type, op,
|
|
comm, stream);
|