Run Distributed TensorFlow Training
Once you have a Union account, install union
:
pip install union
Export the following environment variable to build and push images to your own container registry:
# replace with your registry name
export IMAGE_SPEC_REGISTRY="<your-container-registry>"
Then run the following commands to run the workflow:
git clone https://github.com/unionai/unionai-examples
cd unionai-examples
union run --remote tutorials/sentiment_classifier/sentiment_classifier.py main --model distilbert-base-uncased
The source code for this tutorial can be found here {octicon}mark-github
.
tensorflow:tf.distribute.Strategy
to distribute your training across multiple devices.
Several strategies are available within this API, all of which can be employed as needed.
In this example, we employ the tensorflow:tf.distribute.MirroredStrategy
to train an MNIST model using a CNN.
The MirroredStrategy
enables synchronous distributed training across multiple GPUs on a single machine.
For a deeper understanding of distributed training with TensorFlow, refer to the
distributed training with TensorFlow in the TensorFlow documentation.
To begin, load the libraries.
import os
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple, Tuple
from dataclasses_json import dataclass_json
from flytekit import ImageSpec, Resources, task, workflow
from flytekit.types.directory import FlyteDirectory
Create an ImageSpec
to encompass all the dependencies needed for the TensorFlow task.
custom_image = ImageSpec(
packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"],
registry="ghcr.io/flyteorg",
)
Replace ghcr.io/flyteorg
with a container registry you’ve access to publish to.
To upload the image to the local registry in the demo cluster, indicate the registry as localhost:30000
.
The following imports are required to configure the TensorFlow cluster in Flyte.
You can load them on demand.
if custom_image.is_container():
import tensorflow as tf
import tensorflow_datasets as tfds
from flytekitplugins.kftensorflow import PS, Chief, TfJob, Worker
You can activate GPU support by either using the base image that includes the necessary GPU dependencies
or by initializing the CUDA parameters
within the ImageSpec
.
For this example, we define the MODEL_FILE_PATH
variable to indicate the storage location for the model file.
MODEL_FILE_PATH = "saved_model/"
We initialize a data class to store the hyperparameters.
@dataclass_json
@dataclass
class Hyperparameters(object):
batch_size_per_replica: int = 64
buffer_size: int = 10000
epochs: int = 10
We use the MNIST dataset to train our model.
def load_data(
hyperparameters: Hyperparameters,
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.distribute.Strategy]:
datasets, _ = tfds.load(name="mnist", with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets["train"], datasets["test"]
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
# strategy.num_replicas_in_sync returns the number of replicas; helpful to utilize the extra compute power by increasing the batch size
BATCH_SIZE = hyperparameters.batch_size_per_replica * strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
# Fetch train and evaluation datasets
train_dataset = mnist_train.map(scale).shuffle(hyperparameters.buffer_size).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset, strategy
We create and compile a model in the context of Strategy.scope.
def get_compiled_model(strategy: tf.distribute.Strategy) -> tf.keras.Model:
with strategy.scope():
model = tf.keras.Sequential(
[
tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dense(10),
]
)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"],
)
return model
We define a function for decaying the learning rate.
def decay(epoch: int):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
We define the train_model
function to initiate model training with three callbacks:
tensorflow:tf.keras.callbacks.TensorBoard
to log the training metricstensorflow:tf.keras.callbacks.ModelCheckpoint
to save the model after every epochtensorflow:tf.keras.callbacks.LearningRateScheduler
to decay the learning rate
def train_model(
model: tf.keras.Model,
train_dataset: tf.data.Dataset,
hyperparameters: Hyperparameters,
) -> Tuple[tf.keras.Model, str]:
# Define the checkpoint directory to store checkpoints
checkpoint_dir = "./training_checkpoints"
# Define the name of the checkpoint files
checkpoint_prefix = str(Path(checkpoint_dir) / "ckpt_{epoch}")
# Define a callback for printing the learning rate at the end of each epoch
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print("\nLearning rate for epoch {} is {}".format(epoch + 1, model.optimizer.lr.numpy()))
# Put all the callbacks together
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir="./logs"),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR(),
]
# Train the model
model.fit(train_dataset, epochs=hyperparameters.epochs, callbacks=callbacks)
# Save the model
model.save(MODEL_FILE_PATH, save_format="tf")
return model, checkpoint_dir
We define the test_model
function to evaluate loss and accuracy on the test dataset.
def test_model(model: tf.keras.Model, checkpoint_dir: str, eval_dataset: tf.data.Dataset) -> Tuple[float, float]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
return eval_loss, eval_acc
To create a TensorFlow task, add flytekitplugins.kftensorflow.TfJob
config to the Flyte task.
training_outputs = NamedTuple("TrainingOutputs", accuracy=float, loss=float, model_state=FlyteDirectory)
if os.getenv("SANDBOX"):
resources = Resources(gpu="0", mem="1000Mi", ephemeral_storage="500Mi")
else:
resources = Resources(gpu="1", mem="10Gi", ephemeral_storage="500Mi")
@task(
task_config=TfJob(worker=Worker(replicas=1), ps=PS(replicas=1), chief=Chief(replicas=1)),
retries=2,
cache=True,
cache_version="2.2",
requests=resources,
limits=resources,
container_image=custom_image,
)
def mnist_tensorflow_job(hyperparameters: Hyperparameters) -> training_outputs:
train_dataset, eval_dataset, strategy = load_data(hyperparameters=hyperparameters)
model = get_compiled_model(strategy=strategy)
model, checkpoint_dir = train_model(model=model, train_dataset=train_dataset, hyperparameters=hyperparameters)
eval_loss, eval_accuracy = test_model(model=model, checkpoint_dir=checkpoint_dir, eval_dataset=eval_dataset)
return training_outputs(accuracy=eval_accuracy, loss=eval_loss, model_state=MODEL_FILE_PATH)
The task is initiated using TFJob
with specific values configured:
num_workers
: specifies the number of worker replicas to be launched in the cluster for this jobnum_ps_replicas
: determines the count of parameter server replicas to utilizenum_chief_replicas
: defines the number of chief replicas to be employed For our example, withMirroredStrategy
leveraging an all-reduce algorithm to communicate variable updates across devices, the parameternum_ps_replicas
does not hold significance.
If you’re interested in exploring the diverse TensorFlow strategies available for distributed training, you can find comprehensive information in the types of strategies section of the TensorFlow documentation. Lastly, define a workflow to invoke the tasks.
@workflow
def mnist_tensorflow_workflow(
hyperparameters: Hyperparameters = Hyperparameters(batch_size_per_replica=64),
) -> training_outputs:
return mnist_tensorflow_job(hyperparameters=hyperparameters)
You can run the code locally.
if __name__ == "__main__":
print(mnist_tensorflow_workflow())
In the context of distributed training, it’s important to acknowledge that return values from various workers could potentially vary. If you need to regulate which worker’s return value gets passed on to subsequent tasks in the workflow, you have the option to raise an IgnoreOutputs exception for all remaining ranks.