🏭 Caso de Uso

Segmentación de Imagen Médica con U-Net

Implementación desde cero de U-Net en PyTorch y entrenamiento sobre el dataset Kvasir-SEG para segmentación de pólipos.

🐍 Python 📓 Jupyter Notebook

Segmentación de Imagen Médica con U-Net

Implementación desde cero en PyTorch y entrenamiento sobre el dataset Kvasir-SEG

Presentación

La segmentación semántica es la tarea de visión por computador que asigna una etiqueta de clase a cada píxel de una imagen. En el ámbito de la imagen médica, esta capacidad es absolutamente crítica: permite delimitar con precisión tumores, órganos, vasos sanguíneos, lesiones o estructuras celulares, proporcionando a los clínicos herramientas cuantitativas de diagnóstico y planificación quirúrgica.

Objetivos de este notebook

  1. Comprender la arquitectura U-Net en profundidad: su diseño encoder-decoder, la importancia de las skip connections, y por qué es especialmente adecuada para imagen médica.
  2. Implementar U-Net desde cero en PyTorch, bloque a bloque (DoubleConv, Encoder, Bottleneck, Decoder, capa final).
  3. Entrenar el modelo sobre el dataset Kvasir-SEG de pólipos gastrointestinales con una combinación de pérdida BCE + Dice Loss.
  4. Evaluar cuantitativamente con métricas estándar de segmentación: Dice Score, IoU, Precision y Recall.
  5. Visualizar y analizar las predicciones, los feature maps del encoder y la distribución de rendimiento por imagen.

Bases teóricas

¿Por qué U-Net?

U-Net (Ronneberger, Fischer & Brox, 2015) fue diseñada específicamente para segmentación de imagen biomédica y se ha convertido en la arquitectura de referencia en este campo. Sus ventajas clave:

  • Funciona bien con pocos datos de entrenamiento (decenas a cientos de imágenes), algo habitual en medicina donde la anotación es costosa y requiere expertos.
  • La estructura de skip connections permite combinar información de bajo nivel (bordes, texturas) con contexto semántico de alto nivel, produciendo segmentaciones con bordes nítidos.
  • Su diseño simétrico en forma de U es elegante e intuitivo.

Arquitectura en detalle

La U-Net tiene dos caminos simétricos:

  1. Encoder (camino de contracción): secuencia de bloques convolucionales + max pooling que reduce progresivamente la resolución espacial ($H \times W \to \frac{H}{2} \times \frac{W}{2} \to \ldots$) mientras aumenta el número de canales. Extrae qué hay en la imagen.

  2. Bottleneck: bloque convolucional en la resolución más baja que captura el contexto semántico más abstracto.

  3. Decoder (camino de expansión): convoluciones transpuestas (upsampling) que recuperan la resolución original. En cada nivel, se concatenan los feature maps del encoder (skip connections), aportando los detalles espaciales perdidos. Recupera dónde está cada estructura.

Formalmente, si la entrada es $X \in \mathbb{R}^{H \times W \times C}$, el encoder genera feature maps $E_l$ a resolución $\frac{H}{2^l} \times \frac{W}{2^l}$ en el nivel $l$. El decoder genera $D_l$ combinando:

$$ D_l = \text{Conv}\left(\text{Concat}\left(\text{Up}(D_{l+1}),; E_l\right)\right) $$

Las skip connections $E_l$ son el ingrediente clave: sin ellas, el decoder tendría que reconstruir los detalles espaciales solo a partir de la representación abstracta del bottleneck, produciendo segmentaciones borrosas.

Funciones de pérdida para segmentación

Para segmentación binaria se usan habitualmente:

Binary Cross-Entropy (BCE):

$$ \mathcal{L}{BCE} = -\frac{1}{N} \sum{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right] $$

Dice Loss (especialmente útil con clases desbalanceadas):

$$ \mathcal{L}_{Dice} = 1 - \frac{2 \sum_i y_i \hat{y}_i + \epsilon}{\sum_i y_i + \sum_i \hat{y}_i + \epsilon} $$

En imagen médica, el desbalanceo de clases es la norma: el fondo domina sobre la estructura de interés (un pólipo puede ocupar el 5% de la imagen). La Dice Loss aborda esto directamente al medir el solapamiento relativo, independientemente del tamaño absoluto.

Dataset: Kvasir-SEG

Kvasir-SEG (Jha et al., 2020) es uno de los datasets de referencia en segmentación de imagen médica gastrointestinal:

Característica Valor
Imágenes 1000 imágenes de colonoscopia
Anotaciones Máscaras binarias pixel-level de pólipos
Anotadores Gastroenterólogos expertos del Hospital Vestre Viken (Noruega)
Resolución Variable (desde 332×487 hasta 1920×1072)
Desafío clínico Los pólipos son precursores del cáncer colorrectal; su detección precoz reduce la mortalidad en ~50%

Herramientas utilizadas

  • PyTorch para la implementación del modelo y el entrenamiento.
  • torchvision.transforms.functional para preprocesamiento y data augmentation.
  • scikit-learn para la división train/val/test.
  • matplotlib y OpenCV para visualización.
[1]
# Librerías y configuración

import os
import glob
import zipfile
import urllib.request
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF

# Selección de dispositivo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Dispositivo: {device}')
print(f'PyTorch version: {torch.__version__}')

plt.rcParams['figure.figsize'] = (12, 7)
Dispositivo: cuda
PyTorch version: 2.10.0+cu128

1) Descarga y preparación del dataset

El dataset Kvasir-SEG contiene 1000 imágenes de pólipos gastrointestinales con sus máscaras de segmentación ground-truth generadas por gastroenterólogos expertos. Los pólipos son precursores del cáncer colorrectal, por lo que su detección precoz es fundamental.

Descargaremos el dataset y lo organizaremos en conjuntos de entrenamiento y test.

[2]
# Descarga del dataset Kvasir-SEG
import ssl

DATA_DIR = '/tmp/kvasir_seg'
os.makedirs(DATA_DIR, exist_ok=True)

ZIP_URL = 'https://datasets.simula.no/downloads/kvasir-seg.zip'
ZIP_PATH = os.path.join(DATA_DIR, 'kvasir-seg.zip')

# Crear contexto SSL sin verificación (necesario en algunos entornos)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE

if not os.path.exists(os.path.join(DATA_DIR, 'Kvasir-SEG')):
    if not os.path.exists(ZIP_PATH):
        print('Descargando Kvasir-SEG dataset (~46 MB)...')
        with urllib.request.urlopen(ZIP_URL, context=ssl_context) as response, \
             open(ZIP_PATH, 'wb') as out_file:
            out_file.write(response.read())
        print('Descarga completada.')
    
    print('Extrayendo archivos...')
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)
    print('Extracción completada.')
else:
    print('Dataset ya descargado.')

# Rutas de imágenes y máscaras
IMAGES_DIR = os.path.join(DATA_DIR, 'Kvasir-SEG', 'images')
MASKS_DIR = os.path.join(DATA_DIR, 'Kvasir-SEG', 'masks')

image_files = sorted(glob.glob(os.path.join(IMAGES_DIR, '*.jpg')))
mask_files = sorted(glob.glob(os.path.join(MASKS_DIR, '*.jpg')))

print(f'\nImágenes encontradas: {len(image_files)}')
print(f'Máscaras encontradas: {len(mask_files)}')
Dataset ya descargado.

Imágenes encontradas: 1000
Máscaras encontradas: 1000
[3]
# Visualizamos algunas muestras del dataset
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i in range(3):
    idx = i * 100  # Muestreamos cada 100 imágenes
    img = np.array(Image.open(image_files[idx]))
    mask = np.array(Image.open(mask_files[idx]).convert('L'))
    
    axes[i, 0].imshow(img)
    axes[i, 0].set_title('Imagen original' if i == 0 else '')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(mask, cmap='gray')
    axes[i, 1].set_title('Máscara GT' if i == 0 else '')
    axes[i, 1].axis('off')
    
    # Superposición
    axes[i, 2].imshow(img)
    axes[i, 2].imshow(mask, cmap='Reds', alpha=0.4)
    axes[i, 2].set_title('Superposición' if i == 0 else '')
    axes[i, 2].axis('off')
    
    # Contorno
    contour = np.zeros_like(mask)
    import cv2
    contours, _ = cv2.findContours((mask > 127).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(contour, contours, -1, 255, 2)
    axes[i, 3].imshow(img)
    axes[i, 3].imshow(contour, cmap='Greens', alpha=0.7)
    axes[i, 3].set_title('Contorno del pólipo' if i == 0 else '')
    axes[i, 3].axis('off')

plt.suptitle('Dataset Kvasir-SEG: segmentación de pólipos gastrointestinales', fontsize=14)
plt.tight_layout()
plt.show()
Output

2) Dataset y DataLoader en PyTorch

Creamos un Dataset personalizado que:

  • Redimensiona las imágenes a $256 \times 256$ píxeles (eficiencia computacional).
  • Normaliza la imagen al rango $[0, 1]$.
  • Convierte la máscara a binaria (pólipo vs. fondo).
  • Aplica data augmentation (flips y rotaciones) durante el entrenamiento.
[4]
IMG_SIZE = 256


class KvasirSegDataset(Dataset):
    """Dataset para segmentación de pólipos Kvasir-SEG."""
    
    def __init__(self, image_paths, mask_paths, augment=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.augment = augment
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Cargamos imagen y máscara
        image = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')
        
        # Redimensionamos
        image = TF.resize(image, [IMG_SIZE, IMG_SIZE])
        mask = TF.resize(mask, [IMG_SIZE, IMG_SIZE], interpolation=Image.NEAREST)
        
        # Data augmentation
        if self.augment:
            if torch.rand(1) > 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
            if torch.rand(1) > 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)
            if torch.rand(1) > 0.5:
                angle = float(torch.randint(-30, 30, (1,)))
                image = TF.rotate(image, angle)
                mask = TF.rotate(mask, angle)
        
        # Convertimos a tensores
        image = TF.to_tensor(image)  # [3, H, W], rango [0, 1]
        mask = TF.to_tensor(mask)    # [1, H, W], rango [0, 1]
        mask = (mask > 0.5).float()  # Binarizamos
        
        return image, mask


# División train/val/test (70/15/15)
train_imgs, test_imgs, train_masks, test_masks = train_test_split(
    image_files, mask_files, test_size=0.15, random_state=42
)
train_imgs, val_imgs, train_masks, val_masks = train_test_split(
    train_imgs, train_masks, test_size=0.176, random_state=42  # 0.176 de 0.85 ≈ 0.15 total
)

train_dataset = KvasirSegDataset(train_imgs, train_masks, augment=True)
val_dataset = KvasirSegDataset(val_imgs, val_masks, augment=False)
test_dataset = KvasirSegDataset(test_imgs, test_masks, augment=False)

BATCH_SIZE = 8
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f'Conjunto de entrenamiento: {len(train_dataset)} imágenes')
print(f'Conjunto de validación: {len(val_dataset)} imágenes')
print(f'Conjunto de test: {len(test_dataset)} imágenes')

# Verificamos un batch
imgs, masks = next(iter(train_loader))
print(f'\nBatch shape - imagen: {imgs.shape}, máscara: {masks.shape}')
print(f'Rango imagen: [{imgs.min():.2f}, {imgs.max():.2f}]')
print(f'Valores únicos máscara: {masks.unique().tolist()}')
Conjunto de entrenamiento: 700 imágenes
Conjunto de validación: 150 imágenes
Conjunto de test: 150 imágenes

Batch shape - imagen: torch.Size([8, 3, 256, 256]), máscara: torch.Size([8, 1, 256, 256])
Rango imagen: [0.00, 1.00]
Valores únicos máscara: [0.0, 1.0]

3) Implementación de U-Net en PyTorch

Implementamos la arquitectura U-Net completa. La estructura clave es:

  • DoubleConv: bloque básico = Conv3×3 → BatchNorm → ReLU → Conv3×3 → BatchNorm → ReLU
  • Down: DoubleConv + MaxPool2d (reduce resolución a la mitad)
  • Up: ConvTranspose2d (duplica resolución) + concatenación con skip + DoubleConv
  • Capa final: Conv1×1 que reduce los canales al número de clases (1 para segmentación binaria)

La implementación sigue fielmente la arquitectura original del paper.

[5]
class DoubleConv(nn.Module):
    """Bloque básico: (Conv3x3 -> BN -> ReLU) x 2"""
    
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    """
    U-Net para segmentación binaria.
    
    Arquitectura:
        Encoder: 4 niveles de downsampling
        Bottleneck: bloque convolucional central
        Decoder: 4 niveles de upsampling con skip connections
        Output: Conv 1x1 → Sigmoid
    """
    
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Encoder (camino de contracción)
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        
        # Decoder (camino de expansión)
        for feature in reversed(features):
            # ConvTranspose para upsampling
            self.ups.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            # DoubleConv después de concatenar
            self.ups.append(DoubleConv(feature * 2, feature))
        
        # Capa final: 1x1 conv
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        skip_connections = []
        
        # Encoder
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder (skip connections en orden inverso)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)  # Upsampling
            skip = skip_connections[idx // 2]
            
            # Ajuste de tamaño si es necesario (por resoluciones impares)
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
            
            x = torch.cat([skip, x], dim=1)  # Concatenación (skip connection)
            x = self.ups[idx + 1](x)         # DoubleConv
        
        return torch.sigmoid(self.final_conv(x))


# Instanciamos el modelo
model = UNet(in_channels=3, out_channels=1).to(device)

# Resumen del modelo
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Parámetros totales: {total_params:,}')
print(f'Parámetros entrenables: {trainable_params:,}')

# Verificamos con una entrada dummy
dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
output = model(dummy)
print(f'\nEntrada: {dummy.shape}')
print(f'Salida: {output.shape}')
print(f'Rango salida: [{output.min():.4f}, {output.max():.4f}]')
Parámetros totales: 31,037,633
Parámetros entrenables: 31,037,633

Entrada: torch.Size([1, 3, 256, 256])
Salida: torch.Size([1, 1, 256, 256])
Rango salida: [0.2053, 0.8298]

4) Funciones de pérdida y métricas

Definimos:

  • Dice Loss: penaliza directamente la discrepancia entre predicción y ground truth, ignorando el desbalanceo de clases.
  • BCE + Dice Loss combinada: aprovecha las ventajas de ambas.
  • Métricas de evaluación: Dice Score (equivalente al F1), IoU (Intersection over Union), Precision, Recall.
[6]
class DiceLoss(nn.Module):
    """Dice Loss para segmentación binaria."""
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        intersection = (pred_flat * target_flat).sum()
        return 1 - (2. * intersection + self.smooth) / (
            pred_flat.sum() + target_flat.sum() + self.smooth
        )


class BCEDiceLoss(nn.Module):
    """Combinación de BCE y Dice Loss."""
    def __init__(self, bce_weight=0.5):
        super().__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
    
    def forward(self, pred, target):
        return self.bce_weight * self.bce(pred, target) + (1 - self.bce_weight) * self.dice(pred, target)


def compute_metrics(pred, target, threshold=0.5):
    """Calcula métricas de segmentación binaria."""
    pred_bin = (pred > threshold).float()
    
    # Flatten
    pred_flat = pred_bin.view(-1)
    target_flat = target.view(-1)
    
    tp = (pred_flat * target_flat).sum()
    fp = (pred_flat * (1 - target_flat)).sum()
    fn = ((1 - pred_flat) * target_flat).sum()
    
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    dice = 2 * tp / (2 * tp + fp + fn + 1e-8)
    iou = tp / (tp + fp + fn + 1e-8)
    
    return {
        'dice': dice.item(),
        'iou': iou.item(),
        'precision': precision.item(),
        'recall': recall.item()
    }

print('Funciones de pérdida y métricas definidas.')
Funciones de pérdida y métricas definidas.

5) Entrenamiento del modelo

Entrenamos la U-Net con los siguientes hiperparámetros:

  • Optimizador: Adam con learning rate $10^{-4}$
  • Scheduler: ReduceLROnPlateau (reduce el LR si la pérdida de validación no mejora)
  • Épocas: 25
  • Loss: BCE + Dice Loss combinada

Monitorizamos el Dice Score en validación para early stopping.

[7]
# Configuración del entrenamiento
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4

criterion = BCEDiceLoss(bce_weight=0.5)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

# Almacenamos métricas
history = {
    'train_loss': [], 'val_loss': [],
    'train_dice': [], 'val_dice': [],
    'train_iou': [], 'val_iou': []
}

best_val_dice = 0.0
best_model_state = None

print(f'Entrenando U-Net durante {NUM_EPOCHS} épocas...')
print(f'Dispositivo: {device}')
print(f'Tamaño del batch: {BATCH_SIZE}')
print(f'Learning rate inicial: {LEARNING_RATE}')
print('-' * 70)

for epoch in range(NUM_EPOCHS):
    # ---------- ENTRENAMIENTO ----------
    model.train()
    train_loss = 0.0
    train_metrics = {'dice': 0, 'iou': 0}
    
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        metrics = compute_metrics(outputs, masks)
        train_metrics['dice'] += metrics['dice']
        train_metrics['iou'] += metrics['iou']
    
    n_train = len(train_loader)
    train_loss /= n_train
    train_metrics = {k: v / n_train for k, v in train_metrics.items()}
    
    # ---------- VALIDACIÓN ----------
    model.eval()
    val_loss = 0.0
    val_metrics = {'dice': 0, 'iou': 0}
    
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            val_loss += loss.item()
            metrics = compute_metrics(outputs, masks)
            val_metrics['dice'] += metrics['dice']
            val_metrics['iou'] += metrics['iou']
    
    n_val = len(val_loader)
    val_loss /= n_val
    val_metrics = {k: v / n_val for k, v in val_metrics.items()}
    
    # Scheduler
    scheduler.step(val_loss)
    
    # Guardamos métricas
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_dice'].append(train_metrics['dice'])
    history['val_dice'].append(val_metrics['dice'])
    history['train_iou'].append(train_metrics['iou'])
    history['val_iou'].append(val_metrics['iou'])
    
    # Guardamos el mejor modelo
    if val_metrics['dice'] > best_val_dice:
        best_val_dice = val_metrics['dice']
        best_model_state = model.state_dict().copy()
    
    # Log cada 5 épocas
    if (epoch + 1) % 5 == 0 or epoch == 0:
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Época {epoch+1:3d}/{NUM_EPOCHS} | '
              f'Loss: train={train_loss:.4f} val={val_loss:.4f} | '
              f'Dice: train={train_metrics["dice"]:.4f} val={val_metrics["dice"]:.4f} | '
              f'IoU: train={train_metrics["iou"]:.4f} val={val_metrics["iou"]:.4f} | '
              f'LR={current_lr:.2e}')

print('-' * 70)
print(f'Mejor Dice en validación: {best_val_dice:.4f}')

# Cargamos el mejor modelo
if best_model_state is not None:
    model.load_state_dict(best_model_state)
Entrenando U-Net durante 10 épocas...
Dispositivo: cuda
Tamaño del batch: 8
Learning rate inicial: 0.0001
----------------------------------------------------------------------
Época   1/10 | Loss: train=0.5944 val=0.5610 | Dice: train=0.4645 val=0.4908 | IoU: train=0.3086 val=0.3287 | LR=1.00e-04
Época   5/10 | Loss: train=0.4155 val=0.4361 | Dice: train=0.6305 val=0.5586 | IoU: train=0.4670 val=0.3991 | LR=1.00e-04
Época  10/10 | Loss: train=0.3057 val=0.3456 | Dice: train=0.7292 val=0.6500 | IoU: train=0.5787 val=0.4963 | LR=1.00e-04
----------------------------------------------------------------------
Mejor Dice en validación: 0.6500

6) Curvas de entrenamiento

Visualizamos la evolución de la pérdida, Dice Score e IoU a lo largo del entrenamiento. Un buen entrenamiento muestra:

  • Pérdida decreciente en ambos conjuntos (sin divergencia = sin overfitting severo).
  • Dice Score e IoU crecientes y estabilizándose.
[8]
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
epochs_range = range(1, NUM_EPOCHS + 1)

# Loss
axes[0].plot(epochs_range, history['train_loss'], 'b-', label='Train', linewidth=2)
axes[0].plot(epochs_range, history['val_loss'], 'r-', label='Validación', linewidth=2)
axes[0].set_title('Pérdida (BCE + Dice)', fontsize=13)
axes[0].set_xlabel('Época')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Dice Score
axes[1].plot(epochs_range, history['train_dice'], 'b-', label='Train', linewidth=2)
axes[1].plot(epochs_range, history['val_dice'], 'r-', label='Validación', linewidth=2)
axes[1].set_title('Dice Score', fontsize=13)
axes[1].set_xlabel('Época')
axes[1].set_ylabel('Dice')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# IoU
axes[2].plot(epochs_range, history['train_iou'], 'b-', label='Train', linewidth=2)
axes[2].plot(epochs_range, history['val_iou'], 'r-', label='Validación', linewidth=2)
axes[2].set_title('IoU (Intersection over Union)', fontsize=13)
axes[2].set_xlabel('Época')
axes[2].set_ylabel('IoU')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('Curvas de entrenamiento U-Net', fontsize=15)
plt.tight_layout()
plt.show()
Output

7) Evaluación en el conjunto de test

Evaluamos el modelo con los datos que no ha visto durante el entrenamiento. Calculamos todas las métricas de segmentación relevantes.

[9]
# Evaluación final en test
model.eval()
test_metrics_accum = {'dice': 0, 'iou': 0, 'precision': 0, 'recall': 0}
n_batches = 0

all_preds = []
all_targets = []
all_images = []

with torch.no_grad():
    for images, masks in test_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        
        metrics = compute_metrics(outputs, masks)
        for k in test_metrics_accum:
            test_metrics_accum[k] += metrics[k]
        n_batches += 1
        
        all_images.append(images.cpu())
        all_preds.append(outputs.cpu())
        all_targets.append(masks.cpu())

test_metrics = {k: v / n_batches for k, v in test_metrics_accum.items()}

print('Métricas en el conjunto de TEST:')
print(f'  Dice Score:  {test_metrics["dice"]:.4f}')
print(f'  IoU:         {test_metrics["iou"]:.4f}')
print(f'  Precision:   {test_metrics["precision"]:.4f}')
print(f'  Recall:      {test_metrics["recall"]:.4f}')

# Concatenamos todo
all_images = torch.cat(all_images, dim=0)
all_preds = torch.cat(all_preds, dim=0)
all_targets = torch.cat(all_targets, dim=0)
Métricas en el conjunto de TEST:
  Dice Score:  0.6858
  IoU:         0.5281
  Precision:   0.7964
  Recall:      0.6252

8) Visualización de predicciones

Comparamos las predicciones del modelo con el ground truth para varias imágenes del test set. Cada fila muestra: imagen original → máscara ground truth → predicción del modelo → superposición de la predicción sobre la imagen.

[10]
# Visualizamos predicciones en test
n_samples = 6
indices = np.linspace(0, len(all_images) - 1, n_samples, dtype=int)

fig, axes = plt.subplots(n_samples, 4, figsize=(16, n_samples * 3.5))
col_titles = ['Imagen original', 'Ground Truth', 'Predicción U-Net', 'Superposición']

for row, idx in enumerate(indices):
    img = all_images[idx].permute(1, 2, 0).numpy()
    gt = all_targets[idx, 0].numpy()
    pred = all_preds[idx, 0].numpy()
    pred_bin = (pred > 0.5).astype(float)
    
    # Dice individual
    dice_i = 2 * (pred_bin * gt).sum() / (pred_bin.sum() + gt.sum() + 1e-8)
    
    axes[row, 0].imshow(img)
    axes[row, 0].axis('off')
    
    axes[row, 1].imshow(gt, cmap='gray')
    axes[row, 1].axis('off')
    
    axes[row, 2].imshow(pred_bin, cmap='gray')
    axes[row, 2].set_title(f'Dice={dice_i:.3f}', fontsize=10)
    axes[row, 2].axis('off')
    
    # Superposición: verde=TP, rojo=FP, azul=FN
    overlay = img.copy()
    tp_mask = (pred_bin == 1) & (gt == 1)
    fp_mask = (pred_bin == 1) & (gt == 0)
    fn_mask = (pred_bin == 0) & (gt == 1)
    overlay[tp_mask] = overlay[tp_mask] * 0.5 + np.array([0, 1, 0]) * 0.5  # Verde: acierto
    overlay[fp_mask] = overlay[fp_mask] * 0.5 + np.array([1, 0, 0]) * 0.5  # Rojo: falso positivo
    overlay[fn_mask] = overlay[fn_mask] * 0.5 + np.array([0, 0, 1]) * 0.5  # Azul: falso negativo
    axes[row, 3].imshow(overlay)
    axes[row, 3].axis('off')
    
    if row == 0:
        for col, title in enumerate(col_titles):
            axes[row, col].set_title(title, fontsize=12)

plt.suptitle('Predicciones U-Net en el conjunto de test\nVerde=TP, Rojo=FP, Azul=FN', fontsize=14)
plt.tight_layout()
plt.show()
Output

9) Análisis del efecto de las skip connections

Las skip connections son el ingrediente fundamental de U-Net. Veamos qué ocurre cuando visualizamos los feature maps del encoder y cómo la información se transfiere al decoder.

Registramos los feature maps intermedios en una pasada forward.

[11]
# Visualizamos feature maps del encoder
model.eval()

# Tomamos una imagen de ejemplo
sample_img = all_images[0:1].to(device)

# Extraemos feature maps del encoder
encoder_features = []
x = sample_img
with torch.no_grad():
    for down in model.downs:
        x = down(x)
        encoder_features.append(x.cpu())
        x = model.pool(x)

# Visualizamos el primer canal de cada nivel del encoder
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

axes[0].imshow(sample_img[0].cpu().permute(1, 2, 0))
axes[0].set_title(f'Imagen original\n{sample_img.shape[2]}×{sample_img.shape[3]}', fontsize=11)
axes[0].axis('off')

for i, feat in enumerate(encoder_features):
    # Promedio de los canales para visualizar
    feat_map = feat[0].mean(dim=0).numpy()
    axes[i + 1].imshow(feat_map, cmap='viridis')
    axes[i + 1].set_title(f'Encoder nivel {i+1}\n{feat.shape[2]}×{feat.shape[3]}, {feat.shape[1]} canales', fontsize=11)
    axes[i + 1].axis('off')

plt.suptitle('Feature maps del encoder: de resolución alta a baja', fontsize=14)
plt.tight_layout()
plt.show()

print('Los niveles más profundos capturan semántica (qué es un pólipo).')
print('Los niveles superficiales capturan detalles (bordes, texturas).')
print('Las skip connections combinan ambos → segmentación precisa.')
Output
Los niveles más profundos capturan semántica (qué es un pólipo).
Los niveles superficiales capturan detalles (bordes, texturas).
Las skip connections combinan ambos → segmentación precisa.

10) Distribución de Dice Scores por imagen

Analizamos cómo varía la calidad de segmentación entre imágenes individuales del test set.

[12]
# Calculamos Dice score por imagen individual
individual_dices = []

for i in range(len(all_preds)):
    pred_i = (all_preds[i, 0] > 0.5).float().numpy()
    gt_i = all_targets[i, 0].numpy()
    dice_i = 2 * (pred_i * gt_i).sum() / (pred_i.sum() + gt_i.sum() + 1e-8)
    individual_dices.append(dice_i)

individual_dices = np.array(individual_dices)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histograma
axes[0].hist(individual_dices, bins=25, color='steelblue', edgecolor='white', alpha=0.8)
axes[0].axvline(x=individual_dices.mean(), color='red', linestyle='--', linewidth=2,
                label=f'Media = {individual_dices.mean():.3f}')
axes[0].axvline(x=np.median(individual_dices), color='orange', linestyle='--', linewidth=2,
                label=f'Mediana = {np.median(individual_dices):.3f}')
axes[0].set_title('Distribución de Dice Score por imagen', fontsize=12)
axes[0].set_xlabel('Dice Score')
axes[0].set_ylabel('Número de imágenes')
axes[0].legend()

# Sorted plot
sorted_dices = np.sort(individual_dices)
colors = ['crimson' if d < 0.5 else 'orange' if d < 0.7 else 'forestgreen' for d in sorted_dices]
axes[1].bar(range(len(sorted_dices)), sorted_dices, color=colors, width=1)
axes[1].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
axes[1].axhline(y=0.7, color='gray', linestyle=':', alpha=0.5)
axes[1].set_title('Dice Score ordenado por imagen', fontsize=12)
axes[1].set_xlabel('Imagen (ordenada)')
axes[1].set_ylabel('Dice Score')

plt.tight_layout()
plt.show()

print(f'\nEstadísticas del Dice Score por imagen:')
print(f'  Media:    {individual_dices.mean():.4f}')
print(f'  Mediana:  {np.median(individual_dices):.4f}')
print(f'  Std:      {individual_dices.std():.4f}')
print(f'  Min:      {individual_dices.min():.4f}')
print(f'  Max:      {individual_dices.max():.4f}')
print(f'  % > 0.7:  {(individual_dices > 0.7).mean():.1%}')
print(f'  % > 0.5:  {(individual_dices > 0.5).mean():.1%}')
Output
Estadísticas del Dice Score por imagen:
  Media:    0.6633
  Mediana:  0.7497
  Std:      0.2809
  Min:      0.0000
  Max:      0.9638
  % > 0.7:  56.0%
  % > 0.5:  78.0%

Conclusiones

Resumen del experimento

  1. U-Net es altamente efectiva para segmentación de imagen médica, incluso con una implementación estándar sin técnicas avanzadas. Su diseño con skip connections permite segmentaciones con bordes nítidos.

  2. El dataset Kvasir-SEG presenta desafíos reales de segmentación médica: pólipos de diferentes tamaños, formas irregulares, variabilidad de iluminación y contraste con el tejido circundante.

  3. La pérdida Dice + BCE es más estable que usar solo BCE, ya que Dice Loss aborda directamente el desbalanceo entre fondo y primer plano (los pólipos suelen ser pequeños respecto a la imagen total).

  4. Las skip connections permiten que el decoder combine información semántica de alto nivel ("esto es un pólipo") con detalles espaciales de bajo nivel ("el borde exacto está aquí"). Sin ellas, la segmentación sería mucho más borrosa.

  5. La variabilidad entre imágenes es notable: algunos pólipos son fáciles de segmentar (bien definidos, gran contraste) y otros son difíciles (planos, poco contraste, parcialmente ocultos).

Posibles mejoras

  • Aumentar la resolución de entrada (512×512) para capturar más detalles finos.
  • Usar un encoder preentrenado (ResNet, EfficientNet) en lugar del encoder desde cero.
  • Aplicar más data augmentation: elastic deformations, color jitter, cutout.
  • Probar arquitecturas más avanzadas: Attention U-Net, U-Net++, TransUNet.
  • Usar test-time augmentation (TTA) para predicciones más robustas.

La segmentación de imagen médica sigue siendo una de las aplicaciones más impactantes del deep learning. U-Net demostró que con la arquitectura correcta, se puede lograr segmentación precisa incluso con conjuntos de datos relativamente pequeños.