Co-Authored by Michael Marzec and Anton Dereventsov.
Introduction
At Klaviyo, like many data-driven companies, we face the ongoing challenge of balancing compute costs and processing efficiency across ever-growing (and already massive) datasets. Recently, we've explored Ray's open-source framework for scalable data processing, training, and optimization.
In this post, we'll walk through how we use Ray's three core libraries: Ray Data, Ray Train, and Ray Tune, to ingest and preprocess large datasets, train machine learning models efficiently, and optimize hyperparameters at scale. And specifically, how these libraries augment our broader Data Science team's ability to train, optimize and deploy models from in-app / customer-facing models to internal business function focused models.
Ray Data
At Klaviyo, we use Ray Data to ingest and preprocess datasets that are too large to fit into memory on a single machine (e.g., 100GB+ tables with over 100 million rows). Rather than spinning up Spark jobs for "last-mile" ETL, we can launch a Ray cluster across multiple EC2 instances and run Python-native pipelines that integrate seamlessly with our ML code.
Under the hood, Ray Data uses Apache Arrow, which enables familiar method chaining with .map(), .filter(), and .repartition() calls without any JVM shenanigans or extra serialization.
import ray
from ray.data import read_parquet
from ray.data.preprocessors import SimpleImputer
ray.init(address="auto") # connect to our Ray cluster
dataset = read_parquet("s3://path/to/data/")
# filter out rows with missing values
dataset = dataset.filter(lambda r: r["timestamp"] is not None)
# impute missing open rate with 0
imputer = SimpleImputer(columns=["open_rate"], strategy="constant", fill_value=0)
dataset = imputer.fit_transform(dataset)
# repartition the dataset for better performance
dataset = dataset.repartition(64)Ray Data takes advantage of all available resources in the cluster, striking a balance between throughput and computational cost efficiency. In one job, we used 200 m5.8xlarge instances (4,000+ CPUs) and observed near-linear scaling on billions of records! On a typical day, we load and preprocess ~100GB of data in minutes using roughly 16 instances and can subsequently train various ML models.
Because Ray Data automatically balances work across all available resources, we achieve high throughput without paying for idle resources, and can adjust the cluster size based on necessity without modifying pipelines.
Ray Train
To complement Ray Data, we use Ray Train to train machine learning models on the cleaned data. Ray Train supports many frameworks, but we primarily use PyTorch.
With Ray Train, we write a standard PyTorch training loop, and Ray Train handles worker orchestration, data distribution, gradient synchronization (via DDP), and resource management, eliminating many complexities behind distributed training.
Here's a simplified example:
import ray
from ray.train import get_dataset_shard
from ray.train.torch import TorchTrainer
# this code will be run on every Ray worker
def train_loop(config):
model = setup_model(config) # use any pytorch model!
ray.train.torch.prepare_model(model) # prepare for DDP training
# split dataset between workers
dataset_shards = {
"train": get_dataset_shard("train"),
"test": get_dataset_shard("test"),
}
model.run_training(dataset_shards) # your training loop for the model
# load data -- doesn't necessarily have to come from Ray Data
datasets = {
"train": ray.data.read_parquet("train_data_path"),
"test": ray.data.read_parquet("test_data_path"),
}
# let Ray Train take care of the synchronization training
trainer = TorchTrainer(
train_loop_per_worker=train_loop,
datasets=datasets,
num_workers=4,
resources_per_worker={"CPU": 16},
)
result = trainer.fit()In the code above, model can be any PyTorch model with its own forward and training logic — Ray Train only decides how many workers to launch and what resources to use (as specified by num_workers and resources_per_worker). Under the hood, The TorchTrainer takes care of launching workers, setting up DDP, syncing gradients, and tearing everything down.
In practice, we start with Ray Data to preprocess our data, and then use Ray Train to launch training jobs. We typically use 4 to 16 workers on dev clusters for experimentation and scale up to 64+ workers in production without any changes in the model setup and training code required.
Ray Tune
The final piece of our workflow is Ray Tune, which we use for scalable hyperparameter tuning.
Ray Tune provides a built-in approach for executing large-scale search experiments using various popular optimization libraries like Optuna and Nevergrad. It also integrates with schedulers (e.g., HyperBand), custom experimentation tools, and reporting layers (e.g., MLflow), and supports various modelers, including scikit-learn, XGBoost, PyTorch, TensorFlow, and Keras.
At Klaviyo, we've released an internal template to reduce the friction between model development and deployment. Using our Ray Serve infrastructure (see here for more details on that!), data scientists can plug their models into a utils directory (requiring only minor adjustments for code such as schedulers and search spaces), configure their GPUs/CPUs, replicas, and output configs and deploy the models for training. We also have configurations written to output results to a centralized MLflow experiment.
The template leverages roughly the following pattern:
Here's a simplified example using Optuna with Ray Tune to optimize ARIMA model parameters:
from ray import tune
from ray.tune.search.optuna import OptunaSearch
def objective(config: Dict[str, Any]) -> Dict[str, float]:
"""Objective function for Ray Tune optimization.
Args:
config: Dictionary with hyperparameters from Ray Tune.
Returns:
Dictionary of optimized metrics.
"""
# hyperparameters
p, d, q = config["p"], config["d"], config["q"]
arima_model = ARIMA(order=(p, d, q))
# example functions for training and testing model
train_data, test_data = data_prep.split_train_test()
fx_df = arima_utils.forecast(train_data)
mae, median_ae, mean_error = arima_utils.evaluate(fx_df, test_data)
return {
"mae": mae,
"median_ae": median_ae,
"mean_error": mean_error,
"fx_df": fx_df
}
search_space = {
"p": tune.randint(0, 6),
"d": tune.randint(0, 3),
"q": tune.randint(0, 6),
}
search_algo = OptunaSearch(
metric="mae",
mode="min"
)
num_trials = 1000
analysis = tune.run(
objective,
config=search_space,
search_alg=search_algo,
num_samples=num_trials,
resources_per_trial={"cpu": 2, "memory": 4 * 1024 * 1024 * 1024}
metric="mae",
mode="min",
verbose=2,
name="ray_tune_example"
)
best_params = analysis.best_config
best_result = analysis.best_resultThis same structure can be reused across projects with minimal changes, allowing teams to more easily scale their tuning pipelines.
For example, as referenced earlier, we have leveraged this approach to tune an MSTL model with an ARIMA trend forecasted to predict our support demand on a thirty-minute basis. While this sample code is inherently foundational, it leaves room to easily add nuances such as outlier detection and handling, weekend forecasting, and cross-validation, to name a few.
Conclusion
By leveraging Ray Data, Ray Train, and Ray Tune, Klaviyo has built more scalable, cost-efficient, and production-ready machine learning workflows. These tools enable the seamless handling of massive datasets, distributed model training, and efficient hyperparameter tuning, all with minimal overhead.
With Ray's ecosystem, we can scale from local development to multi-node clusters, accelerating our ML development cycle, controlling costs, and improving the performance of our deployed models.