0

I'm training a UNet to try and predict caustics, kind of as a toy problem.

An input image to the UNet is something like this:

enter image description here

It is trained on "ground truth" outputs like this:

enter image description here

It is currently predicting results like this:

enter image description here

I figure I've begun overfitting because the training loss is consistently dropping below the validation loss (I'm just using standard MSE over the pixels). So, my main question is: what part of the UNet architecture is producing this weird splotchiness? It almost seems like there's some visible grid in the predicted caustic...

It seems like the splotchiness is similar to the that seen on the images in this question.

enter image description here

Here's another example: enter image description here

My UNet is very standard. I'm trying to think what part of it would produce this grid-like splotchiness and then how I can tackle that.


For reference, my code looks like this:

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module): def init(self, in_channels=3, out_channels=3, pool_size=2, pool_type=F.max_pool2d, learn_residual=False): super(UNet, self).init()

    self.pool_size = pool_size
    self.pool_type = pool_type
    self.learn_residual = learn_residual
    self.in_channels = in_channels
    self.out_channels = out_channels

    # Encoder
    self.enc1 = self.conv_block(in_channels, 64)
    self.enc2 = self.conv_block(64, 128)
    self.enc3 = self.conv_block(128, 256)
    self.enc4 = self.conv_block(256, 512)

    # Bottleneck
    self.bottleneck = self.conv_block(512, 1024)

    # Decoder
    self.upconv4 = self.upconv_block(1024, 512)
    self.dec4 = self.conv_block(1024, 512)  # 512 (from upconv) + 512 (from enc4)

    self.upconv3 = self.upconv_block(512, 256)
    self.dec3 = self.conv_block(512, 256)   # 256 (from upconv) + 256 (from enc3)

    self.upconv2 = self.upconv_block(256, 128)
    self.dec2 = self.conv_block(256, 128)   # 128 (from upconv) + 128 (from enc2)

    self.upconv1 = self.upconv_block(128, 64)
    self.dec1 = self.conv_block(128, 64)    # 64 (from upconv) + 64 (from enc1)

    # Final Output
    self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1, stride=1) # No ReLU activation on final output

def conv_block(self, in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
        # Note: Removed MaxPool2d from conv_block
    )

def upconv_block(self, in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
        nn.ReLU(inplace=True)
    )

def forward(self, x):
    # Encoder
    enc1 = self.enc1(x)    # [batch, 64, H, W]
    enc2 = self.enc2(self.pool_type(enc1, self.pool_size)) # [batch, 128, H/2, W/2]
    enc3 = self.enc3(self.pool_type(enc2, self.pool_size)) # [batch, 256, H/4, W/4]
    enc4 = self.enc4(self.pool_type(enc3, self.pool_size)) # [batch, 512, H/8, W/8]

    # Bottleneck
    bottleneck = self.bottleneck(self.pool_type(enc4, 2)) # [batch, 1024, H/16, W/16]

    # Decoder
    up4 = self.upconv4(bottleneck)                         # [batch, 512, H/8, W/8]
    concat4 = torch.cat([up4, enc4], dim=1)                # [batch, 1024, H/8, W/8]
    dec4 = self.dec4(concat4)                              # [batch, 512, H/8, W/8]

    up3 = self.upconv3(dec4)                               # [batch, 256, H/4, W/4]
    concat3 = torch.cat([up3, enc3], dim=1)                # [batch, 512, H/4, W/4]
    dec3 = self.dec3(concat3)                              # [batch, 256, H/4, W/4]

    up2 = self.upconv2(dec3)                               # [batch, 128, H/2, W/2]
    concat2 = torch.cat([up2, enc2], dim=1)                # [batch, 256, H/2, W/2]
    dec2 = self.dec2(concat2)                              # [batch, 128, H/2, W/2]

    up1 = self.upconv1(dec2)                               # [batch, 64, H, W]
    concat1 = torch.cat([up1, enc1], dim=1)                # [batch, 128, H, W]
    dec1 = self.dec1(concat1)                              # [batch, 64, H, W]

    # Final Output
    out = self.final_conv(dec1)                            # [batch, 3, H, W]

    if self.learn_residual:
        out += x[:, :self.out_channels, :, :]

    return out

desertnaut
  • 1,021
  • 11
  • 19
Anson Savage
  • 101
  • 3

0 Answers0