Architecture Overview
BERTAlignModel is a PyTorch Lightning module built on a pretrained transformer encoder (RoBERTa or BERT) with four task heads on top. Only one head — the 3-way classification head — is used at inference time to produce the AlignScore.
12–24 layers · 768-dim hidden states
[N_tokens × 768]
all token vectors
(lm_head / cls.predictions)
→ [N_tokens × vocab_size]
(training only)
[768-dim]
CLS token only
Linear(768→2)
ALIGNED /
NOT-ALIGNED
_logits
Linear(768→3)
ALIGNED /
CONTRADICT /
NEUTRAL
⭐ used for score
Linear(768→1)
score ∈ [0,1]
The Backbone: RoBERTa
The model supports both RoBERTa and BERT backbones. For AlignScore, RoBERTa-base (125M) or RoBERTa-large (355M) is used. Two things are loaded from the pretrained checkpoint:
self.mlm_head = RobertaForMaskedLM.from_pretrained(model).lm_head token-level prediction head, used only for synthetic data generation
Two Key Outputs from base_model
| Output | Shape | What it is | Used by |
|---|---|---|---|
last_hidden_state |
[N_tokens × 768] | Contextual embedding for every token in the input | mlm_head only |
pooler_output |
[768] | Linear + tanh applied to the [CLS] token embedding only — represents the whole pair | bin, tri, reg heads |
The [CLS] token is trained to aggregate the meaning of the entire input sequence. Using it as a fixed-size 768-dim vector to represent the (chunk, sentence) pair is the standard BERT-family classification strategy — no additional pooling needed.
The Four Output Heads
| Head | Layer | Input | Output Shape | Purpose | Used at Inference? |
|---|---|---|---|---|---|
| mlm_head | lm_head (RoBERTa) | last_hidden_state | [N_tokens × vocab] | Predict masked tokens for synthetic data augmentation | No — training only |
| bin_layer | Linear(768 → 2) | pooler_output | [2] | Binary: ALIGNED / NOT-ALIGNED | Optional |
| tri_layer ⭐ | Linear(768 → 3) | pooler_output | [3] | 3-way: ALIGNED / CONTRADICT / NEUTRAL | Yes — primary signal |
| reg_layer | Linear(768 → 1) | pooler_output | [1] | Regression: continuous score in [0,1] | Optional |
The checkpoint was saved with all 4 heads. Even at inference only tri_layer is used, all heads must be present to load the checkpoint without errors — hence
strict=False is passed to load_from_checkpoint().
The Forward Pass
On each call to model(batch), the following happens in sequence:
base_model_output = self.base_model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
token_type_ids=batch.get('token_type_ids', None)
)
outputs last_hidden_state [N_tokens × 768] and pooler_output [768]
# Step 2: MLM head over ALL token embeddings (training only)
prediction_scores = self.mlm_head(base_model_output.last_hidden_state)
→ shape [N_tokens × vocab_size]
# Step 3: Three classification/regression heads over CLS vector
seq_relationship_score = self.bin_layer(self.dropout(base_model_output.pooler_output))
→ shape [2] ALIGNED / NOT-ALIGNED
tri_label_score = self.tri_layer(self.dropout(base_model_output.pooler_output))
→ shape [3] ALIGNED / CONTRADICT / NEUTRAL ⭐ this is used
reg_label_score = self.reg_layer(base_model_output.pooler_output)
→ shape [1] continuous score
bin_layer and tri_layer apply dropout(0.1) to the pooler_output before the linear layer. The reg_layer does not — it gets the raw pooler_output.
ModelOutput Dataclass
The forward pass returns a ModelOutput dataclass bundling all head outputs:
| Field | Shape | Source | Used for |
|---|---|---|---|
prediction_logits | [N_tokens × vocab] | mlm_head(last_hidden_state) | Synthetic data generation |
seq_relationship_logits | [batch × 2] | bin_layer(pooler_output) | Binary alignment training |
tri_label_logits ⭐ | [batch × 3] | tri_layer(pooler_output) | AlignScore inference |
reg_label_logits | [batch × 1] | reg_layer(pooler_output) | Regression training |
hidden_states | tuple of tensors | base_model internals | Optional analysis |
attentions | tuple of tensors | base_model internals | Optional analysis |
In inference_core(), only one field is accessed:
model_output = model(mini_batch) model_output_tri = model_output.tri_label_logits # shape [batch x 3]
From tri_label_logits to AlignScore
The final step converts raw logits to a single score per (chunk, sentence) pair:
shape [batch × 3] → probabilities summing to 1.0 per row
output_score = model_output_tri[:, 0]
column 0 = P(ALIGNED) → the AlignScore signal per pair
Softmax Output Example
For the pair: "DeepInfer infers preconditions from DNNs." vs "DeepInfer is a trustworthy AI tool."
[CLS] chunk [SEP] sentence [SEP]→ RoBERTa (12 layers)
→ pooler_output (CLS vector, 768-dim)
→ dropout → tri_layer Linear(768→3)
→ softmax → [P(ALIGNED), P(CONTRADICT), P(NEUTRAL)]
→ [:, 0] = P(ALIGNED) = single score per pair = AlignScore signal