I'm working on a model that combines a CNN with an LSTM to process sequences of spectrograms and make per-time-step predictions. The CNN alone performs well on the task, but after adding an LSTM for temporal modeling, the model's performance doesn't improve as expected. Here's the architecture of the combined model:
class CNNRNN(nn.Module):
def __init__(self, cnn_input_shape=(129, 35), num_outputs=1):
"""
CNN-RNN model for per-time-step predictions using previous hidden states.
Args:
input_shape (tuple): Shape of the 2D input data (height, width).
num_outputs (int): Number of output predictions (e.g., 1 for regression).
"""
super(CNNRNN, self).__init__()
# CNN for 2D data (e.g., spectrograms)
self.cnn_num_outputs = 1
self.cnn_feature_extractor = CNN(
input_shape=cnn_input_shape, num_outputs=self.cnn_num_outputs
)
# LSTM layer
self.hidden_size = 64
self.lstm = nn.LSTM(
input_size=self.cnn_num_outputs, hidden_size=self.hidden_size, batch_first=True
)
# Fully connected layer for predictions
self.num_outputs = 1
self.fc = nn.Linear(self.hidden_size, self.num_outputs)
def forward(self, x_padded, lengths):
"""
Forward pass for sequences.
Args:
x_padded (torch.Tensor): Padded input tensor of shape (batch_size, max_seq_len, height, width).
lengths (torch.Tensor): Actual lengths of each sequence in the batch.
Returns:
torch.Tensor: Predictions for each time step.
"""
batch_size, max_seq_len, height, width = x_padded.size()
# Reshape to process all time steps through CNN
x_reshaped = x_padded.view(
batch_size * max_seq_len, height, width
) # (batch_size, max_seq_len, height, width) -> (batch_size * max_seq_len, height, width)
# Extract CNN features
cnn_features = self.cnn_feature_extractor(
x_reshaped
) # (batch_size * max_seq_len, height, width) -> (batch_size * max_seq_len, cnn_num_outputs)
# Reshape back to (batch_size, max_seq_len, cnn_num_outputs)
cnn_features = cnn_features.view(
batch_size, max_seq_len, -1
) # (batch_size * max_seq_len, cnn_num_outputs) -> (batch_size, max_seq_len, cnn_num_outputs)
# Pack sequences
packed_input = pack_padded_sequence(
cnn_features, lengths.cpu(), batch_first=True, enforce_sorted=False
) # (batch_size, max_seq_len, cnn_num_outputs)
# Pass through LSTM
packed_output, _ = self.lstm(packed_input)
# Unpack sequences
lstm_output, _ = pad_packed_sequence(
packed_output, batch_first=True
) # Shape: (batch_size, max_seq_len, hidden_size)
# Reshape to apply fully connected layer
lstm_output = lstm_output.view(
-1, self.hidden_size
) # (batch_size, max_seq_len, hidden_size) -> (batch_size * max_seq_len, hidden_size)
# Apply the fully connected layer to each time step
predictions = self.fc(
lstm_output
) # (batch_size * max_seq_len, hidden_size) -> (batch_size * max_seq_len, num_outputs)
# Reshape to (batch_size, max_seq_len, num_outputs)
predictions = predictions.view(
batch_size, max_seq_len, self.num_outputs
) # (batch_size * max_seq_len, num_outputs) -> (batch_size, max_seq_len, num_outputs)
# If num_outputs == 1, squeeze last dimension
if predictions.size(-1) == 1:
predictions = predictions.squeeze(-1) # (batch_size, max_seq_len)
return predictions
The CNN component looks like this:
class CNN(nn.Module):
def __init__(self, input_shape=(129, 35), num_outputs=1):
"""
Build a CNN for regression, tailored to 2D data (e.g., spectrograms),
with a reshape operation to add a channel dimension.
Args:
input_shape (tuple): The shape of the input data, defaulting to (129, 35).
num_outputs (int): The number of output values for regression.
"""
super(CNN, self).__init__()
self.input_shape = input_shape
self.num_outputs = num_outputs
self.model = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=(3, 3), padding="same"),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2)),
nn.Conv2d(8, 16, kernel_size=(3, 3), padding="same"),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2)),
nn.Conv2d(16, 32, kernel_size=(3, 3), padding="same"),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2)),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding="same"),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2)),
nn.Flatten(),
nn.Linear(
64 * ((input_shape[0] // 16) * (input_shape[1] // 16)), num_outputs
),
)
def forward(self, x):
# Process input through the CNN layers
# Add a channel dimension to the input: (batch_size, 129, 35) -> (batch_size, 1, 129, 35)
x = x.unsqueeze(1)
x = self.model(x)
# Remove the last dimension: (batch_size, num_outputs) -> (batch_size)
x = x.squeeze(-1)
return x
I'm training the model on variable-length sequences using the following setup in my training loop, with truncated backpropagation through time (TBPTT):
for epoch in range(epochs):
model.train()
epoch_train_loss = 0.0 # Accumulates the total loss over the epoch
epoch_train_mae = 0.0 # Accumulates the MAE over the epoch
# Train the model for the current epoch
t_train = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}")
for inputs, lengths, labels in t_train:
# inputs: (batch_size, max_seq_len, height, width)
# labels: (batch_size, max_seq_len)
# lengths: (batch_size,)
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Initialize batch loss and MAE for this batch
batch_train_loss = 0.0
batch_train_mae = 0.0
# Truncate sequences and perform forward and backward passes in chunks
for t in range(0, inputs.size(1), tbptt_steps):
# Extract a chunk of inputs and labels for the current TBPTT step
input_chunk = inputs[:, t : t + tbptt_steps, :, :].contiguous()
label_chunk = labels[:, t : t + tbptt_steps].contiguous()
# Adjust lengths for the current chunk
adjusted_lengths = lengths - t
adjusted_lengths = adjusted_lengths.clamp(min=0, max=tbptt_steps)
# Filter out sequences with zero adjusted lengths
non_zero_indices = (adjusted_lengths > 0).nonzero(as_tuple=True)[0]
if len(non_zero_indices) == 0:
continue # Skip this chunk if all sequences have ended
input_chunk = input_chunk[non_zero_indices, :, :, :]
label_chunk = label_chunk[non_zero_indices, :]
adjusted_lengths = adjusted_lengths[non_zero_indices]
# Pass the input chunk and adjusted lengths to the model
outputs = model(input_chunk, adjusted_lengths)
# Mask the outputs and labels
max_seq_len = outputs.size(1)
mask = torch.arange(max_seq_len).expand(len(adjusted_lengths), max_seq_len).to(
device
) < adjusted_lengths.unsqueeze(1)
outputs_masked = outputs[mask]
labels_masked = label_chunk[mask]
# Calculate chunk loss and MAE for this batch
loss = criterion(outputs_masked, labels_masked)
mae_loss = mae_criterion(outputs_masked, labels_masked)
# Accumulate batch losses over chunks
batch_train_loss += loss.item()
batch_train_mae += mae_loss.item()
# Calculate and log the gradient norms (min and max)
max_grad_norm = 0
min_grad_norm = float("inf")
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
max_grad_norm = max(max_grad_norm, grad_norm)
min_grad_norm = min(min_grad_norm, grad_norm)
# Log gradient norms
logger.info(
f"Batch {t}: Max gradient norm = {max_grad_norm:.6f}, Min gradient norm = {min_grad_norm:.6f}"
)
# Backpropagation
loss.backward()
nn.utils.clip_grad_norm_(
model.parameters(), max_norm=10.0, norm_type=2
) # Clip gradients
# Optimizer step after processing all chunks in the batch
optimizer.step()
# Accumulate epoch-level losses from the batch-level losses
epoch_train_loss += batch_train_loss
epoch_train_mae += batch_train_mae
# Update the progress bar with the total batch-wise losses for this batch
t_train.set_postfix(batch_train_loss=batch_train_loss, batch_train_mae=batch_train_mae)
# Compute the average loss and MAE for the training set over the epoch
avg_epoch_train_loss = epoch_train_loss / len(train_loader)
avg_epoch_train_mae = epoch_train_mae / len(train_loader)
# Log the epoch results for training
logger.info(
f"Epoch {epoch+1}/{epochs} - "
f"Train Loss: {avg_epoch_train_loss:.4f}, Train MAE: {avg_epoch_train_mae:.4f}"
)
What I’ve tried:
- Checking the CNN performance: The CNN seems to perform well in isolation.
- Gradient clipping: I’ve added gradient clipping with max_norm=10.0 to prevent exploding gradients during training.
- Batch processing with variable sequence lengths: I’m using a custom collate function to handle variable-length sequences, ensuring correct padding and sequence packing for the LSTM.
Despite these efforts, the overall model doesn't seem to learn. The training and validation losses don't improve at all, which makes me think the issue is in the way I am training the model.
Am I implementing TBPTT correctly? Any ideas on where the problem could lie?
For completeness I have attached my Custom Dataset function and collatefn implementations to deal with variable length sequences:
class CustomDataset(Dataset):
def __init__(self, X, y):
"""
Custom dataset for variable-length sequences.
Args:
X (list of torch.Tensor): List of input tensors of varying lengths.
y (list of torch.Tensor): List of target tensors of varying lengths.
"""
self.X = X
self.y = y
def len(self):
return len(self.X)
def getitem(self, idx):
return self.X[idx], self.y[idx]
def collate_fn(batch):
"""
Custom collate function to pad sequences and labels to the same length in a batch.
Args:
batch (list of tuples): List of (input, target) tuples from the dataset.
Returns:
tuple: Padded input sequences, sequence lengths, padded targets.
"""
Sort batch by sequence length (descending)
batch = sorted(batch, key=lambda x: x[0].size(0), reverse=True)
sequences, labels = zip(*batch)
lengths = torch.tensor([seq.size(0) for seq in sequences])
Pad sequences (batch_first=True for easier handling)
sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0.0)
labels_padded = pad_sequence(labels, batch_first=True, padding_value=0.0)
return sequences_padded, lengths, labels_padded