Scaling AI Systems
As models and datasets grow, single-machine training becomes impractical. Modern AI systems require distributed training, optimized inference, and careful cost management. In this lesson, you'll learn the patterns for scaling every stage of the ML lifecycle.
Distributed Training
When a model or dataset is too large for a single GPU, you distribute the work across multiple GPUs or machines. There are three main strategies:
Data Parallelism
The same model is replicated on each GPU. Each GPU processes a different batch of data, computes gradients, and then gradients are averaged across all GPUs.GPU 0: Model Copy → Batch 0 → Gradients 0 ─┐
GPU 1: Model Copy → Batch 1 → Gradients 1 ─┤─► Average → Update All Copies
GPU 2: Model Copy → Batch 2 → Gradients 2 ─┤
GPU 3: Model Copy → Batch 3 → Gradients 3 ─┘
Best for: Most training scenarios. The model fits on one GPU but training is slow.
Model Parallelism
The model is split across GPUs. Each GPU holds a portion of the model's layers or parameters.GPU 0: Layers 1-8 → Activations → GPU 1
GPU 1: Layers 9-16 → Activations → GPU 2
GPU 2: Layers 17-24 → Output
Best for: Models too large to fit on a single GPU (e.g., GPT-scale LLMs).
Pipeline Parallelism
Combines data and model parallelism. The model is split across GPUs (model parallelism), and micro-batches are pipelined through the stages to keep all GPUs busy.Time → t0 t1 t2 t3 t4
GPU 0: [B0] [B1] [B2] [B3]
GPU 1: [B0] [B1] [B2] [B3]
GPU 2: [B0] [B1] [B2]
Best for: Very large models with large datasets (maximum throughput).
tf.distribute Strategies
TensorFlow provides built-in distribution strategies:
| Strategy | GPUs | Machines | Type | Use Case |
|---|---|---|---|---|
| MirroredStrategy | Multiple | 1 | Data parallelism | Most common — multi-GPU on one machine |
| MultiWorkerMirroredStrategy | Multiple | Multiple | Data parallelism | Distributed across machines |
| TPUStrategy | N/A | TPU pod | Data parallelism | Google Cloud TPUs |
| ParameterServerStrategy | Multiple | Multiple | Async data parallelism | Large-scale with parameter servers |
1import tensorflow as tf
2
3# --- MirroredStrategy: Multi-GPU on a single machine ---
4strategy = tf.distribute.MirroredStrategy()
5print(f"Number of devices: {strategy.num_replicas_in_sync}")
6
7# Wrap model creation and compilation inside the strategy scope
8with strategy.scope():
9 model = tf.keras.Sequential([
10 tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
11 tf.keras.layers.Dropout(0.3),
12 tf.keras.layers.Dense(256, activation='relu'),
13 tf.keras.layers.Dropout(0.3),
14 tf.keras.layers.Dense(10, activation='softmax')
15 ])
16
17 model.compile(
18 optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
19 loss='sparse_categorical_crossentropy',
20 metrics=['accuracy']
21 )
22
23# Training happens normally — tf.distribute handles the rest
24# The global batch size is split across GPUs
25# E.g., batch_size=256 with 4 GPUs → 64 per GPU
26GLOBAL_BATCH_SIZE = 256
27
28(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
29x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
30x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
31
32model.fit(
33 x_train, y_train,
34 batch_size=GLOBAL_BATCH_SIZE,
35 epochs=10,
36 validation_data=(x_test, y_test)
37)1# --- MultiWorkerMirroredStrategy: Multiple machines ---
2import json
3import os
4
5# Each worker needs a TF_CONFIG environment variable
6# Worker 0 (on machine 1):
7os.environ['TF_CONFIG'] = json.dumps({
8 'cluster': {
9 'worker': ['machine1:12345', 'machine2:12345']
10 },
11 'task': {'type': 'worker', 'index': 0}
12})
13
14strategy = tf.distribute.MultiWorkerMirroredStrategy()
15
16with strategy.scope():
17 model = tf.keras.Sequential([...])
18 model.compile(...)
19
20# Each worker runs the same script with different TF_CONFIG
21model.fit(x_train, y_train, batch_size=GLOBAL_BATCH_SIZE, epochs=10)Batch vs Real-Time Inference Architecture
Batch Inference
Process large volumes of data at scheduled intervals.Scheduler (Airflow)
│
▼
Data Warehouse ─► Feature Pipeline ─► Model ─► Prediction Store
(BigQuery) (Spark/Beam) (GPU) (Redis/Postgres)
│
▼
Downstream Apps
Advantages: Cost-efficient (spot instances, scale to zero), high throughput. Disadvantages: High latency (hours), stale predictions.
Real-Time Inference
Serve predictions on-demand with low latency.Client Request
│
▼
Load Balancer ─► Model Server (TF Serving / TorchServe)
│
├─► Feature Store (online — Redis)
│
▼
Prediction ─► Response (< 100ms)
│
▼
Log to Monitoring
Advantages: Real-time, fresh predictions. Disadvantages: Always-on infrastructure, higher cost.
Auto-Scaling ML Serving
Auto-scaling adjusts the number of model serving replicas based on demand:
Key Metrics for Scaling Decisions
| Metric | Scale Up When | Scale Down When |
|---|---|---|
| Request latency (p95) | > 200ms | < 50ms |
| Queue depth | > 100 pending | < 10 pending |
| GPU utilization | > 80% | < 20% |
| CPU utilization | > 70% | < 30% |
| Requests per second | > capacity * 0.8 | < capacity * 0.3 |
1# Kubernetes HPA (Horizontal Pod Autoscaler) for ML serving
2# This would be a YAML file in practice — shown as Python dict for clarity
3
4hpa_config = {
5 "apiVersion": "autoscaling/v2",
6 "kind": "HorizontalPodAutoscaler",
7 "metadata": {"name": "ml-model-server-hpa"},
8 "spec": {
9 "scaleTargetRef": {
10 "apiVersion": "apps/v1",
11 "kind": "Deployment",
12 "name": "fraud-detector-serving"
13 },
14 "minReplicas": 2, # Always at least 2 for availability
15 "maxReplicas": 20, # Cap to control costs
16 "metrics": [
17 {
18 "type": "Resource",
19 "resource": {
20 "name": "cpu",
21 "target": {"type": "Utilization", "averageUtilization": 70}
22 }
23 },
24 {
25 "type": "Pods",
26 "pods": {
27 "metric": {"name": "prediction_latency_p95"},
28 "target": {"type": "AverageValue", "averageValue": "200m"}
29 }
30 }
31 ],
32 "behavior": {
33 "scaleUp": {
34 "stabilizationWindowSeconds": 60, # Wait 1min before scaling up
35 "policies": [{"type": "Percent", "value": 50, "periodSeconds": 60}]
36 },
37 "scaleDown": {
38 "stabilizationWindowSeconds": 300, # Wait 5min before scaling down
39 "policies": [{"type": "Percent", "value": 20, "periodSeconds": 60}]
40 }
41 }
42 }
43}
44
45import json
46print(json.dumps(hpa_config, indent=2))Caching Strategies for ML
Caching can dramatically reduce inference costs and latency:
| Strategy | What's Cached | Hit Rate | Best For |
|---|---|---|---|
| Exact match | Same input → same output | Low-Medium | Repeated queries (search, recommendations) |
| Embedding similarity | Similar inputs → same output | Medium-High | Semantic search, Q&A |
| Feature caching | Pre-computed features | High | Feature store (online serving) |
| Prediction caching | Pre-computed predictions | Very High | Batch + cache pattern |
Load Balancing for ML
ML serving has unique load balancing needs:
Cost Optimization
ML infrastructure can be expensive. Key strategies:
1. Spot/preemptible instances for training (60-90% cheaper) 2. Right-size GPU types: Don't use A100s when T4s will do 3. Mixed precision training: float16 computes are 2x faster on modern GPUs 4. Model distillation: Train a smaller model to mimic a larger one 5. Quantization: Reduce serving costs with INT8 inference 6. Batch inference + cache: Pre-compute when possible 7. Scale to zero: Shut down when there's no traffic 8. Spot instances + checkpointing: Save training state to resume after preemption
1import tensorflow as tf
2
3# --- Mixed Precision Training (float16 compute, float32 storage) ---
4# This can nearly double training speed on GPUs with Tensor Cores (V100, A100, T4)
5policy = tf.keras.mixed_precision.Policy('mixed_float16')
6tf.keras.mixed_precision.set_global_policy(policy)
7
8print(f"Compute dtype: {policy.compute_dtype}") # float16
9print(f"Variable dtype: {policy.variable_dtype}") # float32
10
11model = tf.keras.Sequential([
12 tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
13 tf.keras.layers.Dense(256, activation='relu'),
14 # Output layer must be float32 for numerical stability
15 tf.keras.layers.Dense(10, dtype='float32')
16])
17
18# Use a loss scale optimizer to prevent underflow
19optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
20model.compile(optimizer=optimizer,
21 loss='sparse_categorical_crossentropy',
22 metrics=['accuracy'])
23
24# Training proceeds normally — mixed precision is handled transparently
25model.fit(x_train, y_train, epochs=10, batch_size=256)