BERT — NER en textos científicos y mapas de atención
Reconocimiento de entidades biomédicas con BERT, evaluación con gold set, tests de robustez e inspección visual de mapas de self-attention por capa y cabeza.
BERT para NER en textos científicos: interpretación con mapas de atención
Objetivo del notebook
Este notebook tiene un doble propósito didáctico dentro del submódulo LLMs del módulo de IA Generativa:
- Aplicar un modelo tipo BERT a una tarea real de NLP: Reconocimiento de Entidades Nombradas (NER) en texto científico (en este caso, biomédico).
- Entender cómo “piensa” el modelo analizando sus mapas de atención y conectando los resultados con los fundamentos teóricos vistos en el submódulo (Transformer encoder, bidireccionalidad, preentrenamiento/fine-tuning y atención multi-cabeza).
La idea central es pasar de una visión “caja negra” a una visión más interpretativa: no solo queremos ver qué etiqueta predice el modelo, sino también dónde pone atención cuando toma esa decisión.
Modelos y datasets que vamos a usar
- Modelo de NER:
d4data/biomedical-ner-all- Arquitectura base: BERT-like encoder para token classification.
- Dominio: textos biomédicos/científicos.
- Tokenizer: tokenizer asociado al modelo (subpalabras tipo WordPiece/BPE según checkpoint).
- Textos de ejemplo: varios fragmentos científicos (resúmenes y frases biomédicas).
- Mini conjunto de evaluación manual: pequeño gold set creado en el notebook para medir precisión/recobrado/F1 de forma pedagógica.
Nota: en producción conviene evaluar sobre benchmarks estandarizados (BC5CDR, NCBI Disease, JNLPBA, etc.) y definir claramente esquema de etiquetas y criterio de matching de spans.
Fundamentos matemáticos y computacionales (resumen riguroso y accesible)
1) Tokenización en subpalabras
Dado un texto (x), el tokenizer lo divide en tokens/subtokens: [ x \rightarrow (t_1, t_2, \dots, t_n) ] Esto ayuda a manejar vocabularios abiertos y términos científicos raros (por ejemplo, interleukin-6 o SARS-CoV-2).
2) Embeddings de entrada en BERT
Para cada posición (i), BERT suma: [ \mathbf{e}_i = \mathbf{e}^{token}_i + \mathbf{e}^{position}_i + \mathbf{e}^{segment}_i ] Luego aplica capas Transformer encoder.
3) Auto-atención escalada (self-attention)
En cada cabeza de atención: [ \mathrm{Attention}(Q,K,V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]
- (Q): queries (qué busco)
- (K): keys (qué ofrezco)
- (V): values (qué información paso)
Cada token mezcla contexto de todos los demás tokens, por eso BERT es bidireccional (usa izquierda y derecha).
4) Multi-head attention
En lugar de una sola atención, hay (h) cabezas: [ \mathrm{MHA}(X)=\mathrm{Concat}(head_1,\dots,head_h)W^O ] Cada cabeza puede capturar relaciones diferentes (sintácticas, semánticas, dependencias largas, etc.).
5) Capa de salida para NER
En token classification, cada representación final (\mathbf{h}_i) pasa por una capa lineal:
[
\mathbf{z}_i = W\mathbf{h}_i + b, \quad p(y_i|x)=\mathrm{softmax}(\mathbf{z}_i)
]
Se predice una etiqueta BIO por token/subtoken (por ejemplo, B-Disease, I-Chemical, O).
6) Función de pérdida (fine-tuning)
Normalmente se usa entropía cruzada por token (ignorando padding): [ \mathcal{L}=-\sum_i \log p(y_i^{\star}|x) ]
7) Limitación importante de interpretabilidad
La atención es una señal útil de análisis, pero “atención ≠ explicación causal perfecta”. Aun así, es una herramienta pedagógica excelente para estudiar patrones internos del modelo.
1) Instalación de dependencias (si hace falta)
Si trabajas en local y te falta alguna librería, descomenta la siguiente celda.
# !pip install -q transformers datasets seqeval matplotlib seaborn pandas numpy torch
2) Imports y configuración
# Imports principales para el notebook
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
# Configuración visual para gráficas
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (12, 5)
# Semilla para reproducibilidad básica
torch.manual_seed(42)
np.random.seed(42)
# Selección de dispositivo (CPU/GPU)
device = 0 if torch.cuda.is_available() else -1
print("Usando GPU" if device == 0 else "Usando CPU")
/home/nuberu/xuan/naux/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Usando GPU
3) Cargar modelo BERT de NER biomédico
# Nombre del checkpoint en Hugging Face
MODEL_NAME = "d4data/biomedical-ner-all"
# Carga de tokenizer y modelo
# output_attentions=True permite extraer mapas de atención para análisis posterior
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, output_attentions=True)
# Pipeline de NER con agregación por palabra (une subtokens)
ner_pipe = pipeline(
task="token-classification",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
device=device,
)
print("Modelo cargado correctamente:", MODEL_NAME)
print("Número de etiquetas:", model.config.num_labels)
print("Ejemplo de etiquetas:", list(model.config.id2label.items())[:8])
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads. Loading weights: 100%|██████████| 102/102 [00:00<00:00, 15543.49it/s]
Modelo cargado correctamente: d4data/biomedical-ner-all Número de etiquetas: 84 Ejemplo de etiquetas: [(0, 'O'), (1, 'B-Activity'), (2, 'B-Administration'), (3, 'B-Age'), (4, 'B-Area'), (5, 'B-Biological_attribute'), (6, 'B-Biological_structure'), (7, 'B-Clinical_event')]
4) Exploración de tokenización en lenguaje científico
# Texto científico de ejemplo para entender cómo se trocea en subtokens
sample_text = "Interleukin-6 levels were significantly elevated in COVID-19 patients treated with dexamethasone and remdesivir."
# Tokenización con offsets para alinear tokens con texto original
enc = tokenizer(sample_text, return_offsets_mapping=True, truncation=True)
tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"])
offsets = enc["offset_mapping"]
# Visualización tabular
tok_df = pd.DataFrame({"token": tokens, "offset": offsets})
print("Tokenización de muestra:")
display(tok_df.head(30))
# Chequeo simple
assert len(tokens) == len(offsets), "Tokens y offsets deben tener misma longitud"
Tokenización de muestra:
| token | offset | |
|---|---|---|
| 0 | [CLS] | (0, 0) |
| 1 | inter | (0, 5) |
| 2 | ##le | (5, 7) |
| 3 | ##uki | (7, 10) |
| 4 | ##n | (10, 11) |
| 5 | - | (11, 12) |
| 6 | 6 | (12, 13) |
| 7 | levels | (14, 20) |
| 8 | were | (21, 25) |
| 9 | significantly | (26, 39) |
| 10 | elevated | (40, 48) |
| 11 | in | (49, 51) |
| 12 | co | (52, 54) |
| 13 | ##vid | (54, 57) |
| 14 | - | (57, 58) |
| 15 | 19 | (58, 60) |
| 16 | patients | (61, 69) |
| 17 | treated | (70, 77) |
| 18 | with | (78, 82) |
| 19 | dex | (83, 86) |
| 20 | ##ame | (86, 89) |
| 21 | ##tha | (89, 92) |
| 22 | ##son | (92, 95) |
| 23 | ##e | (95, 96) |
| 24 | and | (97, 100) |
| 25 | re | (101, 103) |
| 26 | ##md | (103, 105) |
| 27 | ##es | (105, 107) |
| 28 | ##iv | (107, 109) |
| 29 | ##ir | (109, 111) |
Comentario didáctico
Observa que algunos términos técnicos se dividen en subpartes. Esto es normal y permite al modelo manejar vocabulario especializado sin necesitar un token único para cada palabra científica posible.
5) Inferencia NER en varios textos científicos
# Conjunto de textos científicos de ejemplo
scientific_texts = [
"BRCA1 mutations are associated with increased risk of breast cancer and ovarian cancer.",
"Patients with rheumatoid arthritis were treated with methotrexate and infliximab.",
"CRISPR-Cas9 editing reduced expression of TNF-alpha in human macrophages.",
"The trial evaluated efficacy of pembrolizumab in metastatic melanoma.",
]
all_results = []
for i, text in enumerate(scientific_texts, start=1):
preds = ner_pipe(text)
for p in preds:
all_results.append({
"text_id": i,
"text": text,
"entity": p["entity_group"],
"word": p["word"],
"score": round(float(p["score"]), 4),
"start": p["start"],
"end": p["end"],
})
results_df = pd.DataFrame(all_results)
results_df.head(20)
| text_id | text | entity | word | score | start | end | |
|---|---|---|---|---|---|---|---|
| 0 | 1 | BRCA1 mutations are associated with increased ... | Diagnostic_procedure | br | 0.9982 | 0 | 2 |
| 1 | 1 | BRCA1 mutations are associated with increased ... | Coreference | ##ca1 mutations | 0.7625 | 2 | 15 |
| 2 | 2 | Patients with rheumatoid arthritis were treate... | Disease_disorder | rheumatoid arthritis | 0.9106 | 14 | 34 |
| 3 | 2 | Patients with rheumatoid arthritis were treate... | Medication | met | 0.9999 | 53 | 56 |
| 4 | 2 | Patients with rheumatoid arthritis were treate... | Medication | ##hot | 0.9698 | 56 | 59 |
| 5 | 2 | Patients with rheumatoid arthritis were treate... | Medication | ##re | 0.9948 | 59 | 61 |
| 6 | 2 | Patients with rheumatoid arthritis were treate... | Medication | ##xate | 0.7508 | 61 | 65 |
| 7 | 2 | Patients with rheumatoid arthritis were treate... | Medication | in | 0.9998 | 70 | 72 |
| 8 | 2 | Patients with rheumatoid arthritis were treate... | Medication | ##fl | 0.9995 | 72 | 74 |
| 9 | 2 | Patients with rheumatoid arthritis were treate... | Medication | ##ix | 0.9994 | 74 | 76 |
| 10 | 2 | Patients with rheumatoid arthritis were treate... | Medication | ##ima | 0.9995 | 76 | 79 |
| 11 | 3 | CRISPR-Cas9 editing reduced expression of TNF-... | Diagnostic_procedure | crispr - cas9 editing | 0.9997 | 0 | 19 |
| 12 | 3 | CRISPR-Cas9 editing reduced expression of TNF-... | Lab_value | reduced expression | 0.9981 | 20 | 38 |
| 13 | 3 | CRISPR-Cas9 editing reduced expression of TNF-... | Diagnostic_procedure | tn | 0.9780 | 42 | 44 |
| 14 | 3 | CRISPR-Cas9 editing reduced expression of TNF-... | Lab_value | ##f - alpha | 0.7226 | 44 | 51 |
| 15 | 4 | The trial evaluated efficacy of pembrolizumab ... | Medication | pe | 0.9998 | 32 | 34 |
| 16 | 4 | The trial evaluated efficacy of pembrolizumab ... | Medication | ##mb | 0.9980 | 34 | 36 |
| 17 | 4 | The trial evaluated efficacy of pembrolizumab ... | Medication | ##rolizumab | 0.9681 | 36 | 45 |
# Resumen de entidades detectadas por tipo
if not results_df.empty:
summary = (
results_df.groupby("entity")
.agg(total=("entity", "count"), score_medio=("score", "mean"))
.sort_values("total", ascending=False)
)
display(summary)
else:
print("No se detectaron entidades en los textos de ejemplo.")
| total | score_medio | |
|---|---|---|
| entity | ||
| Medication | 11 | 0.970855 |
| Diagnostic_procedure | 3 | 0.991967 |
| Lab_value | 2 | 0.860350 |
| Coreference | 1 | 0.762500 |
| Disease_disorder | 1 | 0.910600 |
6) Visualización rápida de entidades en contexto
# Función para resaltar entidades con etiquetas y score
def render_entities(text, preds):
# Ordenamos por inicio para pintar de izquierda a derecha
preds = sorted(preds, key=lambda x: x["start"])
out = ""
cursor = 0
for p in preds:
start, end = p["start"], p["end"]
out += text[cursor:start]
chunk = text[start:end]
tag = f"[{chunk} | {p['entity_group']} | {p['score']:.2f}]"
out += tag
cursor = end
out += text[cursor:]
return out
for i, text in enumerate(scientific_texts, start=1):
preds = ner_pipe(text)
print(f"\nTexto {i}:")
print(render_entities(text, preds))
Texto 1: [BR | Diagnostic_procedure | 1.00][CA1 mutations | Coreference | 0.76] are associated with increased risk of breast cancer and ovarian cancer. Texto 2: Patients with [rheumatoid arthritis | Disease_disorder | 0.91] were treated with [met | Medication | 1.00][hot | Medication | 0.97][re | Medication | 0.99][xate | Medication | 0.75] and [in | Medication | 1.00][fl | Medication | 1.00][ix | Medication | 1.00][ima | Medication | 1.00]b. Texto 3: [CRISPR-Cas9 editing | Diagnostic_procedure | 1.00] [reduced expression | Lab_value | 1.00] of [TN | Diagnostic_procedure | 0.98][F-alpha | Lab_value | 0.72] in human macrophages. Texto 4: The trial evaluated efficacy of [pe | Medication | 1.00][mb | Medication | 1.00][rolizumab | Medication | 0.97] in metastatic melanoma.
7) Atención: extracción y mapas (layer/head)
Vamos a observar cómo se distribuye la atención de un token objetivo hacia el resto.
# Texto para estudiar atención
attention_text = "Dexamethasone reduces mortality in severe COVID-19 pneumonia patients."
# Preparar tensores de entrada y moverlos al dispositivo del modelo
inputs = tokenizer(attention_text, return_tensors="pt", truncation=True)
model_device = next(model.parameters()).device
inputs = {k: v.to(model_device) for k, v in inputs.items()}
# Inferencia sin gradientes para obtener attentions
with torch.no_grad():
outputs = model(**inputs)
# outputs.attentions: tupla de longitud num_layers
# Cada elemento: [batch_size, num_heads, seq_len, seq_len]
attentions = outputs.attentions
print("Número de capas:", len(attentions))
print("Shape capa 0:", tuple(attentions[0].shape))
# Tokens legibles
toks = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
print("Tokens:", toks)
Número de capas: 6 Shape capa 0: (1, 12, 18, 18) Tokens: ['[CLS]', 'dex', '##ame', '##tha', '##son', '##e', 'reduces', 'mortality', 'in', 'severe', 'co', '##vid', '-', '19', 'pneumonia', 'patients', '.', '[SEP]']
# Función para dibujar mapa de atención de una capa y cabeza concretas
def plot_attention_map(attentions, tokens, layer=0, head=0, max_tokens=20):
# Tomamos batch 0
att = attentions[layer][0, head].detach().cpu().numpy()
# Recorte para que la figura sea legible
n = min(max_tokens, len(tokens))
att = att[:n, :n]
toks_local = tokens[:n]
plt.figure(figsize=(9, 7))
sns.heatmap(att, xticklabels=toks_local, yticklabels=toks_local, cmap="magma")
plt.title(f"Atención - Capa {layer}, Cabeza {head}")
plt.xlabel("Token atendido (key)")
plt.ylabel("Token que atiende (query)")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
# Ejemplo: capa intermedia
plot_attention_map(attentions, toks, layer=5, head=3, max_tokens=18)
# Atención promedio por capa (promediando cabezas)
def plot_mean_attention_for_query(attentions, tokens, query_token_idx=1):
num_layers = len(attentions)
layer_vectors = []
for l in range(num_layers):
# [heads, seq, seq]
a = attentions[l][0].detach().cpu().numpy()
# Promedio entre cabezas
mean_a = a.mean(axis=0)
layer_vectors.append(mean_a[query_token_idx])
mat = np.vstack(layer_vectors)
plt.figure(figsize=(12, 5))
sns.heatmap(
mat,
cmap="viridis",
xticklabels=tokens,
yticklabels=[f"L{l}" for l in range(num_layers)]
)
plt.title(f"Evolución de atención por capa para query token idx={query_token_idx} ({tokens[query_token_idx]})")
plt.xlabel("Tokens destino")
plt.ylabel("Capas")
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()
# Query token: palabra principal (ajusta índice si quieres)
plot_mean_attention_for_query(attentions, toks, query_token_idx=2)
Interpretación guiada
- En capas bajas suele aparecer más atención local (tokens vecinos, subpalabras relacionadas).
- En capas medias/altas emergen patrones más semánticos (entidades clínicas, relaciones fármaco-enfermedad).
- No todas las cabezas hacen lo mismo: algunas capturan delimitación de entidad, otras dependencias largas.
8) Mini-evaluación cuantitativa sobre un gold set pequeño (didáctico)
# Gold set manual (pequeño) para evaluación pedagógica
# Formato: por texto, lista de (entidad, span_text)
gold_data = [
{
"text": "BRCA1 mutations are associated with increased risk of breast cancer.",
"gold": [("GENE", "BRCA1"), ("DISEASE", "breast cancer")],
},
{
"text": "Metformin improved glycemic control in type 2 diabetes.",
"gold": [("CHEMICAL", "Metformin"), ("DISEASE", "type 2 diabetes")],
},
{
"text": "Inflammation markers such as IL-6 and CRP were elevated.",
"gold": [("GENE", "IL-6")],
},
]
# Convertimos predicciones a conjunto de spans textuales normalizados
# (evaluación simplificada por coincidencia exacta de texto, sin offsets)
def normalize_span(s):
return s.strip().lower()
y_true = []
y_pred = []
for item in gold_data:
text = item["text"]
gold_spans = [normalize_span(s) for _, s in item["gold"]]
pred_spans = [normalize_span(p["word"]) for p in ner_pipe(text)]
universe = sorted(set(gold_spans) | set(pred_spans))
true_seq = ["ENT" if u in gold_spans else "O" for u in universe]
pred_seq = ["ENT" if u in pred_spans else "O" for u in universe]
y_true.append(true_seq)
y_pred.append(pred_seq)
print("Precision:", round(precision_score(y_true, y_pred), 4))
print("Recall:", round(recall_score(y_true, y_pred), 4))
print("F1:", round(f1_score(y_true, y_pred), 4))
print("\nReporte:")
print(classification_report(y_true, y_pred))
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
Precision: 0.1111
Recall: 0.2
F1: 0.1429
Reporte:
precision recall f1-score support
NT 0.11 0.20 0.14 5
micro avg 0.11 0.20 0.14 5
macro avg 0.11 0.20 0.14 5
weighted avg 0.11 0.20 0.14 5
/home/nuberu/xuan/naux/.venv/lib/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ENT seems not to be NE tag.
warnings.warn('{} seems not to be NE tag.'.format(chunk))
Nota metodológica sobre la evaluación
Esta evaluación es intencionalmente simple para aprender el flujo completo. En un entorno real conviene:
- usar datasets estándar del dominio,
- evaluar a nivel de token BIO o span con offsets exactos,
- separar validación/test,
- analizar por tipo de entidad y por longitud de span,
- estudiar errores por ambigüedad terminológica.
9) Tests de robustez básicos (perturbaciones controladas)
# Probamos cómo cambia la salida ante pequeñas variaciones del texto
base_text = "Patients were treated with dexamethasone for severe pneumonia."
variants = [
base_text,
base_text.lower(),
base_text.replace("dexamethasone", "DEXAMETHASONE"),
base_text.replace("severe", "critical"),
]
rows = []
for v in variants:
preds = ner_pipe(v)
rows.append({
"texto": v,
"n_entidades": len(preds),
"entidades": [p["word"] for p in preds],
"score_medio": np.mean([p["score"] for p in preds]) if preds else np.nan,
})
robust_df = pd.DataFrame(rows)
display(robust_df)
# Test simple: al menos en el texto base debería detectar 1 entidad
assert robust_df.loc[0, "n_entidades"] >= 1, "El modelo no detectó entidades en el texto base"
| texto | n_entidades | entidades | score_medio | |
|---|---|---|---|---|
| 0 | Patients were treated with dexamethasone for s... | 3 | [dexame, ##thasone, severe] | 0.889896 |
| 1 | patients were treated with dexamethasone for s... | 3 | [dexame, ##thasone, severe] | 0.889896 |
| 2 | Patients were treated with DEXAMETHASONE for s... | 3 | [dexame, ##thasone, severe] | 0.889896 |
| 3 | Patients were treated with dexamethasone for c... | 2 | [dexame, ##thasone] | 0.844717 |
10) Conclusiones y siguientes pasos
Conclusiones
- BERT (encoder bidireccional) es muy eficaz para NER porque contextualiza cada token usando toda la secuencia.
- En textos científicos, la tokenización en subpalabras es clave para manejar terminología especializada.
- El análisis de atención ayuda a entender patrones internos (qué tokens influyen en otros), aunque no es una explicación causal definitiva.
- Una evaluación cuantitativa, aunque sea pequeña, complementa muy bien la inspección cualitativa.
Qué más podrías probar
- Comparar este modelo con SciBERT o BioBERT ajustados a datasets concretos.
- Hacer fine-tuning supervisado en un corpus específico de tu área (biomedicina, química, materiales, etc.).
- Estudiar errores por tipos de entidad y crear una matriz de confusión a nivel etiqueta.
- Analizar diferencias de atención entre capas iniciales y finales para distintas familias de entidades.
- Integrar explicabilidad adicional (por ejemplo, gradientes integrados o perturbación sistemática de tokens).