NVIDIA DGX agent
You can run workflows on the NVIDIA DGX platform with the DGX agent.
Installation
To install the DGX agent and have it enabled in your deployment, contact the Union.ai team.
Example usage
from typing import List
import union
from flytekitplugins.dgx import DGXConfig
dgx_image_spec = union.ImageSpec(
base_image="my-image/dgx:v24",
packages=["torch", "transformers", "accelerate", "bitsandbytes"],
registry="my-registry",
)
DEFAULT_CHAT_TEMPLATE = """
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<<|user|>> ' + message['content'].strip() + ' <</s>>' }}
{% elif message['role'] == 'system' %}
{{ '<<|system|>>\\n' + message['content'].strip() + '\\n<</s>>\\n\\n' }}
{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
""".strip()
@union.task(container_image=dgx_image_spec, cache_version="1.0", cache=True)
def form_prompt(prompt: str, system_message: str) -> List[dict]:
return [
{"role": "system", "content": system_message},
{"role": "user", "content": prompt},
]
@union.task(
task_config=DGXConfig(instance="dgxa100.80g.8.norm"),
container_image=dgx_image_spec,
)
def inference(messages: List[dict], n_variations: int) -> List[str]:
import torch
import transformers
from transformers import AutoTokenizer
print(f"gpu is available: {torch.cuda.is_available()}")
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
tokenizer=tokenizer,
model=model,
model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True},
)
print(f"{messages=}")
prompt = pipeline.tokenizer.apply_chat_template(
messages,
chat_template=DEFAULT_CHAT_TEMPLATE,
tokenize=False,
add_generation_prompt=True,
)
outputs = pipeline(
prompt,
num_return_sequences=n_variations,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
return_full_text=False,
)
print(f'generated text={outputs[0]["generated_text"]}')
return [output["generated_text"] for output in outputs]
@union.workflow
def wf(
prompt: str = "Explain what a Mixture of Experts is in less than 100 words.",
n_variations: int = 8,
system_message: str = "You are a helpful and polite bot.",
) -> List[str]:
messages = form_prompt(prompt=prompt, system_message=system_message)
return inference(messages=messages, n_variations=n_variations)