jax.distributed Walkthrough
jax.distributed and XLA Coordination Walkthrough
Commit: JAX 0.4.35 (81991d8); XLA (76da730)
Use case:
jax.distributed.initialize(
coordinator_address="192.168.0.1:1234",
num_processes=2,
process_id=RANK
)
jitted_function()
jax.distributed.shutdown()
What happens under the hood of jax.distributed.initialize()?
COORDINATION SERVICE
jax.distributed.initialize()
- Process 0 starts the service
byxla_extension.get_distributed_runtime_serviceinjax/_src/distributed.py
self.service = xla_extension.get_distributed_runtime_service(
coordinator_bind_address, num_processes
)
- →
get_distributed_runtime_serviceinxla/python/xla.cc - →
GetDistributedRuntimeServiceinxla/pjrt/distributed/distributed.cc - →
DistributedRuntimeServiceinxla/pjrt/distributed/service.cc - →
CoordinationServiceImplinxla/pjrt/distributed/service.cc
Creates the following two services:
CoordinationServiceStandaloneImpl(registered fortsl::CoordinationServiceInterface) at
xla/tsl/distributed_runtime/coordination/coordination_service.cctsl::GrpcCoordinationServiceImplin
/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h
COORDINATION CLIENT (AGENT)
jax.distributed.initialize()
- All processes start clients
byxla_extension.get_distributed_runtime_clientinjax/_src/distributed.py
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, init_timeout=initialization_timeout
)
- →
get_distributed_runtime_clientinxla/python/xla.cc - →
GetDistributedRuntimeClientinxla/pjrt/distributed/client.cc - →
DistributedRuntimeCoordinationServiceClientinxla/pjrt/distributed/client.cc - →
coord_agent_->Initialize(options.env, "jax_worker", options.node_id, config, std::move(leader_client), error_fn);
inxla/pjrt/distributed/client.cctask_id = node_id
- →
CoordinationServiceAgentImplin
xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc - →
CoordinationServiceAgentImpl::Initialize
Client Connection Flow
self.client.connect()
- →
CoordinationServiceAgentImpl::Connect()incoordination_service_agent.cc - →
leader_client_->RegisterTaskAsync(req, ...)incoordination_service_agent.cc
(in 76da730 Async is converted todone(RegisterTask)in the rpc_handle) - →
CoordinationServiceStandaloneImpl::RegisterTaskincoordination_service.cc - →
task_cluster_state->SetConnected(incarnation)task_cluster_stateis the value for key"jax_worker/NODE_ID"incarnationis a random number to index instance, different for standby
void CoordinationServiceStandaloneImpl::TaskState::SetConnected(
uint64_t task_incarnation
) {
state_ = CoordinatedTaskState::TASKSTATE_CONNECTED;
status_ = absl::OkStatus();
task_incarnation_ = task_incarnation;
absl::MutexLock l(&last_heartbeat_mu_);
last_heartbeat_us_ = Env::Default()->NowMicros();
}
Shutdown Flow
jax.distributed.shutdown()
- →
self.client.shutdown()injax/_src/distributed.py - →
CoordinationServiceAgentImpl::ShutdownInternal()in
xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc- Sends rpc request → leader_client →
CoordinationServiceStandaloneImpl::ShutdownTaskAsync - Client will shutdown regardless of service success
- Sends rpc request → leader_client →
- →
self.service.shutdown()injax/_src/distributed.py