jax.distributed Walkthrough
jax.distributed and XLA Coordination Walkthrough
Commit: JAX 0.4.35 (81991d8); XLA (76da730)
Use case:
jax.distributed.initialize(
coordinator_address="192.168.0.1:1234",
num_processes=2,
process_id=RANK
)
jitted_function()
jax.distributed.shutdown()
What happens under the hood of jax.distributed.initialize()
?
COORDINATION SERVICE
jax.distributed.initialize()
- Process 0 starts the service
byxla_extension.get_distributed_runtime_service
injax/_src/distributed.py
self.service = xla_extension.get_distributed_runtime_service(
coordinator_bind_address, num_processes
)
- →
get_distributed_runtime_service
inxla/python/xla.cc
- →
GetDistributedRuntimeService
inxla/pjrt/distributed/distributed.cc
- →
DistributedRuntimeService
inxla/pjrt/distributed/service.cc
- →
CoordinationServiceImpl
inxla/pjrt/distributed/service.cc
Creates the following two services:
CoordinationServiceStandaloneImpl
(registered fortsl::CoordinationServiceInterface
) at
xla/tsl/distributed_runtime/coordination/coordination_service.cc
tsl::GrpcCoordinationServiceImpl
in
/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h
COORDINATION CLIENT (AGENT)
jax.distributed.initialize()
- All processes start clients
byxla_extension.get_distributed_runtime_client
injax/_src/distributed.py
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, init_timeout=initialization_timeout
)
- →
get_distributed_runtime_client
inxla/python/xla.cc
- →
GetDistributedRuntimeClient
inxla/pjrt/distributed/client.cc
- →
DistributedRuntimeCoordinationServiceClient
inxla/pjrt/distributed/client.cc
- →
coord_agent_->Initialize(options.env, "jax_worker", options.node_id, config, std::move(leader_client), error_fn);
inxla/pjrt/distributed/client.cc
task_id = node_id
- →
CoordinationServiceAgentImpl
in
xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
- →
CoordinationServiceAgentImpl::Initialize
Client Connection Flow
self.client.connect()
- →
CoordinationServiceAgentImpl::Connect()
incoordination_service_agent.cc
- →
leader_client_->RegisterTaskAsync(req, ...)
incoordination_service_agent.cc
(in 76da730 Async is converted todone(RegisterTask)
in the rpc_handle) - →
CoordinationServiceStandaloneImpl::RegisterTask
incoordination_service.cc
- →
task_cluster_state->SetConnected(incarnation)
task_cluster_state
is the value for key"jax_worker/NODE_ID"
incarnation
is a random number to index instance, different for standby
void CoordinationServiceStandaloneImpl::TaskState::SetConnected(
uint64_t task_incarnation
) {
state_ = CoordinatedTaskState::TASKSTATE_CONNECTED;
status_ = absl::OkStatus();
task_incarnation_ = task_incarnation;
absl::MutexLock l(&last_heartbeat_mu_);
last_heartbeat_us_ = Env::Default()->NowMicros();
}
Shutdown Flow
jax.distributed.shutdown()
- →
self.client.shutdown()
injax/_src/distributed.py
- →
CoordinationServiceAgentImpl::ShutdownInternal()
in
xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc
- Sends rpc request → leader_client →
CoordinationServiceStandaloneImpl::ShutdownTaskAsync
- Client will shutdown regardless of service success
- Sends rpc request → leader_client →
- →
self.service.shutdown()
injax/_src/distributed.py