0

I want to train an MTL model with NER and a seq2seq architecture to identify entity names and correct typos within them. The NER model is word-based, while the seq2seq model operates at the character level. The NER model performs well, but the seq2seq typo correction model performs poorly. Can anyone help me understand the issue with the model or suggest ways to improve its performance?

toy_data = [ { "sentence": "The superstore OBAD sold 1000 new deserts last month", "correction": "oba",
"entity_type": "store" }, { "sentence": "The green organization of OBAD hired new volunteers to help them for clean forests", "correction": "obada", "entity_type": "organization" }, { "sentence": "The tech giant applee inc unveiled its latest gadget", "correction": "apple", "entity_type": "company" }]

import math import random import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.nn.utils import clip_grad_norm_ from torch.utils.data import Dataset, DataLoader

Building vocabularies for words and characters

def build_word_vocab(sentences): vocab = set() for s in sentences: for w in s.lower().split(): vocab.add(w) vocab.add("") vocab.add("") vocab.add("") vocab.add("") return sorted(list(vocab))

def build_char_vocab(texts): vocab = set() for text in texts: vocab.update(list(text)) vocab.add("") vocab.add("") vocab.add("") vocab.add("") # Add for unknown characters return sorted(list(vocab))

Collect all sentences and correction targets.

all_sentences = [item["sentence"].lower() for item in toy_data] all_corrections = [item["correction"] for item in toy_data]

word_vocab = build_word_vocab(all_sentences) char_vocab = build_char_vocab(all_corrections)

word2idx = {w: idx for idx, w in enumerate(word_vocab)} idx2word = {idx: w for w, idx in word2idx.items()} char2idx = {ch: idx for idx, ch in enumerate(char_vocab)} idx2char = {idx: ch for ch, idx in char2idx.items()}

word_vocab_size = len(word2idx) char_vocab_size = len(char2idx)

Define special token indices.

WORD_PAD_IDX = word2idx[""] CHAR_PAD_IDX = char2idx[""] SOS_token = char2idx[""] EOS_token = char2idx[""]

def generate_ner_labels(sentence): tokens = sentence.split() labels = [] for token in tokens: if token == "obad": labels.append(1) # B-COMPANY; (since we assume a one-token entity) else: labels.append(0) return tokens, labels

Find maximum sentence lengths (word count) and maximum correction length (char count).

max_word_len = max(len(s.split()) for s in all_sentences) max_char_len = max(len(txt) for txt in all_corrections)

Functions to encode and pad word sequences and label sequences.

def encode_words(sentence, word2idx, max_len): tokens = sentence.split() ids = [word2idx.get(t, word2idx[""]) for t in tokens] if len(ids) < max_len: ids = ids + [WORD_PAD_IDX] * (max_len - len(ids)) else: ids = ids[:max_len] return ids

def encode_labels(labels, max_len, pad_idx=WORD_PAD_IDX): if len(labels) < max_len: labels = labels + [pad_idx] * (max_len - len(labels)) # Replace -100 with pad_idx else: labels = labels[:max_len] return labels

def encode_chars(text, char2idx, max_len): ids = [char2idx[ch] for ch in text if ch in char2idx] if len(ids) < max_len: ids = ids + [CHAR_PAD_IDX] * (max_len - len(ids)) else: ids = ids[:max_len] return ids

Create a dataset class that returns:

(1) word–level tensor for NER,

(2) word–level NER labels,

(3) char–level tensor for correction target.

class MultiTaskDataset(Dataset): def init(self, data, word2idx, char2idx, max_word_len, max_char_len): self.data = data self.word2idx = word2idx self.char2idx = char2idx self.max_word_len = max_word_len self.max_char_len = max_char_len

def __len__(self):
    return len(self.data)

def getitem(self, idx): item = self.data[idx] sentence = item["sentence"].lower() correction = item["correction"] word_tokens, ner_labels = generate_ner_labels(sentence) word_ids = encode_words(sentence, self.word2idx, self.max_word_len) ner_ids = encode_labels(ner_labels, self.max_word_len) char_ids = encode_chars(correction, self.char2idx, self.max_char_len) return (torch.tensor(word_ids, dtype=torch.long), torch.tensor(ner_ids, dtype=torch.long), torch.tensor(char_ids, dtype=torch.long))

dataset = MultiTaskDataset(toy_data, word2idx, char2idx, max_word_len, max_char_len) dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

class MultiHeadAttention(nn.Module): def init(self, d_model, num_heads): super(MultiHeadAttention, self).init() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model)

def scaled_dot_product_attention(self, Q, K, V, mask=None):
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
    if mask is not None:
        attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
    attn_probs = torch.softmax(attn_scores, dim=-1)
    return torch.matmul(attn_probs, V)

def split_heads(self, x): batch_size, seq_length, _ = x.size() x = x.view(batch_size, seq_length, self.num_heads, self.d_k) return x.transpose(1, 2)

def combine_heads(self, x): batch_size, num_heads, seq_length, d_k = x.size() x = x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) return x

def forward(self, Q, K, V, mask=None): Q = self.split_heads(self.W_q(Q)) K = self.split_heads(self.W_k(K)) V = self.split_heads(self.W_v(V)) attn_output = self.scaled_dot_product_attention(Q, K, V, mask) return self.W_o(self.combine_heads(attn_output))

class PositionWiseFeedForward(nn.Module): def init(self, d_model, d_ff): super(PositionWiseFeedForward, self).init() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.relu = nn.ReLU()

def forward(self, x):
    return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module): def init(self, d_model, max_len): super(PositionalEncoding, self).init() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer("pe", pe)

def forward(self, x):
    return x + self.pe[:, :x.size(1)]

class EncoderLayer(nn.Module): def init(self, d_model, num_heads, d_ff, dropout): super(EncoderLayer, self).init() self.self_attn = MultiHeadAttention(d_model, num_heads) self.ffn = PositionWiseFeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout)

def forward(self, x, mask):
    attn = self.self_attn(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn))
    ffn_out = self.ffn(x)
    x = self.norm2(x + self.dropout(ffn_out))
    return x

class DecoderLayer(nn.Module): def init(self, d_model, num_heads, d_ff, dropout): super(DecoderLayer, self).init() self.self_attn = MultiHeadAttention(d_model, num_heads) self.cross_attn = MultiHeadAttention(d_model, num_heads) self.ffn = PositionWiseFeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout)

def forward(self, x, enc_output, src_mask, tgt_mask):
    self_attn = self.self_attn(x, x, x, tgt_mask)
    x = self.norm1(x + self.dropout(self_attn))
    cross_attn = self.cross_attn(x, enc_output, enc_output, src_mask)
    x = self.norm2(x + self.dropout(cross_attn))
    ffn_out = self.ffn(x)
    x = self.norm3(x + self.dropout(ffn_out))
    return x

class MultiTaskTransformer(nn.Module): def init(self, word_vocab_size, char_vocab_size, d_model, num_heads, num_layers, d_ff, max_word_len, max_char_len, dropout, word_pad_idx, char_pad_idx): super(MultiTaskTransformer, self).init() self.word_pad_idx = word_pad_idx self.char_pad_idx = char_pad_idx self.max_word_len = max_word_len self.max_char_len = max_char_len

    # Encoder: word embeddings + positional encoding.
    self.word_embedding = nn.Embedding(word_vocab_size, d_model, padding_idx=word_pad_idx)
    self.pos_enc = PositionalEncoding(d_model, max_word_len)
    self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
# NER head: per-word classification into 3 classes (O, B-COMPANY, I-COMPANY)
self.ner_classifier = nn.Linear(d_model, 3)

# Decoder: character embeddings + positional encoding.
self.char_embedding = nn.Embedding(char_vocab_size, d_model, padding_idx=char_pad_idx)
self.pos_enc_dec = PositionalEncoding(d_model, max_char_len)
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.fc_out = nn.Linear(d_model, char_vocab_size)
self.dropout = nn.Dropout(dropout)

def generate_masks(self, word_ids, tgt_ids): # Encoder mask: (batch, 1, 1, word_seq_len) enc_mask = (word_ids != self.word_pad_idx).unsqueeze(1).unsqueeze(2) # Decoder mask: combine padding mask and subsequent mask. tgt_mask = (tgt_ids != self.char_pad_idx).unsqueeze(1).unsqueeze(3) seq_len = tgt_ids.size(1) nopeak_mask = torch.triu(torch.ones((1, seq_len, seq_len), device=tgt_ids.device), diagonal=1).bool() tgt_mask = tgt_mask & ~nopeak_mask return enc_mask, tgt_mask

def forward(self, word_ids, tgt_ids=None): # --- Encoder: Word-level --- enc_emb = self.dropout(self.pos_enc(self.word_embedding(word_ids))) enc_mask = (word_ids != self.word_pad_idx).unsqueeze(1).unsqueeze(2) enc_output = enc_emb for layer in self.encoder_layers: enc_output = layer(enc_output, enc_mask)

# --- NER Head ---
ner_logits = self.ner_classifier(enc_output)  # shape: (batch, word_seq_len, 3)
ner_probs = torch.softmax(ner_logits, dim=-1)
company_prob = ner_probs[:, :, 1:3].sum(dim=-1, keepdim=True)  # (batch, word_seq_len, 1)
company_sum = (enc_output * company_prob).sum(dim=1, keepdim=True)
denom = company_prob.sum(dim=1, keepdim=True) + 1e-8
company_summary = company_sum / denom
enc_extended = torch.cat([enc_output, company_summary], dim=1)
extra_mask = torch.ones((enc_mask.size(0), 1, 1, 1), device=enc_mask.device).bool()
enc_mask_extended = torch.cat([enc_mask, extra_mask], dim=-1)

# --- Decoder: Character-based Correction ---
correction_logits = None
if tgt_ids is not None:
    enc_mask_dec, tgt_mask = self.generate_masks(word_ids, tgt_ids)
    dec_emb = self.dropout(self.pos_enc_dec(self.char_embedding(tgt_ids)))
    dec_output = dec_emb
    for layer in self.decoder_layers:
        dec_output = layer(dec_output, enc_extended, enc_mask_extended, tgt_mask)
    correction_logits = self.fc_out(dec_output)
return ner_logits, correction_logits

Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Hyperparameters.

d_model = 128 num_heads = 8 num_layers = 4 d_ff = 512 dropout = 0.1

model = MultiTaskTransformer( word_vocab_size=word_vocab_size, char_vocab_size=char_vocab_size, d_model=d_model, num_heads=num_heads, num_layers=num_layers, d_ff=d_ff, max_word_len=max_word_len, max_char_len=max_char_len, dropout=dropout, word_pad_idx=WORD_PAD_IDX, char_pad_idx=CHAR_PAD_IDX ).to(device)

Loss functions

criterion_ner = nn.CrossEntropyLoss(ignore_index=WORD_PAD_IDX) criterion_corr = nn.CrossEntropyLoss(ignore_index=CHAR_PAD_IDX)

Dynamic loss balancing parameters

log_var_ner = torch.zeros(1, requires_grad=True, device=device) log_var_corr = torch.zeros(1, requires_grad=True, device=device)

Optimizer

optimizer = optim.Adam(list(model.parameters()) + [log_var_ner, log_var_corr], lr=1e-4)

Curriculum learning: train correction first, then NER

total_epochs = 600 curriculum_epochs = 200 # For the first 20 epochs, only correction loss is used. max_grad_norm = 1.0

Training Loop

model.train() for epoch in range(total_epochs): epoch_loss = 0.0 for word_ids, ner_labels, tgt_ids in dataloader: word_ids = word_ids.to(device) ner_labels = ner_labels.to(device) tgt_ids = tgt_ids.to(device)

    optimizer.zero_grad()
loss_ner = criterion_ner(ner_logits.view(-1, 3), ner_labels.view(-1))

# Compute correction loss (no change needed)
loss_corr = criterion_corr(corr_logits.view(-1, char_vocab_size), tgt_ids[:, 1:].contiguous().view(-1)) # decoder target


ner_logits, corr_logits = model(word_ids, tgt_ids[:, :-1])

# Compute NER loss
loss_ner = criterion_ner(ner_logits.view(-1, 3), ner_labels.view(-1))

# Compute correction loss
#loss_corr = criterion_corr(corr_logits.view(-1, char_vocab_size), tgt_ids[:, 1:].contiguous().view(-1))

loss_corr = criterion_corr(corr_logits.view(-1, char_vocab_size), tgt_ids[:, 1:].contiguous().view(-1))

# Dynamic loss balancing (uncertainty weighting)
if epoch &lt; curriculum_epochs:
    total_loss = torch.exp(-log_var_corr) * loss_corr + log_var_corr
else:
    total_loss = (torch.exp(-log_var_corr) * loss_corr + log_var_corr +
                  torch.exp(-log_var_ner) * loss_ner + log_var_ner)

total_loss.backward()
clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()

epoch_loss += total_loss.item()

avg_loss = epoch_loss / len(dataloader) phase = "Correction-only" if epoch < curriculum_epochs else "Joint training" print(f"Epoch {epoch+1}/{total_epochs} [{phase}], Loss: {avg_loss:.4f} | " f"w_ner: {log_var_ner.item():.4f}, w_corr: {log_var_corr.item():.4f}")

0 Answers0