PyTorch Lightning & Modern Tooling
Raw PyTorch gives you full control, but the training loop boilerplate can become repetitive and error-prone. PyTorch Lightning is a lightweight wrapper that handles the engineering while you focus on the science.
Why Lightning?
Lightning eliminates boilerplate for:
You still write pure PyTorch — Lightning just organizes it.
LightningModule
Instead of a raw training loop, you define a LightningModule that encapsulates model + training logic:
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import lightning as L
5from torchmetrics import Accuracy
6
7class MNISTClassifier(L.LightningModule):
8 def __init__(self, hidden_size=256, lr=1e-3):
9 super().__init__()
10 self.save_hyperparameters() # Saves all __init__ args to self.hparams
11
12 # Model architecture (same as before)
13 self.model = nn.Sequential(
14 nn.Flatten(),
15 nn.Linear(784, hidden_size),
16 nn.ReLU(),
17 nn.Dropout(0.2),
18 nn.Linear(hidden_size, 10),
19 )
20
21 # Metrics
22 self.train_acc = Accuracy(task="multiclass", num_classes=10)
23 self.val_acc = Accuracy(task="multiclass", num_classes=10)
24
25 def forward(self, x):
26 """Used for inference: model(x)"""
27 return self.model(x)
28
29 def training_step(self, batch, batch_idx):
30 """Replaces the inner training loop."""
31 x, y = batch
32 logits = self(x)
33 loss = F.cross_entropy(logits, y)
34
35 # Log metrics (automatically handles accumulation and logging)
36 preds = logits.argmax(dim=1)
37 self.train_acc(preds, y)
38 self.log("train_loss", loss, prog_bar=True)
39 self.log("train_acc", self.train_acc, prog_bar=True)
40
41 return loss # Lightning handles backward() and optimizer.step()
42
43 def validation_step(self, batch, batch_idx):
44 """Replaces the validation loop."""
45 x, y = batch
46 logits = self(x)
47 loss = F.cross_entropy(logits, y)
48
49 preds = logits.argmax(dim=1)
50 self.val_acc(preds, y)
51 self.log("val_loss", loss, prog_bar=True)
52 self.log("val_acc", self.val_acc, prog_bar=True)
53
54 def configure_optimizers(self):
55 """Define optimizer and (optional) scheduler."""
56 optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
57 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
58 return [optimizer], [scheduler]What LightningModule Replaces
The Trainer
The Trainer handles everything outside the model: hardware, logging, callbacks, and distributed training.
1import lightning as L
2from lightning.pytorch.callbacks import (
3 ModelCheckpoint,
4 EarlyStopping,
5 LearningRateMonitor,
6 RichProgressBar,
7)
8from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
9
10# --- Callbacks ---
11checkpoint_cb = ModelCheckpoint(
12 monitor="val_loss",
13 mode="min",
14 save_top_k=3, # Keep top 3 models
15 filename="{epoch}-{val_loss:.2f}",
16)
17
18early_stop_cb = EarlyStopping(
19 monitor="val_loss",
20 patience=5, # Stop if val_loss doesn't improve for 5 epochs
21 mode="min",
22)
23
24lr_monitor = LearningRateMonitor(logging_interval="epoch")
25
26# --- Logger ---
27tb_logger = TensorBoardLogger("logs/", name="mnist")
28# wandb_logger = WandbLogger(project="mnist", name="run-01")
29
30# --- Create Trainer ---
31trainer = L.Trainer(
32 max_epochs=20,
33 accelerator="auto", # Automatically use GPU/MPS/CPU
34 devices="auto", # Use all available GPUs
35 precision="16-mixed", # Mixed precision training
36 callbacks=[checkpoint_cb, early_stop_cb, lr_monitor],
37 logger=tb_logger,
38 gradient_clip_val=1.0, # Gradient clipping
39 accumulate_grad_batches=4, # Simulate 4x larger batch size
40 log_every_n_steps=10,
41)
42
43# --- Train ---
44from torch.utils.data import DataLoader
45from torchvision import datasets, transforms
46
47transform = transforms.Compose([
48 transforms.ToTensor(),
49 transforms.Normalize((0.1307,), (0.3081,)),
50])
51
52train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform)
53val_ds = datasets.MNIST("./data", train=False, transform=transform)
54
55train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
56val_loader = DataLoader(val_ds, batch_size=256, num_workers=4)
57
58model = MNISTClassifier(hidden_size=256, lr=1e-3)
59trainer.fit(model, train_loader, val_loader)
60
61# --- Test ---
62trainer.test(model, val_loader)
63
64# --- Load best checkpoint ---
65best_model = MNISTClassifier.load_from_checkpoint(checkpoint_cb.best_model_path)LightningDataModule
Lightning Fabric
If you want some Lightning benefits without the full LightningModule structure, Fabric is a lighter-weight alternative:
1import lightning as L
2import torch
3import torch.nn as nn
4
5# Fabric: keeps your raw training loop but handles hardware
6fabric = L.Fabric(accelerator="auto", devices="auto", precision="16-mixed")
7fabric.launch()
8
9model = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
10optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
11
12# Fabric wraps model, optimizer, and dataloaders
13model, optimizer = fabric.setup(model, optimizer)
14train_loader = fabric.setup_dataloaders(train_loader)
15
16# Your normal training loop — but it works on any hardware!
17for epoch in range(10):
18 model.train()
19 for batch in train_loader:
20 images, labels = batch
21 optimizer.zero_grad()
22 output = model(images)
23 loss = nn.functional.cross_entropy(output, labels)
24 fabric.backward(loss) # Use fabric.backward instead of loss.backward
25 optimizer.step()TorchMetrics
Lightning integrates with TorchMetrics for metric computation that works correctly across distributed training:
1from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC
2from torchmetrics import MetricCollection
3
4# Individual metrics
5accuracy = Accuracy(task="multiclass", num_classes=10)
6precision = Precision(task="multiclass", num_classes=10, average="macro")
7recall = Recall(task="multiclass", num_classes=10, average="macro")
8f1 = F1Score(task="multiclass", num_classes=10, average="macro")
9
10# MetricCollection: compute all at once
11metrics = MetricCollection({
12 "accuracy": Accuracy(task="multiclass", num_classes=10),
13 "precision": Precision(task="multiclass", num_classes=10, average="macro"),
14 "recall": Recall(task="multiclass", num_classes=10, average="macro"),
15 "f1": F1Score(task="multiclass", num_classes=10, average="macro"),
16})
17
18# Usage in LightningModule
19class MyModel(L.LightningModule):
20 def __init__(self):
21 super().__init__()
22 self.train_metrics = metrics.clone(prefix="train_")
23 self.val_metrics = metrics.clone(prefix="val_")
24
25 def training_step(self, batch, batch_idx):
26 x, y = batch
27 logits = self(x)
28 loss = F.cross_entropy(logits, y)
29 self.train_metrics(logits.argmax(1), y)
30 self.log_dict(self.train_metrics, prog_bar=True)
31 return lossONNX Export
Export your PyTorch model to ONNX (Open Neural Network Exchange) for deployment in other runtimes:
1import torch
2import torch.onnx
3
4model = MNISTClassifier(hidden_size=256)
5model.eval()
6
7# Create dummy input with the correct shape
8dummy_input = torch.randn(1, 1, 28, 28)
9
10# Export to ONNX
11torch.onnx.export(
12 model, # Model to export
13 dummy_input, # Example input
14 "mnist_model.onnx", # Output file
15 export_params=True, # Store trained weights
16 opset_version=17, # ONNX opset version
17 input_names=["image"], # Input tensor name
18 output_names=["logits"], # Output tensor name
19 dynamic_axes={ # Variable-length axes (for batching)
20 "image": {0: "batch_size"},
21 "logits": {0: "batch_size"},
22 },
23)
24
25# Verify the exported model
26import onnx
27onnx_model = onnx.load("mnist_model.onnx")
28onnx.checker.check_model(onnx_model)
29print("ONNX model is valid!")
30
31# Run inference with ONNX Runtime
32import onnxruntime as ort
33session = ort.InferenceSession("mnist_model.onnx")
34input_data = dummy_input.numpy()
35result = session.run(None, {"image": input_data})
36print(f"ONNX output shape: {result[0].shape}")Why Export to ONNX?
Integration with W&B (Weights & Biases)
Weights & Biases provides experiment tracking, visualization, and collaboration:
1# pip install wandb
2import wandb
3from lightning.pytorch.loggers import WandbLogger
4
5# Option 1: Lightning integration
6wandb_logger = WandbLogger(
7 project="my-project",
8 name="experiment-01",
9 log_model=True, # Log model checkpoints to W&B
10)
11
12trainer = L.Trainer(
13 logger=wandb_logger,
14 max_epochs=20,
15)
16
17# Option 2: Raw PyTorch integration
18wandb.init(project="my-project", config={"lr": 1e-3, "epochs": 20})
19for epoch in range(20):
20 train_loss = train_one_epoch(...)
21 val_loss = evaluate(...)
22 wandb.log({
23 "train_loss": train_loss,
24 "val_loss": val_loss,
25 "epoch": epoch,
26 })
27wandb.finish()