Skip to main content

Training Rounds

A training round is the fundamental unit of work in federated learning. Each round orchestrates the entire cycle of device selection, model distribution, local training, update collection, and aggregation.

Round Architecture

Complete Round Lifecycle

Round Flow Diagram

Figure 2: Detailed sequence diagram of a single training round with 100 devices. Note the parallel training phase where devices work independently.

1. Device Selection

The server selects a cohort of devices to participate based on:

  • Availability: Device is online and idle
  • Resource constraints: Sufficient battery, storage, and network bandwidth
  • Eligibility criteria: Meets minimum requirements for participation
  • Sampling strategy: Random, stratified, or metric-based selection

In EdgeML, you control selection through the min_updates parameter:

federation.train(
model="my-classifier",
rounds=10,
min_updates=100 # Wait for at least 100 device updates per round
)

2. Model Distribution

Selected devices download the current global model. EdgeML optimizes this through:

  • Format conversion: Automatic conversion to ONNX, TFLite, or CoreML for efficient mobile deployment
  • Compression: Model quantization and pruning to reduce download size
  • Caching: Devices cache models to avoid redundant downloads

From the Python SDK:

# Edge device downloads the latest model
model_bytes = client.pull_model(
model="my-classifier",
version="1.0.0",
format="pytorch"
)

3. Local Training

Each device trains the model on its private local data for several epochs. This is where the actual learning happens.

Key parameters:

  • Local epochs: Number of training passes over local data (typically 1-5)
  • Batch size: Training batch size (device-dependent)
  • Learning rate: Step size for gradient updates (often 0.01-0.001)

Example local training function:

def train_locally(base_state_dict):
"""Train model on local data"""
model = MyModel()
model.load_state_dict(base_state_dict)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# Train for multiple epochs
for epoch in range(3):
for X_batch, y_batch in local_dataloader:
optimizer.zero_grad()
loss = loss_fn(model(X_batch), y_batch)
loss.backward()
optimizer.step()

return model.state_dict(), len(local_data), {"loss": loss.item()}

4. Update Upload

Devices compute and upload model updates. EdgeML supports two formats:

Full Weights (update_format="weights"):

  • Send complete trained model weights
  • Simpler but larger payload (~10-100MB for typical models)
  • Use for first round or when model architecture changes

Delta Updates (update_format="delta"):

  • Send only the difference between base and trained weights
  • 10-100x smaller payload
  • More efficient for iterative training
client.train_from_remote(
model="my-classifier",
local_train_fn=train_locally,
update_format="delta", # Send only weight changes
rounds=5
)

5. Aggregation

The server aggregates updates using Federated Averaging (FedAvg):

w_global = Σ(n_k / n_total) × w_k

where:
- w_k = weights from device k
- n_k = number of training samples on device k
- n_total = total samples across all devices

Devices with more training data have proportionally more influence on the global model. This weighted approach ensures:

  • Fair representation: Large datasets don't dominate small ones
  • Faster convergence: More data means more reliable updates
  • Better generalization: Global model learns from diverse data distributions

6. Publication

The new global model is versioned and published:

result = federation.train(
model="my-classifier",
base_version="1.0.0",
new_version="1.1.0", # Explicit version
publish=True # Auto-publish for devices
)

EdgeML stores:

  • Model weights
  • Training metadata (round number, participant count, convergence metrics)
  • Version lineage (which version this was trained from)
  • Deployment status (which devices should receive this version)

Round Parameters and Tuning

Critical Parameters

ParameterDefaultPurposeTuning Guidance
rounds1Number of training iterations10-100 for convergence; more for complex models
min_updates1Minimum device updates per roundHigher = better quality, slower rounds
local_epochs1-3Training passes per device1-3 typical; more can cause overfitting
update_format"weights"Update payload typeUse "delta" after first round

Convergence Monitoring

Track these metrics to know when to stop training:

  • Loss: Should decrease over rounds
  • Validation accuracy: Should plateau when model converges
  • Update variance: High variance indicates non-IID data or instability
  • Participation rate: Consistent participation improves convergence

Example:

for round_num in range(20):
result = federation.train(
model="my-classifier",
rounds=1, # One round at a time
min_updates=100
)

# Check convergence
if result['metrics']['loss'] < 0.1:
print(f"Converged at round {round_num}")
break

Round Coordination Challenges

Stragglers

Some devices take longer to train due to:

  • Limited computational resources
  • Poor network connectivity
  • Background activity

EdgeML handles stragglers through:

  • Timeouts: Rounds don't wait indefinitely
  • Partial aggregation: Accept updates from fast devices
  • Asynchronous rounds: Devices can join the next available round

Communication Efficiency

Federated learning requires multiple communication rounds. Optimize by:

  1. Reducing round count: Better local training (more epochs, better optimizers)
  2. Compressing updates: Delta encoding, quantization, sparsification
  3. Batching updates: Group multiple local epochs before uploading

Research shows that 10-50 rounds typically achieve good convergence for most tasks [McMahan et al., 2017], compared to thousands of iterations in centralized training.

Byzantine Failures

Malicious or buggy devices might send corrupted updates. Mitigation strategies:

  • Statistical outlier detection: Reject updates far from the median
  • Secure aggregation: Cryptographically verify update integrity
  • Reputation systems: Track device reliability over time

EdgeML's MVP uses basic outlier detection; advanced cryptographic protections are planned for future releases.

Real-World Round Examples

Mobile Keyboard (Gboard)

  • Rounds: 100-500 over several days
  • Devices per round: 5,000-50,000
  • Local epochs: 1-2
  • Update format: Compressed deltas
  • Result: Improved next-word prediction without collecting keystrokes

Healthcare Consortium

  • Rounds: 20-50
  • Devices per round: 10-100 hospitals
  • Local epochs: 5-10 (more data per site)
  • Update format: Full weights (fewer participants)
  • Result: Collaborative disease prediction model across institutions

References

  1. McMahan, B., et al. (2017). "Communication-Efficient Learning of Deep Networks from Decentralized Data." AISTATS. [arXiv:1602.05629]

    • Foundational paper on federated learning rounds and FedAvg
  2. Bonawitz, K., et al. (2019). "Towards Federated Learning at Scale: System Design." MLSys. [arXiv:1902.01046]

    • Production system design for handling millions of devices
  3. Li, T., et al. (2020). "Federated Learning: Challenges, Methods, and Future Directions." IEEE Signal Processing Magazine. [arXiv:1908.07873]

    • Comprehensive overview of federated learning systems and round optimization
  4. Wang, S., et al. (2021). "Adaptive Federated Learning in Resource Constrained Edge Computing Systems." IEEE JSAC. [arXiv:1804.05271]

    • Handling device heterogeneity and stragglers

Next Steps