Streaming Data for BERT Training

This example demonstrates how to train a BERT model on a large Arabic text dataset using PyTorch Lightning and the streaming library from MosaicML.

Once you have a Union account, install union:

pip install union

Export the following environment variable to build and push images to your own container registry:

# replace with your registry name
export IMAGE_SPEC_REGISTRY="<your-container-registry>"

Then run the following commands to run the workflow:

$ git clone https://github.com/unionai/unionai-examples
$ cd unionai-examples
$ union run --remote <path/to/file.py> <workflow_name> <params>

The source code for this tutorial can be found here.

The dataset is preprocessed into shards to enable efficient random access during training. The training job is distributed across multiple GPUs using the flytekitplugins-kfpytorch plugin, which leverages torchrun under the hood for multi-process training.

To get started, import the necessary libraries and set up the environment:

import os
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Optional

import pytorch_lightning as pl
import torch
import union
from flytekit.extras.accelerators import T4
from flytekitplugins.kfpytorch.task import Elastic
from transformers import BertForSequenceClassification

Set the number of nodes and GPUs to be used for training.

NUM_NODES = "1"
NUM_GPUS = "2"

Define the container image to be used for the tasks. This image includes all the necessary dependencies for training the BERT model.

image = union.ImageSpec(
    name="arabic-bert",
    builder="union",
    packages=[
        "union==0.1.173",
        "datasets==3.3.2",
        "flytekitplugins-kfpytorch==1.15.3",
        "mosaicml-streaming==0.11.0",
        "torch==2.6.0",
        "transformers==4.49.0",
        "wandb==0.19.8",
        "pytorch-lightning==2.5.1",
    ],
)

Define configuration parameters for both data streaming and model training.

  • The streaming configuration specifies the number of data loading workers, the number of retry attempts for downloading shards, whether to shuffle the data during training, and the batch size.
  • The training configuration defines key training hyperparameters such as learning rate, learning rate decay (gamma), and number of training epochs.
@dataclass
class StreamingConfig:
    num_workers: int = 2
    download_retry: int = 2
    shuffle: bool = True
    batch_size: int = 8


@dataclass
class TrainConfig:
    lr: float = 1.0
    gamma: float = 0.7
    epochs: int = 2

Define the artifacts for the dataset and model. These artifacts enable caching of the dataset and model files for future runs.

DatasetArtifact = union.Artifact(name="arabic-reviews-shards")
ModelArtifact = union.Artifact(name="arabic-bert")

Set the secret for authenticating with the Weights and Biases API. Make sure to request or store your API key as a secret in Union.

WANDB_API_KEY = "wandb-api-key"

Define the custom collate function for the DataLoader. This function prepares each batch of data for training by converting NumPy arrays into PyTorch tensors. It also ensures that data is correctly formatted and writable before conversion, which is especially important when working with memory-mapped arrays or data streaming.

def collate_fn(batch):
    import torch

    collated_batch = {}
    for key in batch[0].keys():
        if key == "labels":
            collated_batch[key] = torch.tensor([item[key] for item in batch])
        else:
            # Ensure arrays are writable before conversion
            tensors = []
            for item in batch:
                value = item[key]
                if hasattr(value, "flags") and not value.flags.writeable:
                    value = value.copy()
                tensors.append(torch.tensor(value))
            collated_batch[key] = torch.stack(tensors)
    return collated_batch

Define the tasks for downloading the model and dataset. The download_model task fetches a pretrained model from the Hugging Face Hub and caches it for use during training. The download_dataset task downloads the dataset containing 100,000 Arabic reviews, preprocesses it into streaming-compatible shards using MDSWriter, and saves it to a local directory. The dataset is then automatically uploaded to a remote blob store using FlyteDirectory for efficient access during training.

@union.task(cache=True, requests=union.Resources(mem="5Gi"), container_image=image)
def download_model(model_name: str) -> Annotated[union.FlyteDirectory, ModelArtifact]:
    from huggingface_hub import snapshot_download

    ctx = union.current_context()
    working_dir = Path(ctx.working_directory)
    cached_model_dir = working_dir / "cached_model"

    snapshot_download(model_name, local_dir=cached_model_dir)
    return cached_model_dir


@union.task(
    cache=True, container_image=image, requests=union.Resources(cpu="3", mem="3Gi")
)
def download_dataset(
    dataset: str, model_dir: union.FlyteDirectory
) -> Annotated[union.FlyteDirectory, DatasetArtifact]:
    from datasets import ClassLabel, load_dataset
    from streaming.base import MDSWriter
    from transformers import AutoTokenizer

    loaded_dataset = load_dataset(dataset, split="train")
    loaded_dataset = loaded_dataset.shuffle(seed=42)

    tokenizer = AutoTokenizer.from_pretrained(model_dir.download())

    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True)

    tokenized_dataset = loaded_dataset.map(tokenize_function, batched=True)

    tokenized_dataset = tokenized_dataset.cast_column(
        "label", ClassLabel(names=["Positive", "Negative", "Mixed"])
    )
    tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
    tokenized_dataset = tokenized_dataset.remove_columns(["text"])

    tokenized_dataset.set_format("numpy")

    local_dir = os.path.join(union.current_context().working_directory, "mds_shards")
    os.makedirs(local_dir, exist_ok=True)

    # Use MDSWriter to write the dataset to local directory
    with MDSWriter(
        out=local_dir,
        columns={
            "input_ids": "ndarray",
            "attention_mask": "ndarray",
            "token_type_ids": "ndarray",
            "labels": "int64",
        },
        size_limit="100kb",
    ) as out:
        for i in range(len(tokenized_dataset)):
            out.write(
                {k: tokenized_dataset[i][k] for k in tokenized_dataset.column_names}
            )

    return union.FlyteDirectory(local_dir)

Define the BERT classifier model using PyTorch Lightning. This module wraps Hugging Face’s BertForSequenceClassification model in a PyTorch Lightning module. It supports multi-class classification and is configured with an adaptive learning rate scheduler for training stability.

class BertClassifier(pl.LightningModule):
    def __init__(self, model_dir: str, learning_rate: float, gamma: float):
        super().__init__()
        self.model = BertForSequenceClassification.from_pretrained(
            model_dir, num_labels=3
        )
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.save_hyperparameters()

    def forward(self, **batch):
        return self.model(**batch)

    def training_step(self, batch, batch_idx):
        output = self(**batch)
        loss = output.loss
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adadelta(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=1, gamma=self.gamma
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

Set up a training task to fine-tune the BERT model using PyTorch Lightning. This task leverages the Elastic strategy to distribute training across 2 GPUs on a single node, and uses WandbLogger to log metrics to Weights & Biases for experiment tracking.

The training data is streamed from a remote blob store using the StreamingDataset class. The dataset is provided as a FlyteDirectory, which was created and uploaded in the earlier download_dataset task. The streaming library downloads shards on demand and loads them into GPU memory as needed, enabling efficient training at scale.

@union.task(
    cache=True,
    container_image=image,
    task_config=Elastic(
        nnodes=int(NUM_NODES),
        nproc_per_node=int(NUM_GPUS),
        max_restarts=3,
        start_method="fork",
    ),
    requests=union.Resources(
        mem="40Gi", cpu="10", gpu=NUM_GPUS, ephemeral_storage="15Gi"
    ),
    secret_requests=[union.Secret(key=WANDB_API_KEY, env_var="WANDB_API_KEY")],
    accelerator=T4,
    environment={
        "NCCL_DEBUG": "WARN",
        "TORCH_DISTRIBUTED_DEBUG": "INFO",
    },
    shared_memory=True,
)
def train_bert(
    dataset_shards: union.FlyteDirectory,
    model_dir: union.FlyteDirectory,
    train_config: TrainConfig,
    wandb_entity: str,
    streaming_config: StreamingConfig,
) -> Annotated[Optional[union.FlyteFile], ModelArtifact]:
    import os

    import pytorch_lightning as pl
    import wandb
    from pytorch_lightning.loggers import WandbLogger
    from streaming.base import StreamingDataset
    from torch.utils.data import DataLoader

    local_model_dir = model_dir.download()
    model = BertClassifier(local_model_dir, train_config.lr, train_config.gamma)

    dataset = StreamingDataset(
        remote=dataset_shards.remote_source,
        batch_size=streaming_config.batch_size,
        download_retry=streaming_config.download_retry,
        shuffle=streaming_config.shuffle,
    )

    train_loader = DataLoader(
        dataset=dataset,
        batch_size=streaming_config.batch_size,
        collate_fn=collate_fn,
        num_workers=streaming_config.num_workers,
    )

    wandb_logger = WandbLogger(
        entity=wandb_entity,
        project="bert-training",
        name=f"bert-training-rank-{os.environ['RANK']}",
    )

    trainer = pl.Trainer(
        accelerator="gpu",
        strategy="ddp",
        devices="auto",
        max_epochs=train_config.epochs,
        logger=wandb_logger,
        use_distributed_sampler=False,
    )

    trainer.fit(model, train_loader)

    # Save model only from rank 0
    if int(os.environ["RANK"]) == 0:
        model_file = os.path.join(
            union.current_context().working_directory, "bert_uncased_gpu.pt"
        )
        torch.save(model.model.state_dict(), model_file)
        wandb.finish()
        return union.FlyteFile(model_file)

    return None

Define the workflow for downloading the model, dataset, and training the BERT model. The workflow orchestrates the execution of the tasks and ensures that the model and dataset are available for training.

@union.workflow
def finetune_bert_on_sharded_data(
    wandb_entity: str,
    dataset_name: str = "arbml/arabic_100k_reviews",
    model_name: str = "bert-base-uncased",
    train_config: TrainConfig = TrainConfig(),
    streaming_config: StreamingConfig = StreamingConfig(),
) -> Optional[union.FlyteFile]:
    model = download_model(model_name=model_name)
    dataset_shards = download_dataset(dataset=dataset_name, model_dir=model)
    return train_bert(
        dataset_shards=dataset_shards,
        model_dir=model,
        train_config=train_config,
        wandb_entity=wandb_entity,
        streaming_config=streaming_config,
    )