Slurm agent example usage
Once you have a Union account, install union
:
pip install union
Export the following environment variable to build and push images to your own container registry:
# replace with your registry name
export IMAGE_SPEC_REGISTRY="<your-container-registry>"
Then run the following commands to run the workflow:
git clone https://github.com/unionai/unionai-examples
cd unionai-examples
union run --remote tutorials/sentiment_classifier/sentiment_classifier.py main --model distilbert-base-uncased
The source code for this tutorial can be found here {octicon}mark-github
.
import os
from flytekit import task, workflow
from flytekitplugins.slurm import Slurm, SlurmFunction, SlurmRemoteScript, SlurmShellTask, SlurmTask
SlurmTask
First, SlurmTask
is the most basic use case, allowing users to directly run a pre-existing shell script on the Slurm cluster. To configure this task, you need to specify the following fields:
ssh_config
: Options of SSH client connection.- Authentication is done via key pair verification. For available options, please refer to here.
batch_script_path
: Path to the shell script on the Slurm cluster.sbatch_conf
(optional): Options ofsbatch
command. If not provided, defaults to an empty dict.- For available options, please refer to the official Slurm documentation.
batch_script_args
(optional): Additional arguments for the batch script on Slurm cluster.
slurm_task = SlurmTask(
name="basic",
task_config=SlurmRemoteScript(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
},
sbatch_conf={
"partition": "debug",
"job-name": "job0",
},
batch_script_path="/home/ubuntu/echo.sh",
),
)
@workflow
def basic_wf() -> None:
slurm_task()
Then, you can execute the workflow locally as below:
if __name__ == "__main__":
from click.testing import CliRunner
from flytekit.clis.sdk_in_container import pyflyte
runner = CliRunner()
path = os.path.realpath(__file__)
print(">>> LOCAL EXEC <<<")
result = runner.invoke(pyflyte.main, ["run", path, "basic_wf"])
print(result.output)
SlurmShellTask
Instead of running a pre-existing shell script on the Slurm cluster, SlurmShellTask
allows users to define the script content within the interface as shown below:
shell_task = SlurmShellTask(
name="shell0",
script="""#!/bin/bash -i
echo [TEST SLURM SHELL TASK 1] Run the user-defined script...
""",
task_config=Slurm(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
"client_keys": ["~/.ssh/private_key.pem"],
},
sbatch_conf={
"partition": "debug",
"job-name": "job1",
},
),
)
shell_task_with_args = SlurmShellTask(
name="shell1",
script="""#!/bin/bash -i
echo [TEST SLURM SHELL TASK 2] Run the user-defined script with args...
echo Arg1: $1
echo Arg2: $2
echo Arg3: $3
""",
task_config=Slurm(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
},
sbatch_conf={
"partition": "debug",
"job-name": "job2",
},
batch_script_args=["0", "a", "xyz"],
),
)
@workflow
def shell_wf() -> None:
shell_task()
shell_task_with_args()
Once again, execute the workflow locally to view the results:
if __name__ == "__main__":
from click.testing import CliRunner
from flytekit.clis.sdk_in_container import pyflyte
runner = CliRunner()
path = os.path.realpath(__file__)
print(">>> LOCAL EXEC <<<")
result = runner.invoke(pyflyte.main, ["run", path, "shell_wf"])
print(result.output)
SlurmFunctionTask
Finally, SlurmFunctionTask
is a highly flexible task type that allows you to run a user-defined task function on a Slurm cluster. To configure this task, you need to specify the following fields:
ssh_config
: Options of SSH client connection.- Authentication is done via key pair verification. For available options, please refer to here.
sbatch_conf
(optional): Options ofsbatch
command. If not provided, defaults to an empty dict.- For available options, please refer to the official Slurm documentation.
script
(optional): A user-defined script where{task.fn}
serves as a placeholder for the task function execution.- You should insert
{task.fn}
at the desired execution point within the script. If no script is provided, the task function will be executed directly.
- You should insert
@task(
task_config=SlurmFunction(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
"client_keys": ["~/.ssh/private_key.pem"],
},
sbatch_conf={"partition": "debug", "job-name": "job3", "output": "/home/ubuntu/fn_task.log"},
script="""#!/bin/bash -i
echo [TEST SLURM FN TASK 1] Run the first user-defined task function...
Setup env vars
export MY_ENV_VAR=123
Source the virtual env
. /home/ubuntu/.cache/pypoetry/virtualenvs/demo-4A8TrTN7-py3.12/bin/activate
Run the user-defined task function
{task.fn}
""",
)
)
def plus_one(x: int) -> int:
print(os.getenv("MY_ENV_VAR"))
return x + 1
@task(
task_config=SlurmFunction(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
},
script="""#!/bin/bash -i
echo [TEST SLURM FN TASK 2] Run the second user-defined task function...
. /home/ubuntu/.cache/pypoetry/virtualenvs/demo-4A8TrTN7-py3.12/bin/activate
{task.fn}
""",
)
)
def greet(year: int) -> str:
return f"Hello {year}!!!"
@workflow
def function_wf(x: int) -> str:
x = plus_one(x=x)
msg = greet(year=x)
return msg
Let’s execute the workflow:
if __name__ == "__main__":
from click.testing import CliRunner
from flytekit.clis.sdk_in_container import pyflyte
runner = CliRunner()
path = os.path.realpath(__file__)
print(">>> LOCAL EXEC <<<")
result = runner.invoke(
pyflyte.main, ["run", "--raw-output-data-prefix", "s3://my-flyte-slurm-agent", path, "function_wf", "--x", 2024]
)
print(result.output)
Train and Evaluate a DL Model with SlurmFunctionTask
The following example demonstrates how SlurmFunctionTask
can be integrated into a standard deep learning model training workflow. At the highest level, this workflow consists of three main components:
dataset
: Manage dataset downloading and data preprocessing (MNIST is used as an example).model
: Define the deep learning model architecture (e.g., a convolutional neural network).trainer
: Handle the training process, includingtrain_epoch
andeval_epoch
. Let’s first take a closer look at each component before diving into the main training workflow.
Dataset
from typing import Tuple
from torch.utils.data import Dataset
from torchvision import datasets, transforms
def get_dataset(download_path: str = "/tmp/torch_data") -> Tuple[Dataset, Dataset]:
"""Process data and build training and validation sets.
Args:
download_path: Directory to store the raw data.
Returns:
A tuple (tr_ds, val_ds), where tr_ds is a training set and val_ds is a valiation set.
"""
# Define data processing pipeline
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
tr_ds = datasets.MNIST(root=download_path, train=True, download=True, transform=transform)
val_ds = datasets.MNIST(root=download_path, train=True, download=True, transform=transform)
return tr_ds, val_ds
Model
from typing import Dict
import torch.nn as nn
from torch import Tensor
class Model(nn.Module):
def __init__(self) -> None:
super(Model, self).__init__()
self.cnn_encoder = nn.Sequential(
# Block 1
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
# Block 2
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.clf = nn.Linear(32 * 4 * 4, 10)
def forward(self, inputs: Dict[str, Tensor]) -> Tensor:
x = inputs["x"]
bs = x.size(0)
x = self.cnn_encoder(x)
x = x.reshape(bs, -1)
logits = self.clf(x)
return logits
Trainer
import gc
from typing import Tuple
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
def train_epoch(
tr_loader: DataLoader, model: nn.Module, loss_fn: nn.Module, optimizer: Optimizer, debug: bool = False
) -> float:
"""Run training for one epoch.
Args:
tr_loader: Training dataloader.
model: Model instance.
loss_fn: Loss criterion.
optimizer: Optimizer.
debug: If True, run one batch only.
Returns:
The average training loss over batches.
"""
tr_loss_tot = 0.0
model.train()
for i, batch_data in tqdm(enumerate(tr_loader), total=len(tr_loader)):
optimizer.zero_grad(set_to_none=True)
# Retrieve batched raw data
x, y = batch_data
inputs = {"x": x}
# Forward pass
logits = model(inputs)
# Derive loss
loss = loss_fn(logits, y)
tr_loss_tot += loss.item()
# Backpropagation
loss.backward()
optimizer.step()
del x, y, inputs, logits
_ = gc.collect()
if debug:
break
tr_loss_avg = tr_loss_tot / len(tr_loader)
return tr_loss_avg
@torch.no_grad()
def eval_epoch(
eval_loader: DataLoader, model: nn.Module, loss_fn: nn.Module, debug: bool = False
) -> Tuple[float, float]:
"""Run evaluation for one epoch.
Args:
eval_loader: Evaluation dataloader.
model: Model instance.
loss_fn: Loss criterion.
debug: If True, run one batch only.
Returns:
A tuple (eval_loss_avg, acc), where eval_loss_avg is the average evaluation loss over batches
and acc is the accuracy.
"""
eval_loss_tot = 0
y_true, y_pred = [], []
model.eval()
for i, batch_data in tqdm(enumerate(eval_loader), total=len(eval_loader)):
# Retrieve batched raw data
x, y = batch_data
inputs = {"x": x}
# Forward pass
logits = model(inputs)
# Derive loss
loss = loss_fn(logits, y)
eval_loss_tot += loss.item()
# Record batched output
y_true.append(y.detach())
y_pred.append(logits.detach())
del x, y, inputs, logits
_ = gc.collect()
if debug:
break
eval_loss_avg = eval_loss_tot / len(eval_loader)
# Derive accuracy
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
y_pred = torch.argmax(y_pred, dim=1)
acc = (y_true == y_pred).sum() / len(y_true)
return eval_loss_avg, acc
Deep Learning Model Training Workflow
Once the three main components are in place, you can now train your deep learning model using GPUs on the Slurm cluster with the highly flexible SlurmFunctionTask
.
import os
from pathlib import Path
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from flytekit import task, workflow
from flytekit.types.file import FlyteFile
from flytekitplugins.slurm import SlurmFunction
from torch.utils.data import DataLoader
@task(
task_config=SlurmFunction(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
},
sbatch_conf={
"partition": "debug",
"job-name": "process-data",
"output": "/home/ubuntu/dp.log",
},
script="""#!/bin/bash -i
echo "Process and build torch datasets..."
{task.fn}
""",
)
)
def process_data(raw_data_path: str) -> str:
# Download the MNIST dataset but ignore the torch training and validation datasets,
# which are built in the `train` function below
_ = get_dataset(download_path=raw_data_path)
proc_data_path = raw_data_path
return proc_data_path
@task(
task_config=SlurmFunction(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
},
sbatch_conf={
"partition": "debug",
"job-name": "train-model",
"output": "/home/ubuntu/train.log",
},
script="""#!/bin/bash -i
echo "Training process..."
{task.fn}
""",
)
)
def train(
data_path: str,
epochs: int = 5,
batch_size: int = 32,
lr: float = 1e-3,
ckpt_path: Optional[str] = None,
debug: bool = False,
) -> FlyteFile:
# --------------------------------------------------------------------------------
# HARD-CODE CUDA DEVICE AND ASSERT IT'S AVAILABLE
# --------------------------------------------------------------------------------
device = torch.device("cuda")
assert torch.cuda.is_available(), "Requested GPU but no CUDA device found!"
print(f"[train] Using device: {device}")
ckpt_path = Path("./output") if ckpt_path is None else Path(ckpt_path)
ckpt_path.mkdir(exist_ok=True)
model_path = ckpt_path / "model.pth"
# Build dataloaders
tr_ds, val_ds = get_dataset(download_path=data_path)
tr_loader = DataLoader(tr_ds, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=batch_size * 4, shuffle=False)
# Build model
model = Model()
# Builc loss criterion
loss_fn = nn.CrossEntropyLoss()
# Build solvers
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# LR scheduler
# lr_skd = None
# Run training and evaluation
best_score = 1e16
for ep in range(epochs):
tr_loss = train_epoch(
tr_loader=tr_loader,
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
debug=debug,
)
val_loss, acc = eval_epoch(eval_loader=val_loader, model=model, loss_fn=loss_fn, debug=debug)
# Save model ckpt
if val_loss < best_score:
best_score = val_loss
torch.save(model.state_dict(), model_path)
print(f"Epoch [{ep+1}/{epochs}] TRAIN LOSS {tr_loss:.4f} | VAL LOSS {val_loss:.4f} | ACC {acc:.4f}")
return FlyteFile(path=model_path)
@task(
task_config=SlurmFunction(
ssh_config={
"host": "ec2-11-22-33-444.us-west-2.compute.amazonaws.com",
"username": "ubuntu",
},
sbatch_conf={
"partition": "debug",
"job-name": "eval-model",
"output": "/home/ubuntu/eval.log",
"gres": "gpu:1",
},
script="""#!/bin/bash -i
echo "Evaluation process..."
{task.fn}
""",
)
)
@torch.no_grad()
def run_infer(data_path: str, model_path: FlyteFile) -> Dict[str, float]:
# --------------------------------------------------------------------------------
# HARD-CODE CUDA DEVICE AND ASSERT IT'S AVAILABLE
# --------------------------------------------------------------------------------
# Build validation dataloader
_, val_ds = get_dataset(download_path=data_path)
val_loader = DataLoader(val_ds, batch_size=2048, shuffle=False)
# Load model
model = Model()
model.load_state_dict(torch.load(model_path.download()))
y_true, y_pred = [], []
model.eval()
for i, batch_data in tqdm(enumerate(val_loader), total=len(val_loader)):
# Retrieve batched raw data
x, y = batch_data
inputs = {"x": x}
# Forward pass
logits = model(inputs)
# Record batched output
y_true.append(y.detach())
y_pred.append(logits.detach())
# Derive accuracy
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0).argmax(dim=1)
prf_report = {"acc": ((y_true == y_pred).sum() / len(y_true)).item()}
return prf_report
@workflow
def dl_wf(
raw_data_path: str,
epochs: int = 1,
debug: bool = True,
) -> Dict[str, float]:
proc_data_path = process_data(raw_data_path=raw_data_path)
output_path = train(data_path=proc_data_path, epochs=epochs, debug=debug)
prf_report = run_infer(data_path=proc_data_path, model_path=output_path)
return prf_report
Run the following code snippet and enjoy your training journey!
if __name__ == "__main__":
from click.testing import CliRunner
from flytekit.clis.sdk_in_container import pyflyte
runner = CliRunner()
path = os.path.realpath(__file__)
# Local run
print(">>> LOCAL EXEC <<<")
result = runner.invoke(
pyflyte.main,
[
"run",
"--raw-output-data-prefix",
"s3://my-flyte-slurm-agent",
path,
"dl_wf",
"--raw_data_path",
"/tmp/torch_data",
],
)
print(result.output)