Skip to main content

Scaling AI Systems

Distributed training, inference optimization, and cost management

~50 min
Listen to this lesson

Scaling AI Systems

As models and datasets grow, single-machine training becomes impractical. Modern AI systems require distributed training, optimized inference, and careful cost management. In this lesson, you'll learn the patterns for scaling every stage of the ML lifecycle.

Distributed Training

When a model or dataset is too large for a single GPU, you distribute the work across multiple GPUs or machines. There are three main strategies:

Data Parallelism

The same model is replicated on each GPU. Each GPU processes a different batch of data, computes gradients, and then gradients are averaged across all GPUs.

GPU 0: Model Copy → Batch 0 → Gradients 0 ─┐
GPU 1: Model Copy → Batch 1 → Gradients 1 ─┤─► Average → Update All Copies
GPU 2: Model Copy → Batch 2 → Gradients 2 ─┤
GPU 3: Model Copy → Batch 3 → Gradients 3 ─┘

Best for: Most training scenarios. The model fits on one GPU but training is slow.

Model Parallelism

The model is split across GPUs. Each GPU holds a portion of the model's layers or parameters.

GPU 0: Layers 1-8    → Activations → GPU 1
GPU 1: Layers 9-16   → Activations → GPU 2
GPU 2: Layers 17-24  → Output

Best for: Models too large to fit on a single GPU (e.g., GPT-scale LLMs).

Pipeline Parallelism

Combines data and model parallelism. The model is split across GPUs (model parallelism), and micro-batches are pipelined through the stages to keep all GPUs busy.

Time →  t0    t1    t2    t3    t4
GPU 0: [B0]  [B1]  [B2]  [B3]
GPU 1:       [B0]  [B1]  [B2]  [B3]
GPU 2:             [B0]  [B1]  [B2]

Best for: Very large models with large datasets (maximum throughput).

tf.distribute Strategies

TensorFlow provides built-in distribution strategies:

StrategyGPUsMachinesTypeUse Case
MirroredStrategyMultiple1Data parallelismMost common — multi-GPU on one machine
MultiWorkerMirroredStrategyMultipleMultipleData parallelismDistributed across machines
TPUStrategyN/ATPU podData parallelismGoogle Cloud TPUs
ParameterServerStrategyMultipleMultipleAsync data parallelismLarge-scale with parameter servers

python
1import tensorflow as tf
2
3# --- MirroredStrategy: Multi-GPU on a single machine ---
4strategy = tf.distribute.MirroredStrategy()
5print(f"Number of devices: {strategy.num_replicas_in_sync}")
6
7# Wrap model creation and compilation inside the strategy scope
8with strategy.scope():
9    model = tf.keras.Sequential([
10        tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
11        tf.keras.layers.Dropout(0.3),
12        tf.keras.layers.Dense(256, activation='relu'),
13        tf.keras.layers.Dropout(0.3),
14        tf.keras.layers.Dense(10, activation='softmax')
15    ])
16
17    model.compile(
18        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
19        loss='sparse_categorical_crossentropy',
20        metrics=['accuracy']
21    )
22
23# Training happens normally — tf.distribute handles the rest
24# The global batch size is split across GPUs
25# E.g., batch_size=256 with 4 GPUs → 64 per GPU
26GLOBAL_BATCH_SIZE = 256
27
28(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
29x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
30x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
31
32model.fit(
33    x_train, y_train,
34    batch_size=GLOBAL_BATCH_SIZE,
35    epochs=10,
36    validation_data=(x_test, y_test)
37)
python
1# --- MultiWorkerMirroredStrategy: Multiple machines ---
2import json
3import os
4
5# Each worker needs a TF_CONFIG environment variable
6# Worker 0 (on machine 1):
7os.environ['TF_CONFIG'] = json.dumps({
8    'cluster': {
9        'worker': ['machine1:12345', 'machine2:12345']
10    },
11    'task': {'type': 'worker', 'index': 0}
12})
13
14strategy = tf.distribute.MultiWorkerMirroredStrategy()
15
16with strategy.scope():
17    model = tf.keras.Sequential([...])
18    model.compile(...)
19
20# Each worker runs the same script with different TF_CONFIG
21model.fit(x_train, y_train, batch_size=GLOBAL_BATCH_SIZE, epochs=10)

Batch vs Real-Time Inference Architecture

Batch Inference

Process large volumes of data at scheduled intervals.

Scheduler (Airflow)
     │
     ▼
Data Warehouse ─► Feature Pipeline ─► Model ─► Prediction Store
(BigQuery)        (Spark/Beam)         (GPU)     (Redis/Postgres)
                                                       │
                                                       ▼
                                                  Downstream Apps

Advantages: Cost-efficient (spot instances, scale to zero), high throughput. Disadvantages: High latency (hours), stale predictions.

Real-Time Inference

Serve predictions on-demand with low latency.

Client Request
     │
     ▼
Load Balancer ─► Model Server (TF Serving / TorchServe)
                      │
                      ├─► Feature Store (online — Redis)
                      │
                      ▼
                 Prediction ─► Response (< 100ms)
                      │
                      ▼
                 Log to Monitoring

Advantages: Real-time, fresh predictions. Disadvantages: Always-on infrastructure, higher cost.

Auto-Scaling ML Serving

Auto-scaling adjusts the number of model serving replicas based on demand:

Key Metrics for Scaling Decisions

MetricScale Up WhenScale Down When
Request latency (p95)> 200ms< 50ms
Queue depth> 100 pending< 10 pending
GPU utilization> 80%< 20%
CPU utilization> 70%< 30%
Requests per second> capacity * 0.8< capacity * 0.3

python
1# Kubernetes HPA (Horizontal Pod Autoscaler) for ML serving
2# This would be a YAML file in practice — shown as Python dict for clarity
3
4hpa_config = {
5    "apiVersion": "autoscaling/v2",
6    "kind": "HorizontalPodAutoscaler",
7    "metadata": {"name": "ml-model-server-hpa"},
8    "spec": {
9        "scaleTargetRef": {
10            "apiVersion": "apps/v1",
11            "kind": "Deployment",
12            "name": "fraud-detector-serving"
13        },
14        "minReplicas": 2,       # Always at least 2 for availability
15        "maxReplicas": 20,      # Cap to control costs
16        "metrics": [
17            {
18                "type": "Resource",
19                "resource": {
20                    "name": "cpu",
21                    "target": {"type": "Utilization", "averageUtilization": 70}
22                }
23            },
24            {
25                "type": "Pods",
26                "pods": {
27                    "metric": {"name": "prediction_latency_p95"},
28                    "target": {"type": "AverageValue", "averageValue": "200m"}
29                }
30            }
31        ],
32        "behavior": {
33            "scaleUp": {
34                "stabilizationWindowSeconds": 60,  # Wait 1min before scaling up
35                "policies": [{"type": "Percent", "value": 50, "periodSeconds": 60}]
36            },
37            "scaleDown": {
38                "stabilizationWindowSeconds": 300,  # Wait 5min before scaling down
39                "policies": [{"type": "Percent", "value": 20, "periodSeconds": 60}]
40            }
41        }
42    }
43}
44
45import json
46print(json.dumps(hpa_config, indent=2))

Caching Strategies for ML

Caching can dramatically reduce inference costs and latency:

StrategyWhat's CachedHit RateBest For
Exact matchSame input → same outputLow-MediumRepeated queries (search, recommendations)
Embedding similaritySimilar inputs → same outputMedium-HighSemantic search, Q&A
Feature cachingPre-computed featuresHighFeature store (online serving)
Prediction cachingPre-computed predictionsVery HighBatch + cache pattern

Load Balancing for ML

ML serving has unique load balancing needs:

  • Model-aware routing: Route to replicas with the correct model version
  • GPU-aware scheduling: Balance across GPUs, not just CPUs
  • Sticky sessions: For stateful models (sequence models, chatbots)
  • Weighted routing: For A/B tests and canary deployments
  • Cost Optimization

    ML infrastructure can be expensive. Key strategies:

    1. Spot/preemptible instances for training (60-90% cheaper) 2. Right-size GPU types: Don't use A100s when T4s will do 3. Mixed precision training: float16 computes are 2x faster on modern GPUs 4. Model distillation: Train a smaller model to mimic a larger one 5. Quantization: Reduce serving costs with INT8 inference 6. Batch inference + cache: Pre-compute when possible 7. Scale to zero: Shut down when there's no traffic 8. Spot instances + checkpointing: Save training state to resume after preemption

    python
    1import tensorflow as tf
    2
    3# --- Mixed Precision Training (float16 compute, float32 storage) ---
    4# This can nearly double training speed on GPUs with Tensor Cores (V100, A100, T4)
    5policy = tf.keras.mixed_precision.Policy('mixed_float16')
    6tf.keras.mixed_precision.set_global_policy(policy)
    7
    8print(f"Compute dtype: {policy.compute_dtype}")  # float16
    9print(f"Variable dtype: {policy.variable_dtype}")  # float32
    10
    11model = tf.keras.Sequential([
    12    tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    13    tf.keras.layers.Dense(256, activation='relu'),
    14    # Output layer must be float32 for numerical stability
    15    tf.keras.layers.Dense(10, dtype='float32')
    16])
    17
    18# Use a loss scale optimizer to prevent underflow
    19optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    20model.compile(optimizer=optimizer,
    21              loss='sparse_categorical_crossentropy',
    22              metrics=['accuracy'])
    23
    24# Training proceeds normally — mixed precision is handled transparently
    25model.fit(x_train, y_train, epochs=10, batch_size=256)

    The Cost Hierarchy

    From cheapest to most expensive inference: 1. **Cached prediction** (no compute) 2. **Quantized batch inference** on CPU 3. **Standard batch inference** on GPU 4. **Quantized real-time** on CPU 5. **Standard real-time** on GPU 6. **Large model real-time** on multiple GPUs Always ask: can we move workload UP this ladder?