Architecture Deep Dive · AlignScore

BERTAlignModel: How It Works

A detailed breakdown of the RoBERTa backbone, pooling strategy, four output heads, and how P(ALIGNED) becomes the AlignScore signal

Architecture Overview

BERTAlignModel is a PyTorch Lightning module built on a pretrained transformer encoder (RoBERTa or BERT) with four task heads on top. Only one head — the 3-way classification head — is used at inference time to produce the AlignScore.

[CLS] chunk_text [SEP] summary_sentence [SEP]
RoBERTa / BERT Transformer
12–24 layers  ·  768-dim hidden states
two outputs extracted
last_hidden_state
[N_tokens × 768]
all token vectors
mlm_head
(lm_head / cls.predictions)
→ [N_tokens × vocab_size]
prediction_logits
(training only)
pooler_output
[768-dim]
CLS token only
↓ dropout(0.1)
bin_layer
Linear(768→2)
ALIGNED /
NOT-ALIGNED
seq_relationship
_logits
tri_layer
Linear(768→3)
ALIGNED /
CONTRADICT /
NEUTRAL
tri_label_logits
⭐ used for score
reg_layer
Linear(768→1)
score ∈ [0,1]
reg_label_logits
softmax(tri_label_logits)[:, 0] = P(ALIGNED) → AlignScore

The Backbone: RoBERTa

The model supports both RoBERTa and BERT backbones. For AlignScore, RoBERTa-base (125M) or RoBERTa-large (355M) is used. Two things are loaded from the pretrained checkpoint:

model/__init__.py — backbone loading
self.base_model = RobertaModel.from_pretrained(model) transformer blocks, outputs last_hidden_state + pooler_output
self.mlm_head   = RobertaForMaskedLM.from_pretrained(model).lm_head token-level prediction head, used only for synthetic data generation

Two Key Outputs from base_model

OutputShapeWhat it isUsed by
last_hidden_state [N_tokens × 768] Contextual embedding for every token in the input mlm_head only
pooler_output [768] Linear + tanh applied to the [CLS] token embedding only — represents the whole pair bin, tri, reg heads
Why pooler_output for scoring?
The [CLS] token is trained to aggregate the meaning of the entire input sequence. Using it as a fixed-size 768-dim vector to represent the (chunk, sentence) pair is the standard BERT-family classification strategy — no additional pooling needed.

The Four Output Heads

HeadLayerInputOutput ShapePurposeUsed at Inference?
mlm_head lm_head (RoBERTa) last_hidden_state [N_tokens × vocab] Predict masked tokens for synthetic data augmentation No — training only
bin_layer Linear(768 → 2) pooler_output [2] Binary: ALIGNED / NOT-ALIGNED Optional
tri_layer ⭐ Linear(768 → 3) pooler_output [3] 3-way: ALIGNED / CONTRADICT / NEUTRAL Yes — primary signal
reg_layer Linear(768 → 1) pooler_output [1] Regression: continuous score in [0,1] Optional
Why keep all 4 heads loaded?
The checkpoint was saved with all 4 heads. Even at inference only tri_layer is used, all heads must be present to load the checkpoint without errors — hence strict=False is passed to load_from_checkpoint().

The Forward Pass

On each call to model(batch), the following happens in sequence:

forward(self, batch)
# Step 1: Run RoBERTa transformer on tokenized pair
base_model_output = self.base_model(
    input_ids=batch['input_ids'],
    attention_mask=batch['attention_mask'],
    token_type_ids=batch.get('token_type_ids', None)
)
outputs last_hidden_state [N_tokens × 768] and pooler_output [768]

# Step 2: MLM head over ALL token embeddings (training only)
prediction_scores = self.mlm_head(base_model_output.last_hidden_state)
→ shape [N_tokens × vocab_size]

# Step 3: Three classification/regression heads over CLS vector
seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output))
→ shape [2]   ALIGNED / NOT-ALIGNED

tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
→ shape [3]   ALIGNED / CONTRADICT / NEUTRAL   ⭐ this is used

reg_label_score = self.reg_layer(base_model_output.pooler_output)
→ shape [1]   continuous score
Note on dropout: bin_layer and tri_layer apply dropout(0.1) to the pooler_output before the linear layer. The reg_layer does not — it gets the raw pooler_output.

ModelOutput Dataclass

The forward pass returns a ModelOutput dataclass bundling all head outputs:

FieldShapeSourceUsed for
prediction_logits[N_tokens × vocab]mlm_head(last_hidden_state)Synthetic data generation
seq_relationship_logits[batch × 2]bin_layer(pooler_output)Binary alignment training
tri_label_logits[batch × 3]tri_layer(pooler_output)AlignScore inference
reg_label_logits[batch × 1]reg_layer(pooler_output)Regression training
hidden_statestuple of tensorsbase_model internalsOptional analysis
attentionstuple of tensorsbase_model internalsOptional analysis

In inference_core(), only one field is accessed:

model_output = model(mini_batch)
model_output_tri = model_output.tri_label_logits   # shape [batch x 3]

From tri_label_logits to AlignScore

The final step converts raw logits to a single score per (chunk, sentence) pair:

inference_core() — scoring
model_output_tri = softmax(tri_label_logits)
shape [batch × 3]  →  probabilities summing to 1.0 per row

output_score = model_output_tri[:, 0]
column 0 = P(ALIGNED)  →  the AlignScore signal per pair

Softmax Output Example

For the pair: "DeepInfer infers preconditions from DNNs." vs "DeepInfer is a trustworthy AI tool."

logits →
[2.1,
raw
-0.3,
raw
0.5]
raw
→ softmax →
0.906
P(ALIGNED)  ▲
0.044
P(CONTRADICT)
0.050
P(NEUTRAL)
Full picture in one line:

[CLS] chunk [SEP] sentence [SEP]
  → RoBERTa (12 layers)
  → pooler_output (CLS vector, 768-dim)
  → dropout → tri_layer Linear(768→3)
  → softmax → [P(ALIGNED), P(CONTRADICT), P(NEUTRAL)]
  → [:, 0] = P(ALIGNED) = single score per pair = AlignScore signal