I'm training a UNet to try and predict caustics, kind of as a toy problem.
An input image to the UNet is something like this:
It is trained on "ground truth" outputs like this:
It is currently predicting results like this:
I figure I've begun overfitting because the training loss is consistently dropping below the validation loss (I'm just using standard MSE over the pixels). So, my main question is: what part of the UNet architecture is producing this weird splotchiness? It almost seems like there's some visible grid in the predicted caustic...
It seems like the splotchiness is similar to the that seen on the images in this question.
My UNet is very standard. I'm trying to think what part of it would produce this grid-like splotchiness and then how I can tackle that.
For reference, my code looks like this:
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def init(self, in_channels=3, out_channels=3, pool_size=2, pool_type=F.max_pool2d, learn_residual=False):
super(UNet, self).init()
self.pool_size = pool_size
self.pool_type = pool_type
self.learn_residual = learn_residual
self.in_channels = in_channels
self.out_channels = out_channels
# Encoder
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)
# Bottleneck
self.bottleneck = self.conv_block(512, 1024)
# Decoder
self.upconv4 = self.upconv_block(1024, 512)
self.dec4 = self.conv_block(1024, 512) # 512 (from upconv) + 512 (from enc4)
self.upconv3 = self.upconv_block(512, 256)
self.dec3 = self.conv_block(512, 256) # 256 (from upconv) + 256 (from enc3)
self.upconv2 = self.upconv_block(256, 128)
self.dec2 = self.conv_block(256, 128) # 128 (from upconv) + 128 (from enc2)
self.upconv1 = self.upconv_block(128, 64)
self.dec1 = self.conv_block(128, 64) # 64 (from upconv) + 64 (from enc1)
# Final Output
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1, stride=1) # No ReLU activation on final output
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
# Note: Removed MaxPool2d from conv_block
)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.ReLU(inplace=True)
)
def forward(self, x):
# Encoder
enc1 = self.enc1(x) # [batch, 64, H, W]
enc2 = self.enc2(self.pool_type(enc1, self.pool_size)) # [batch, 128, H/2, W/2]
enc3 = self.enc3(self.pool_type(enc2, self.pool_size)) # [batch, 256, H/4, W/4]
enc4 = self.enc4(self.pool_type(enc3, self.pool_size)) # [batch, 512, H/8, W/8]
# Bottleneck
bottleneck = self.bottleneck(self.pool_type(enc4, 2)) # [batch, 1024, H/16, W/16]
# Decoder
up4 = self.upconv4(bottleneck) # [batch, 512, H/8, W/8]
concat4 = torch.cat([up4, enc4], dim=1) # [batch, 1024, H/8, W/8]
dec4 = self.dec4(concat4) # [batch, 512, H/8, W/8]
up3 = self.upconv3(dec4) # [batch, 256, H/4, W/4]
concat3 = torch.cat([up3, enc3], dim=1) # [batch, 512, H/4, W/4]
dec3 = self.dec3(concat3) # [batch, 256, H/4, W/4]
up2 = self.upconv2(dec3) # [batch, 128, H/2, W/2]
concat2 = torch.cat([up2, enc2], dim=1) # [batch, 256, H/2, W/2]
dec2 = self.dec2(concat2) # [batch, 128, H/2, W/2]
up1 = self.upconv1(dec2) # [batch, 64, H, W]
concat1 = torch.cat([up1, enc1], dim=1) # [batch, 128, H, W]
dec1 = self.dec1(concat1) # [batch, 64, H, W]
# Final Output
out = self.final_conv(dec1) # [batch, 3, H, W]
if self.learn_residual:
out += x[:, :self.out_channels, :, :]
return out




