ML Pipelines & Orchestration
In production, ML is not a single script — it's a pipeline of interconnected steps: data ingestion, validation, preprocessing, training, evaluation, and deployment. Pipeline orchestrators automate and manage these steps reliably.
Why Pipelines?
| Problem | How Pipelines Solve It |
|---|---|
| "Works on my machine" | Containerized steps with explicit dependencies |
| Manual re-runs | Automated triggers (schedule, data change, drift alert) |
| Partial failures | Retry logic, checkpointing, resume from failure |
| No audit trail | Every run logged with inputs, outputs, and metadata |
| Team coordination | Shared definitions, version-controlled DAGs |
DAG — Directed Acyclic Graph
Apache Airflow
Airflow is the most widely-used workflow orchestrator. Originally built by Airbnb, it's now an Apache project used by thousands of companies.
Core concepts:
1from airflow import DAG
2from airflow.operators.python import PythonOperator
3from airflow.operators.bash import BashOperator
4from datetime import datetime, timedelta
5
6# Define the DAG
7default_args = {
8 "owner": "ml-team",
9 "retries": 2,
10 "retry_delay": timedelta(minutes=5),
11 "email_on_failure": True,
12 "email": ["ml-alerts@company.com"],
13}
14
15dag = DAG(
16 dag_id="ml_training_pipeline",
17 default_args=default_args,
18 description="Daily model retraining pipeline",
19 schedule_interval="0 2 * * *", # Run at 2 AM daily
20 start_date=datetime(2024, 1, 1),
21 catchup=False,
22 tags=["ml", "production"],
23)
24
25# Define tasks
26def extract_data(**context):
27 """Pull latest data from the data warehouse."""
28 print("Extracting data from BigQuery...")
29 # In production: query data warehouse, save to GCS/S3
30 return {"rows": 50000, "path": "gs://bucket/data/2024-01-15.parquet"}
31
32def validate_data(**context):
33 """Check data quality and schema."""
34 ti = context['ti']
35 data_info = ti.xcom_pull(task_ids='extract_data')
36 print(f"Validating {data_info['rows']} rows...")
37 # In production: run Great Expectations or TFX data validation
38 if data_info['rows'] < 1000:
39 raise ValueError("Too few rows — possible data issue!")
40
41def train_model(**context):
42 """Train the model on validated data."""
43 print("Training model...")
44 # In production: load data, train, log to MLflow
45 return {"accuracy": 0.94, "model_path": "gs://bucket/models/v12"}
46
47def evaluate_model(**context):
48 """Compare new model against production baseline."""
49 ti = context['ti']
50 results = ti.xcom_pull(task_ids='train_model')
51 print(f"New model accuracy: {results['accuracy']}")
52 if results['accuracy'] < 0.90:
53 raise ValueError("Model below threshold — not deploying!")
54
55def deploy_model(**context):
56 """Push model to serving infrastructure."""
57 print("Deploying model to TF Serving...")
58 # In production: update model version in TF Serving / SageMaker
59
60# Build the DAG
61t1 = PythonOperator(task_id='extract_data', python_callable=extract_data, dag=dag)
62t2 = PythonOperator(task_id='validate_data', python_callable=validate_data, dag=dag)
63t3 = PythonOperator(task_id='train_model', python_callable=train_model, dag=dag)
64t4 = PythonOperator(task_id='evaluate_model', python_callable=evaluate_model, dag=dag)
65t5 = PythonOperator(task_id='deploy_model', python_callable=deploy_model, dag=dag)
66
67# Define dependencies
68t1 >> t2 >> t3 >> t4 >> t5Kubeflow Pipelines
Kubeflow Pipelines runs on Kubernetes and is designed specifically for ML workloads. Each step runs in its own container, making it easy to use different environments (GPU for training, CPU for preprocessing).
1from kfp import dsl, compiler
2
3@dsl.component(base_image="python:3.10",
4 packages_to_install=["pandas", "scikit-learn"])
5def preprocess_data(input_path: str, output_path: dsl.Output[dsl.Dataset]):
6 """Load and preprocess the dataset."""
7 import pandas as pd
8 from sklearn.preprocessing import StandardScaler
9
10 df = pd.read_csv(input_path)
11 scaler = StandardScaler()
12 df[df.columns] = scaler.fit_transform(df)
13 df.to_csv(output_path.path, index=False)
14
15@dsl.component(base_image="tensorflow/tensorflow:2.15.0",
16 packages_to_install=["pandas"])
17def train_model(dataset: dsl.Input[dsl.Dataset],
18 model: dsl.Output[dsl.Model],
19 epochs: int = 50):
20 """Train a TensorFlow model."""
21 import tensorflow as tf
22 import pandas as pd
23
24 df = pd.read_csv(dataset.path)
25 # ... training logic ...
26 tf_model.save(model.path)
27
28@dsl.component(base_image="python:3.10")
29def evaluate_model(model: dsl.Input[dsl.Model]) -> float:
30 """Evaluate model and return accuracy."""
31 # ... evaluation logic ...
32 return 0.94
33
34@dsl.pipeline(name="ML Training Pipeline",
35 description="End-to-end training pipeline")
36def ml_pipeline(data_path: str = "gs://bucket/data.csv"):
37 preprocess_task = preprocess_data(input_path=data_path)
38 train_task = train_model(
39 dataset=preprocess_task.outputs["output_path"],
40 epochs=50
41 )
42 eval_task = evaluate_model(model=train_task.outputs["model"])
43
44# Compile the pipeline
45compiler.Compiler().compile(ml_pipeline, "pipeline.yaml")TFX Pipeline Components
TFX provides a complete set of production ML components:
ExampleGen → StatisticsGen → SchemaGen → ExampleValidator
↓
Transform → Trainer → Tuner
↓
Evaluator → Pusher
Each component: 1. Reads input artifacts from the metadata store 2. Executes its logic 3. Writes output artifacts back to the metadata store 4. Registers everything for lineage tracking
1import tfx
2from tfx.components import (
3 CsvExampleGen, StatisticsGen, SchemaGen, ExampleValidator,
4 Transform, Trainer, Evaluator, Pusher
5)
6from tfx.orchestration.experimental.interactive.interactive_context \
7 import InteractiveContext
8
9# Create a TFX pipeline context
10context = InteractiveContext()
11
12# 1. Ingest data
13example_gen = CsvExampleGen(input_base='data/')
14context.run(example_gen)
15
16# 2. Compute statistics
17statistics_gen = StatisticsGen(
18 examples=example_gen.outputs['examples']
19)
20context.run(statistics_gen)
21
22# 3. Infer schema
23schema_gen = SchemaGen(
24 statistics=statistics_gen.outputs['statistics']
25)
26context.run(schema_gen)
27
28# 4. Validate data against schema
29example_validator = ExampleValidator(
30 statistics=statistics_gen.outputs['statistics'],
31 schema=schema_gen.outputs['schema']
32)
33context.run(example_validator)
34
35# 5. Feature engineering
36transform = Transform(
37 examples=example_gen.outputs['examples'],
38 schema=schema_gen.outputs['schema'],
39 module_file='transform_module.py' # Contains preprocessing_fn
40)
41context.run(transform)
42
43# 6. Train the model
44trainer = Trainer(
45 module_file='trainer_module.py', # Contains run_fn
46 examples=transform.outputs['transformed_examples'],
47 transform_graph=transform.outputs['transform_graph'],
48 schema=schema_gen.outputs['schema'],
49 train_args=tfx.proto.TrainArgs(num_steps=1000),
50 eval_args=tfx.proto.EvalArgs(num_steps=200)
51)
52context.run(trainer)CI/CD for ML
CI/CD for ML extends traditional software CI/CD with ML-specific stages:
Code Change → Lint/Test → Train → Evaluate → Validate → Deploy
↑ ↓
└──── Monitor ← Serve ← Push ←────────────┘
Key differences from software CI/CD:
GitHub Actions for ML
You can use GitHub Actions to automate ML workflows:
1# .github/workflows/ml-pipeline.yml
2name: ML Training Pipeline
3
4on:
5 push:
6 branches: [main]
7 paths:
8 - 'src/model/**'
9 - 'data/config.yaml'
10 schedule:
11 - cron: '0 6 * * 1' # Weekly on Monday at 6 AM
12
13jobs:
14 validate-data:
15 runs-on: ubuntu-latest
16 steps:
17 - uses: actions/checkout@v4
18 - uses: actions/setup-python@v5
19 with:
20 python-version: '3.10'
21 - run: pip install -r requirements.txt
22 - run: python scripts/validate_data.py
23 - uses: actions/upload-artifact@v4
24 with:
25 name: validation-report
26 path: reports/data_validation.html
27
28 train:
29 needs: validate-data
30 runs-on: ubuntu-latest # Or self-hosted GPU runner
31 steps:
32 - uses: actions/checkout@v4
33 - uses: actions/setup-python@v5
34 with:
35 python-version: '3.10'
36 - run: pip install -r requirements.txt
37 - run: python scripts/train.py --config config/prod.yaml
38 - uses: actions/upload-artifact@v4
39 with:
40 name: trained-model
41 path: models/
42
43 evaluate:
44 needs: train
45 runs-on: ubuntu-latest
46 steps:
47 - uses: actions/checkout@v4
48 - uses: actions/download-artifact@v4
49 with:
50 name: trained-model
51 path: models/
52 - run: pip install -r requirements.txt
53 - run: python scripts/evaluate.py --threshold 0.90
54 # Fails the job if accuracy < threshold
55
56 deploy:
57 needs: evaluate
58 if: github.ref == 'refs/heads/main'
59 runs-on: ubuntu-latest
60 steps:
61 - run: echo "Deploying model to production..."
62 # Push to model registry / update serving endpoint