Memray Profiling Example

Once you have a Union account, install union:

pip install union

Export the following environment variable to build and push images to your own container registry:

# replace with your registry name
export IMAGE_SPEC_REGISTRY="<your-container-registry>"

Then run the following commands to run the workflow:

git clone https://github.com/unionai/unionai-examples
cd unionai-examples
union run --remote tutorials/sentiment_classifier/sentiment_classifier.py main --model distilbert-base-uncased

The source code for this tutorial can be found here {octicon}mark-github.

Memray tracks and reports memory allocations, both in python code and in compiled extension modules. This Memray Profiling plugin enables memory tracking on the Flyte task level and renders a memgraph profiling graph on Flyte Deck.

import time

from flytekit import ImageSpec, task, workflow
from flytekitplugins.memray import memray_profiling

First, we use ImageSpec to construct a container that contains the dependencies for the tasks, we want to profile:

image = ImageSpec(
    name="memray_demo",
    packages=["flytekitplugins_memray"],
    registry="ghcr.io/flyteorg",  # Use your image registry
)

Next, we define a dummy function that generates data in memory without releasing:

def generate_data(n: int):
    leak_list = []
    for _ in range(n):  # Arbitrary large number for demonstration
        large_data = " " * 10**6  # 1 MB string
        leak_list.append(large_data)  # Keeps appending without releasing
        time.sleep(0.1)  # Slow down the loop to observe memory changes

Example of profiling the memory usage of generate_data() via the memray table html reporter

@task(container_image=image, enable_deck=True)
@memray_profiling(memray_html_reporter="table")
def memory_usage(n: int) -> str:
    generate_data(n=n)

    return "Well"

Example of profiling the memory leackage of generate_data() via the memray flamegraph html reporter

@task(container_image=image, enable_deck=True)
@memray_profiling(trace_python_allocators=True, memray_reporter_args=["--leaks"])
def memory_leakage(n: int) -> str:
    generate_data(n=n)

    return "Well"

Put everything together in a workflow.

@workflow
def wf(n: int = 500):
    memory_usage(n=n)
    memory_leakage(n=n)