Skip to main content

Python SDK Reference

Complete API reference for the EdgeML Python SDK.

Installation

pip install edgeml-sdk

PyPI: pypi.org/project/edgeml-sdk

Modules

The SDK provides three main classes:

from edgeml import Federation, FederatedClient, ModelRegistry

Federation

Server-side orchestrator for coordinating federated training across devices.

Constructor

Federation(
api_key: str,
name: str | None = None,
org_id: str = "default",
api_base: str = "https://api.edgeml.io/api/v1"
)

Parameters:

  • api_key - Your EdgeML API key (starts with ek_live_...)
  • name - Federation name (default: "default")
  • org_id - Organization ID (default: "default")
  • api_base - API endpoint (default: production)

Example:

federation = Federation(
api_key="ek_live_...",
name="my-federation"
)

Methods

train()

Run federated training with FedAvg aggregation.

federation.train(
model: str,
algorithm: str = "fedavg",
rounds: int = 1,
min_updates: int = 1,
base_version: str | None = None,
new_version: str | None = None,
publish: bool = True,
strategy: str = "metrics",
update_format: str = "delta"
) -> dict

Parameters:

  • model - Model name or ID
  • algorithm - Aggregation algorithm (currently only "fedavg")
  • rounds - Number of training rounds
  • min_updates - Minimum device updates required per round
  • base_version - Starting model version (default: latest)
  • new_version - New version name (default: auto-generated)
  • publish - Auto-publish new version (default: True)
  • strategy - Selection strategy (default: "metrics")
  • update_format - Update format: "delta" or "weights"

Returns: Dict with new_version, model_id, aggregation metadata

Example:

result = federation.train(
model="my-classifier",
rounds=10,
min_updates=100,
base_version="1.0.0",
new_version="1.1.0"
)
print(f"New version: {result['new_version']}")

deploy()

Deploy a model version with progressive rollout.

federation.deploy(
model_id: str | None = None,
version: str | None = None,
rollout_percentage: int = 10,
target_percentage: int = 100,
increment_step: int = 10,
start_immediately: bool = True
) -> dict

Parameters:

  • model_id - Model ID (default: last trained model)
  • version - Version to deploy (default: latest)
  • rollout_percentage - Initial rollout % (default: 10)
  • target_percentage - Target rollout % (default: 100)
  • increment_step - Increment step % (default: 10)
  • start_immediately - Start rollout now (default: True)

Returns: Dict with deployment ID and status

Example:

deployment = federation.deploy(
version="1.1.0",
rollout_percentage=5,
target_percentage=100,
increment_step=5
)

FederatedClient

Edge device client for participating in federated training.

Constructor

FederatedClient(
api_key: str,
org_id: str = "default",
api_base: str = "https://api.edgeml.io/api/v1",
device_identifier: str | None = None,
platform: str = "python"
)

Parameters:

  • api_key - Your EdgeML API key
  • org_id - Organization ID
  • api_base - API endpoint
  • device_identifier - Unique device ID (default: auto-generated)
  • platform - Platform name (default: "python")

Example:

client = FederatedClient(
api_key="ek_live_...",
device_identifier="laptop-001",
platform="python"
)

Methods

register()

Register device with EdgeML.

client.register() -> str

Returns: Device ID

Example:

device_id = client.register()
print(f"Registered as: {device_id}")

join_federation()

Join a federation.

client.join_federation(federation_name: str) -> dict

Example:

client.join_federation("my-federation")

train()

Submit local training weights.

client.train(
model: str,
local_data: Any,
rounds: int = 1,
version: str | None = None,
sample_count: int = 0,
metrics: dict[str, float] | None = None,
update_format: str = "delta"
) -> list[dict]

Parameters:

  • model - Model name or ID
  • local_data - Weights (torch.nn.Module, state_dict, bytes, or callable)
  • rounds - Number of rounds
  • version - Model version (default: latest)
  • sample_count - Number of training samples
  • metrics - Training metrics (e.g., {"loss": 0.5, "accuracy": 0.95})
  • update_format - "delta" (differences) or "weights" (full weights)

Example:

# Submit pre-trained weights
results = client.train(
model="my-classifier",
local_data=model.state_dict(),
sample_count=1000,
metrics={"loss": 0.42, "accuracy": 0.89}
)

# Or use a callable for lazy evaluation
def get_weights():
# Train model...
return model.state_dict(), 1000, {"loss": 0.42}

results = client.train(
model="my-classifier",
local_data=get_weights,
rounds=5
)

train_from_remote()

Pull model, train locally, submit updates.

client.train_from_remote(
model: str,
local_train_fn: Callable,
rounds: int = 1,
version: str | None = None,
update_format: str = "weights",
format: str = "pytorch"
) -> list[dict]

Parameters:

  • model - Model name or ID
  • local_train_fn - Training function: (state_dict) -> (new_state_dict, sample_count, metrics)
  • rounds - Number of training rounds
  • version - Starting version (default: latest)
  • update_format - "delta" or "weights"
  • format - Model format: "pytorch", "onnx", "tflite", "coreml"

Example:

def train_locally(base_state_dict):
# Load model
model = MyModel()
model.load_state_dict(base_state_dict)

# Train on local data
for epoch in range(3):
train_one_epoch(model, local_dataloader)

# Return updated weights
return model.state_dict(), len(local_data), {"loss": 0.42}

# Train for 5 rounds
results = client.train_from_remote(
model="my-classifier",
local_train_fn=train_locally,
rounds=5,
update_format="delta" # More efficient
)

pull_model()

Download a model version.

client.pull_model(
model: str,
version: str | None = None,
format: str = "pytorch"
) -> bytes

Parameters:

  • model - Model name or ID
  • version - Version (default: latest)
  • format - "pytorch", "onnx", "tflite", "coreml"

Returns: Model bytes

Example:

# Download PyTorch model
model_bytes = client.pull_model(
model="my-classifier",
version="1.0.0",
format="pytorch"
)

# Load it
import torch
import io
state_dict = torch.load(io.BytesIO(model_bytes))

ModelRegistry

Model management for versioning, uploads, and conversions.

Constructor

ModelRegistry(
api_key: str,
org_id: str = "default",
api_base: str = "https://api.edgeml.io/api/v1",
timeout: float = 60.0
)

Methods

ensure_model()

Create or get existing model.

registry.ensure_model(
name: str,
framework: str,
use_case: str,
description: str | None = None
) -> dict

Parameters:

  • name - Model name (unique per org)
  • framework - "pytorch", "tensorflow", "onnx", etc.
  • use_case - "image_classification", "text_classification", "object_detection", etc.
  • description - Optional description

Returns: Model dict with id, name, metadata

Example:

model = registry.ensure_model(
name="mnist-classifier",
framework="pytorch",
use_case="image_classification",
description="Handwritten digit classifier"
)
print(f"Model ID: {model['id']}")

upload_version_from_path()

Upload a model version from file.

registry.upload_version_from_path(
model_id: str,
file_path: str,
version: str,
description: str | None = None,
formats: str | None = None
) -> dict

Parameters:

  • model_id - Model ID
  • file_path - Path to model file (.pt, .pth, .onnx)
  • version - Version string (e.g., "1.0.0")
  • description - Optional description
  • formats - Comma-separated formats to convert to: "onnx,tflite,coreml"

Returns: Upload result with version metadata

Example:

result = registry.upload_version_from_path(
model_id=model['id'],
file_path="model.pt",
version="1.0.0",
description="Initial release",
formats="onnx,tflite,coreml" # Auto-convert for mobile
)

publish_version()

Publish a model version (make it available for download).

registry.publish_version(model_id: str, version: str) -> dict

create_rollout()

Create a progressive rollout for a model version.

registry.create_rollout(
model_id: str,
version: str,
rollout_percentage: int = 10,
target_percentage: int = 100,
increment_step: int = 10,
start_immediately: bool = True
) -> dict

Utility Functions

compute_state_dict_delta()

Compute weight delta between two PyTorch state dicts.

from edgeml import compute_state_dict_delta

delta = compute_state_dict_delta(
base_state: dict,
updated_state: dict
) -> dict

Example:

# Compute delta
delta = compute_state_dict_delta(old_weights, new_weights)

# Delta is much smaller than full weights

Error Handling

All SDK methods raise EdgeMLClientError on failure:

from edgeml import EdgeMLClientError

try:
client.train(model="nonexistent", local_data=weights)
except EdgeMLClientError as e:
print(f"Training failed: {e}")

Next Steps