Pydantic BaseModel
flytekit
version >=1.14 supports natively the JSON
format that Pydantic BaseModel
produces, enhancing the
interoperability of Pydantic BaseModels with the Flyte type system.
Pydantic BaseModel V2 only works when you are using flytekit version >= v1.14.0.
With the 1.14 release, flytekit
adopted MessagePack
as the serialization format for Pydantic BaseModel
,
overcoming a major limitation of serialization into a JSON string within a Protobuf struct
datatype like the previous versions do:
to store int
types, Protobuf’s struct
converts them to float
, forcing users to write boilerplate code to work around this issue.
By default, flytekit >= 1.14
will produce msgpack
bytes literals when serializing, preserving the types defined in your BaseModel
class.
If you’re serializing BaseModel
using flytekit
version >= v1.14.0 and you want to produce Protobuf struct
literal instead, you can set environment variable FLYTE_USE_OLD_DC_FORMAT
to true
.
For more details, you can refer the MESSAGEPACK IDL RFC: https://github.com/flyteorg/flyte/blob/master/rfc/system/5741-binary-idl-with-message-pack.md
To clone and run the example code on this page, see the Flytesnacks repo.
You can put Dataclass and FlyteTypes (FlyteFile, FlyteDirectory, FlyteSchema, and StructuredDataset) in a pydantic BaseModel.
To begin, import the necessary dependencies:
import os
import tempfile
import pandas as pd
from flytekit
from flytekit.types.structured import StructuredDataset
from pydantic import BaseModel
Build your custom image with ImageSpec:
image_spec = union.ImageSpec(
registry="ghcr.io/flyteorg",
packages=["pandas", "pyarrow", "pydantic"],
)
Python types
We define a pydantic basemodel
with int
, str
and dict
as the data types.
class Datum(BaseModel):
x: int
y: str
z: dict[int, str]
You can send a pydantic basemodel
between different tasks written in various
languages, and input it through the Flyte console as raw
JSON.
All variables in a data class should be annotated with their type. Failure to do will result in an error.
Once declared, a dataclass can be returned as an output or accepted as an input.
@fl.task(container_image=image_spec)
def stringify(s: int) -> Datum:
"""
A Pydantic model return will be treated as a single complex JSON return.
"""
return Datum(x=s, y=str(s), z={s: str(s)})
@fl.task(container_image=image_spec)
def add(x: Datum, y: Datum) -> Datum:
x.z.update(y.z)
return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z)
Flyte types
We also define a data class that accepts StructuredDataset
, FlyteFile
and
FlyteDirectory
.
class FlyteTypes(BaseModel):
dataframe: StructuredDataset
file: union.FlyteFile
directory: union.FlyteDirectory
@fl.task(container_image=image_spec)
def upload_data() -> FlyteTypes:
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
temp_dir = tempfile.mkdtemp(prefix="flyte-")
df.to_parquet(os.path.join(temp_dir, "df.parquet"))
file_path = tempfile.NamedTemporaryFile(delete=False)
file_path.write(b"Hello, World!")
file_path.close()
fs = FlyteTypes(
dataframe=StructuredDataset(dataframe=df),
file=fl.FlyteFile(file_path.name),
directory=fl.FlyteDirectory(temp_dir),
)
return fs
@fl.task(container_image=image_spec)
def download_data(res: FlyteTypes):
expected_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
actual_df = res.dataframe.open(pd.DataFrame).all()
assert expected_df.equals(actual_df), "DataFrames do not match!"
with open(res.file, "r") as f:
assert f.read() == "Hello, World!", "File contents do not match!"
assert os.listdir(res.directory) == ["df.parquet"], "Directory contents do not match!"
A data class supports the usage of data associated with Python types, data classes, FlyteFile, FlyteDirectory and StructuredDataset.
We define a workflow that calls the tasks created above.
@fl.workflow
def basemodel_wf(x: int, y: int) -> (Datum, FlyteTypes):
o1 = add(x=stringify(s=x), y=stringify(s=y))
o2 = upload_data()
download_data(res=o2)
return o1, o2
To trigger a task that accepts a dataclass as an input with pyflyte run
, you can provide a JSON file as an input:
$ pyflyte run dataclass.py basemodel_wf --x 1 --y 2
To trigger a task that accepts a dataclass as an input with pyflyte run
, you can provide a JSON file as an input:
$ pyflyte run \
https://raw.githubusercontent.com/flyteorg/flytesnacks/b71e01d45037cea883883f33d8d93f258b9a5023/examples/data_types_and_io/data_types_and_io/pydantic_basemodel.py \
basemodel_wf --x 1 --y 2