push.particle
- class push.particle.Particle(node_event_loop, device: int, pid: int, module: Module, state: Dict[str, any])
Bases:
WaitableUser-facing particle interface.
Implements concurrent particles.
- forward(x: Tensor, *args: any) PFuture
Performs a forward pass.
- Parameters:
x (torch.Tensor) – Input to the particle.
- Returns:
Future that eventually contains the value of the forward pass.
- Return type:
- other_particles() List[int]
Returns all particles except current particle.
- Returns:
List of all particle identifiers visible to current particle except current particle.
- Return type:
List[int]
- particle_ids() List[int]
Returns all particles.
- Returns:
List of all particle identifiers visible to current particle.
- Return type:
List[int]
- register_receive(msg: str, fn: Callable, state: dict[str, any]) None
Register receive functionality for current particle.
- scheduler_step(*args) PFuture
Performs a forward and backward pass using the registered optimizer.
- Parameters:
loss_fn (Callable) – Loss function to take a step with respect to.
data (torch.Tensor) – Data.
label (torch.Tensor) – label.
- Returns:
Future that eventually contains the loss of the step.
- Return type:
- send(pid: int, msg: str, *args: any) PFuture
Send a msg to pid with arguments *args from current particle.
- step(loss_fn: Callable, data: Tensor, label: Tensor, *args) PFuture
Performs a forward and backward pass using the registered optimizer.
- Parameters:
loss_fn (Callable) – Loss function to take a step with respect to.
data (torch.Tensor) – Data.
label (torch.Tensor) – label.
- Returns:
Future that eventually contains the loss of the step.
- Return type: