Image Segmentation
While object detection draws bounding boxes around objects, segmentation provides pixel-level understanding of the image.
Types of Segmentation
Semantic Segmentation
Assigns a class label to every pixel in the image. All pixels belonging to "car" get the same label, regardless of which individual car they belong to.Instance Segmentation
Detects each object instance and provides a pixel mask for each one. Two different cars get different masks.Panoptic Segmentation
Combines semantic and instance segmentation:U-Net Architecture
U-Net is the foundational architecture for segmentation, originally developed for biomedical image segmentation (2015).
Encoder (Contracting Path)
Decoder (Expanding Path)
Skip Connections
The critical innovation: feature maps from the encoder are concatenated with the corresponding decoder feature maps. This provides:U-Net Skip Connections vs ResNet Skip Connections
Mask R-CNN
Mask R-CNN extends Faster R-CNN by adding a segmentation branch: 1. Backbone CNN extracts features (typically ResNet + FPN) 2. Region Proposal Network generates proposals 3. RoI Align (improved RoI Pooling without quantization) extracts features 4. Three parallel heads: - Classification head (what class?) - Box regression head (where exactly?) - Mask head (pixel-level mask for each instance)
Key insight: The mask head is a small FCN (Fully Convolutional Network) applied to each RoI independently, predicting a binary mask per class.
Segment Anything Model (SAM)
Meta's SAM (2023) is a foundation model for segmentation:
SAM 2 (2024)
Extends SAM to video with temporal consistency:Loss Functions for Segmentation
Pixel-wise Cross-Entropy Loss
Standard classification loss applied to each pixel independently: $$L = -\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C} y_{i,c} \log(p_{i,c})$$Dice Loss
Based on the Dice coefficient (F1 score for sets): $$\text{Dice} = \frac{2|A \cap B|}{|A| + |B|} = \frac{2\sum p_i g_i}{\sum p_i + \sum g_i}$$ $$L_{\text{dice}} = 1 - \text{Dice}$$Focal Loss
Downweights easy examples to focus learning on hard ones: $$L_{\text{focal}} = -\alpha_t (1 - p_t)^\gamma \log(p_t)$$Evaluation Metrics
| Metric | Description |
|---|---|
| Pixel Accuracy | % of pixels correctly classified (misleading with imbalance) |
| IoU (per class) | Intersection / Union for each class |
| mIoU | Mean IoU across all classes â the standard metric |
| Dice Score | 2 * intersection / (pred + GT) â equivalent to F1 |
| Boundary F1 | F1 score computed only on boundary pixels |
1# ==============================================================
2# Simple U-Net implementation in PyTorch
3# ==============================================================
4import torch
5import torch.nn as nn
6
7class DoubleConv(nn.Module):
8 """Two 3x3 convolutions with BatchNorm and ReLU."""
9 def __init__(self, in_ch, out_ch):
10 super().__init__()
11 self.block = nn.Sequential(
12 nn.Conv2d(in_ch, out_ch, 3, padding=1),
13 nn.BatchNorm2d(out_ch),
14 nn.ReLU(inplace=True),
15 nn.Conv2d(out_ch, out_ch, 3, padding=1),
16 nn.BatchNorm2d(out_ch),
17 nn.ReLU(inplace=True),
18 )
19
20 def forward(self, x):
21 return self.block(x)
22
23
24class UNet(nn.Module):
25 def __init__(self, in_channels=3, num_classes=1):
26 super().__init__()
27 # Encoder
28 self.enc1 = DoubleConv(in_channels, 64)
29 self.enc2 = DoubleConv(64, 128)
30 self.enc3 = DoubleConv(128, 256)
31 self.enc4 = DoubleConv(256, 512)
32
33 # Bottleneck
34 self.bottleneck = DoubleConv(512, 1024)
35
36 # Decoder
37 self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
38 self.dec4 = DoubleConv(1024, 512) # 512 from up + 512 from skip
39 self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
40 self.dec3 = DoubleConv(512, 256)
41 self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
42 self.dec2 = DoubleConv(256, 128)
43 self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
44 self.dec1 = DoubleConv(128, 64)
45
46 # Output
47 self.out_conv = nn.Conv2d(64, num_classes, 1)
48 self.pool = nn.MaxPool2d(2)
49
50 def forward(self, x):
51 # Encoder
52 e1 = self.enc1(x) # [B, 64, H, W]
53 e2 = self.enc2(self.pool(e1)) # [B, 128, H/2, W/2]
54 e3 = self.enc3(self.pool(e2)) # [B, 256, H/4, W/4]
55 e4 = self.enc4(self.pool(e3)) # [B, 512, H/8, W/8]
56
57 # Bottleneck
58 b = self.bottleneck(self.pool(e4)) # [B, 1024, H/16, W/16]
59
60 # Decoder with skip connections (concatenation)
61 d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
62 d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
63 d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
64 d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
65
66 return self.out_conv(d1)
67
68# Test the architecture
69model = UNet(in_channels=3, num_classes=21) # 21 classes for Pascal VOC
70x = torch.randn(2, 3, 256, 256)
71out = model(x)
72print(f"Input shape: {x.shape}") # [2, 3, 256, 256]
73print(f"Output shape: {out.shape}") # [2, 21, 256, 256]
74print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")1# ==============================================================
2# Dice Loss and combined loss implementation
3# ==============================================================
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7
8class DiceLoss(nn.Module):
9 """Dice Loss for binary or multiclass segmentation."""
10 def __init__(self, smooth=1.0):
11 super().__init__()
12 self.smooth = smooth
13
14 def forward(self, pred, target):
15 # pred: [B, C, H, W] (logits)
16 # target: [B, H, W] (class indices)
17 pred = F.softmax(pred, dim=1)
18 num_classes = pred.shape[1]
19
20 # One-hot encode target
21 target_onehot = F.one_hot(target, num_classes) # [B, H, W, C]
22 target_onehot = target_onehot.permute(0, 3, 1, 2).float() # [B, C, H, W]
23
24 # Compute dice per class
25 intersection = (pred * target_onehot).sum(dim=(2, 3))
26 union = pred.sum(dim=(2, 3)) + target_onehot.sum(dim=(2, 3))
27
28 dice = (2 * intersection + self.smooth) / (union + self.smooth)
29 return 1 - dice.mean()
30
31
32class CombinedLoss(nn.Module):
33 """Combine Cross-Entropy and Dice Loss."""
34 def __init__(self, ce_weight=0.5, dice_weight=0.5):
35 super().__init__()
36 self.ce = nn.CrossEntropyLoss()
37 self.dice = DiceLoss()
38 self.ce_weight = ce_weight
39 self.dice_weight = dice_weight
40
41 def forward(self, pred, target):
42 return (self.ce_weight * self.ce(pred, target) +
43 self.dice_weight * self.dice(pred, target))
44
45# Usage
46criterion = CombinedLoss(ce_weight=0.5, dice_weight=0.5)
47pred = torch.randn(4, 21, 256, 256) # predictions for 4 images, 21 classes
48target = torch.randint(0, 21, (4, 256, 256)) # ground truth labels
49loss = criterion(pred, target)
50print(f"Combined loss: {loss.item():.4f}")1# ==============================================================
2# Using Segment Anything Model (SAM)
3# pip install segment-anything
4# ==============================================================
5from segment_anything import sam_model_registry, SamPredictor
6import numpy as np
7import cv2
8import matplotlib.pyplot as plt
9
10# Load SAM model
11sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
12sam.to("cuda" if torch.cuda.is_available() else "cpu")
13predictor = SamPredictor(sam)
14
15# Load and set image
16image = cv2.imread("example.jpg")
17image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
18predictor.set_image(image_rgb)
19
20# Segment with a point prompt
21# (x, y) point and label (1 = foreground, 0 = background)
22input_point = np.array([[500, 375]]) # click location
23input_label = np.array([1]) # foreground
24
25masks, scores, logits = predictor.predict(
26 point_coords=input_point,
27 point_labels=input_label,
28 multimask_output=True, # returns 3 masks at different granularities
29)
30
31# Visualize the best mask
32best_idx = np.argmax(scores)
33best_mask = masks[best_idx]
34
35fig, axes = plt.subplots(1, 3, figsize=(18, 6))
36axes[0].imshow(image_rgb)
37axes[0].set_title("Original Image")
38
39axes[1].imshow(image_rgb)
40axes[1].imshow(best_mask, alpha=0.5, cmap="jet")
41axes[1].scatter(*input_point[0], c="red", s=200, marker="*")
42axes[1].set_title(f"Best Mask (score: {scores[best_idx]:.3f})")
43
44# Segment with a box prompt
45input_box = np.array([100, 100, 600, 500]) # [x1, y1, x2, y2]
46masks_box, scores_box, _ = predictor.predict(box=input_box, multimask_output=False)
47
48axes[2].imshow(image_rgb)
49axes[2].imshow(masks_box[0], alpha=0.5, cmap="jet")
50axes[2].set_title("Box-Prompted Segmentation")
51
52for ax in axes:
53 ax.axis("off")
54plt.tight_layout()
55plt.show()