Thomas Fan

Flyte and Weights & Biases Integration 

With Flyte’s latest plugin for Weights & Biases, you can now effectively run Machine Learning or AI workflows on Union and integrate with Weights & Biases capabilities. Union provides scalability, declarative infrastructure, and data lineage allowing you to quickly iterate and productionize AI or ML workflows. Weights & Biases helps customers build models faster, fine-tune LLMs, and develop GenAI applications with confidence, all in one system of record. In this blog post, we learn how to use the Weights & Biases plugin on Union.

Flytekit’s Weights & Biases Plugin

Union considers both data and computing to be fundamental building blocks.You can train models using machine learning or AI libraries such as XGBoost or PyTorch and track those models with Union artifacts. Union's reactive workflows are triggered when the underlying data changes and scales up to train many models.

In this initial example, flytekit's `wandb_init` configures the run in Weights & Biases and the XGBoost callback automatically tracks the model's progress. After decorating your function, the body consists of code you'll find in Weights & Biases documentation:

Copied to clipboard!
from flytekit import task
from flytekitplugins.wandb import wandb_init

wandb_secret = Secret(key="wandb-api-key")

@task(container_image=image, secret_requests=[wandb_secret])
@wandb_init(
    project=WANDB_PROJECT, entity=WANDB_ENTITY, secret=wandb_secret,
)
def train(data: pd.DataFrame) -> float:
    # Normal usage of wandb
    from wandb.integration.xgboost import WandbCallback
    import wandb
    
    bst = XGBClassifier(...,callbacks=[WandbCallback(log_model=True)])

    wandb.run.log({"test_score": test_score})
    return test_score

The `wandb_secret` object refers to a Weights & Biases API key, which was created with Union’s CLI: `unionai create secret wandb-api-key`. The `wandb_init` decorator will start the run and configure Union's UI to show the link to the run:

Clicking the link takes us to Weights & Biases, which shows all the tracking information about our model training execution. On Weights & Biases, the Flyte Execution is linked back in the run’s description:

Reactive Workflows

With Union's artifacts, you can write workflows that automatically trigger when the data gets updated by another workflow. This enables workflows to be modular, where one team focuses on extracting data, and another focuses on modeling. You can declare an artifact with a Python typing annotation:

Copied to clipboard!
from flytekit.core.artifact import Artifact
from typing_extensions import Annotated

MyDataset = Artifact(name="my_dataset")

@task(...)
@wandb_init(...)
def train(data: pd.DataFrame) -> float:
    ...

# train_workflow will trigger when "my_dataset" gets updated 
@workflow
def train_workflow(data: pd.DataFrame = MyDataset.query()):
    train(data)

trigger = LaunchPlan.create(
    "trigger_train_workflow",
    train_workflow,
    trigger=OnArtifact(trigger_on=MyDataset),
)

This `train_with_artifact` task takes a `"my_dataset"` artifact, which represents an upstream dataset. With the `wandb_init` decorator, Weights & Biases will track the metrics and results of the new training task with the updated dataset. You can observe changes in the model's performance as the dataset changes over time.

Scaling Out Experiments

With Flyte's dynamic workflows, you can quickly scale up to multiple training tasks, each with its own resources. In this example, you see how to use Flyte’s declarative infrastructure to train various models using PyTorch Lightning on GPUs. The function’s body consists of regular PyTorch Lightning code that you’ll find in their documentation.

Copied to clipboard!
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from flytekit.extras.accelerators import T4

@task(
    container_image=image,
    requests=Resources(gpu="1", cpu="2", mem="8Gi"),
    accelerator=T4,
)
@wandb_init(...)
def train_lightning_model(n_layer: int) -> dict:
    wandb_logger = WandbLogger(log_model="all")
    
    model = MyLightningModule(n_layer_1=n_layer, n_layer_2=n_layer)
    trainer = Trainer(max_epochs=5, logger=wandb_logger)
    trainer.fit(model, training_loader, validation_loader)
    ...

@dynamic(container_image=image)
def main(n_layers: list[int]):
    dataset = get_dataset()
    for n_layer in n_layers:
        train_lightning_model(dataset=dataset, n_layer=n_layer)

In the Union UI, the workflow dynamically scale out to multiple GPU-powered tasks:

PyTorch Lightning's WandbLogger automatically logs the metrics, hyperparameters, and checkpoints during model training. From the Weights & Biases platform, you can compare the different runs and evaluate the model’s performance.

Wrapping Up

Union's declarative infrastructure and scalable orchestration platform make it simple to scale up our machine learning or AI workflows and put them in production. With flytekit's Weights & Biases plugin, you can easily track our experiments, visualize results, and debug our models. Use the plugin by installing it with `pip install flytekitplugins-wandb`.

If you want to learn more about Union, get in touch with us at union.ai/demo.

Machine Learning
AI Orchestration