Managing machine learning models with Dagster
This guide reviews ways to manage and maintain your machine learning (ML) models in Dagster.
Machine learning models are highly dependent on data at a point in time and must be managed to ensure they produce the same results as when you were in the development phase. In this guide, you'll learn how to:
- Automate training of your model when new data is available or when you want to use your model for predictions
- Integrate metadata about your model into the Dagster UI to display info about your model's performance
Before proceeding, we recommend reviewing "Building machine learning pipelines with Dagster", which provides background on using Dagster's assets for machine learning.
Machine learning operations (MLOps)
You might have thought about your data sources, feature sets, and the best model for your use case. Inevitably, you start thinking about how to make this process sustainable and operational and deploy it to production. You want to make the machine learning pipeline self-sufficient and have confidence that the model you built is performing the way you expect. Thinking about machine learning operations, or MLOps, is the process of making your model maintainable and repeatable for a production use case.
Automating ML model maintenance
Whether you have a large or small model, Dagster can help automate data refreshes and model training based on your business needs.
Declarative Automation can be used to update a machine learning model when the upstream data is updated. This can be done by setting the AutomationCondition
to eager
, which means that our machine learning model asset will be refreshed anytime our data asset is updated.
from dagster import AutomationCondition, asset
@asset
def my_data(): ...
@asset(automation_condition=AutomationCondition.eager())
def my_ml_model(my_data): ...
Some machine learning models might be more cumbersome to retrain; it also might be less important to update them as soon as new data arrives. For this, the on_cron
condition may be used, which will cause the asset to be updated on a given cron schedule, but only after all of its upstream dependencies have been updated.
from dagster import AutoMaterializePolicy, asset, FreshnessPolicy
@asset
def my_other_data(): ...
@asset(automation_condition=AutomationCondition.on_cron("0 9 * * *"))
def my_other_ml_model(my_other_data): ...
Monitoring
Integrating your machine learning models into Dagster allows you to see when the model and its data dependencies were refreshed, or when a refresh process has failed. By using Dagster to monitor performance changes and process failures on your ML model, it becomes possible to set up remediation paths, such as automated model retraining, that can help resolve issues like model drift.
In this example, the model is being evaluated against the previous model’s accuracy. If the model’s accuracy has improved, the model is returned for use in downstream steps, such as inference or deploying to production.
from sklearn import linear_model
from dagster import asset, Output, AssetKey, AssetExecutionContext
import numpy as np
from sklearn.model_selection import train_test_split
@asset(output_required=False)
def conditional_machine_learning_model(context: AssetExecutionContext):
X, y = np.random.randint(5000, size=(5000, 2)), range(5000)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)
reg = linear_model.LinearRegression()
reg.fit(X_train, y_train)
# Get the model accuracy from metadata of the previous materilization of this machine learning model
instance = context.instance
materialization = instance.get_latest_materialization_event(
AssetKey(["conditional_machine_learning_model"])
)
if materialization is None:
yield Output(reg, metadata={"model_accuracy": float(reg.score(X_test, y_test))})
else:
previous_model_accuracy = None
if materialization.asset_materialization and isinstance(
materialization.asset_materialization.metadata["model_accuracy"].value,
float,
):
previous_model_accuracy = float(
materialization.asset_materialization.metadata["model_accuracy"].value
)
new_model_accuracy = reg.score(X_test, y_test)
if (
previous_model_accuracy is None
or new_model_accuracy > previous_model_accuracy
):
yield Output(reg, metadata={"model_accuracy": float(new_model_accuracy)})
A sensor can be set up that triggers if an asset fails to materialize. Alerts can be customized and sent through e-mail or natively through Slack. In this example, a Slack message is sent anytime the ml_job
fails.
import os
from dagster import define_asset_job
from dagster_slack import make_slack_on_run_failure_sensor
ml_job = define_asset_job("ml_training_job", selection=[ml_model])
slack_on_run_failure = make_slack_on_run_failure_sensor(
channel="#ml_monitor_channel",
slack_token=slack_token,
monitored_jobs=([ml_job]),
)
Enhancing the Dagster UI with metadata
Understanding the performance of your ML model is critical to both the model development process and production. Metadata can significantly enhance the usability of the Dagster UI to show what’s going on in a specific asset. Using metadata in Dagster is flexible, can be used for tracking evaluation metrics, and viewing the training accuracy progress over training iterations as a graph.
One of the easiest ways to utilize Dagster’s metadata is by using a dictionary to track different metrics that are relevant for an ML model.
Another way is to store relevant data for a single training iteration as a graph that you can view directly from the Dagster UI. In this example, a function is defined that uses data produced by a machine learning model to plot an evaluation metric as the model goes through the training process and render that in the Dagster UI.
Dagster’s MetadataValue
types enable types such as tables, URLs, notebooks, Markdown, etc. In the following example, the Markdown metadata type is used to generate plots. Each plot will show a specific evaluation metric’s performance throughout each training iteration also known as an epoch during the training cycle.
from dagster import MetadataValue
import seaborn
import matplotlib.pyplot as plt
import base64
from io import BytesIO
def make_plot(eval_metric):
plt.clf()
training_plot = seaborn.lineplot(eval_metric)
fig = training_plot.get_figure()
buffer = BytesIO()
fig.savefig(buffer)
image_data = base64.b64encode(buffer.getvalue())
return MetadataValue.md(f"![img](data:image/png;base64,{image_data.decode()})")
In this example, a dictionary is used called metadata
to store the Markdown plots and the score value in Dagster.
from dagster import asset
import xgboost as xgb
from sklearn.metrics import mean_absolute_error
@asset
def xgboost_comments_model(transformed_training_data, transformed_test_data):
transformed_X_train, transformed_y_train = transformed_training_data
transformed_X_test, transformed_y_test = transformed_test_data
# Train XGBoost model, which is a highly efficient and flexible model
xgb_r = xgb.XGBRegressor(
objective="reg:squarederror", eval_metric=mean_absolute_error, n_estimators=20
)
xgb_r.fit(
transformed_X_train,
transformed_y_train,
eval_set=[(transformed_X_test, transformed_y_test)],
)
## plot the mean absolute error values as the training progressed
metadata = {}
for eval_metric in xgb_r.evals_result()["validation_0"].keys():
metadata[f"{eval_metric} plot"] = make_plot(
xgb_r.evals_result_["validation_0"][eval_metric]
)
# keep track of the score
metadata["score (mean_absolute_error)"] = xgb_r.evals_result_["validation_0"][
"mean_absolute_error"
][-1]
return Output(xgb_r, metadata=metadata)
In the Dagster UI, the xgboost_comments_model
has the metadata rendered. Numerical values, such as the score (mean_absolute_error)
will be logged and plotted for each materialization, which can be useful to understand the score over time for machine learning models.
The Markdown plots are also available to inspect the evaluation metrics during the training cycle by clicking on [Show Markdown]:
Tracking model history
Viewing previous versions of a machine learning model can be useful to understand the evaluation history or referencing a model that was used for inference. Using Dagster will enable you to understand:
- What data was used to train the model
- When the model was refreshed
- The code version and ML model version was used to generate the predictions used for predicted values
In Dagster, each time an asset is materialized, the metadata and model are stored. Dagster registers the code version, data version and source data for each asset, so understanding what data was used to train a model is linked.
In the screenshot below, each materialization of xgboost_comments_model
and the path for where each iteration of the model is stored.
Any plots generated through the asset's metadata can be viewed in the metadata section. In this example, the plots of score (mean_absolute_error)
are available for analysis.