2

TLDR: given two tensors $t_1$ and $t_2$, both with shape $(c,h,w),$ how shall the distance between them be measured?


More Info: I'm working on a project in which I'm trying to distinguish between an anomalous sample (specifically from MNIST) and a "regular" sample (specifically from CIFAR10). The solution I chose is to consider the feature maps that are given by ResNet and use kNN. More specifically:

  • I embed the entire CIFAR10_TRAIN data to achieve a dataset that consists of activations with dimension $(N,c,h,w)$ where $N$ is the size of CIFAR_TRAIN
  • I embed $2$ new test samples $t_C$ and $t_M$ from CIFAR10_TEST and MNIST_TEST respectively (both with shape $(c,h,w)$), same as I did with the training data.
  • (!) I find the k-Nearest-Neighbours of $t_C$ and $t_M$ w.r.t the embedding of the training data
  • I calculate the mean distance between the $k$ neighbors
  • Given some predefined threshold, I classify $t_C$ and $t_M$ as regular or anomalous, hoping that the distance for $t_M$ would be higher, as it represents O.O.D sample.

Notice that in (!) I need some distance measure, but this is not trivial as these are tensors, not vectors.


What I've Tried: a trivial solution is to flatten the tensor to have shape $(c\cdot h\cdot w)$ and then use basic $\ell_2$, but the results turned out pretty bad. (could not distinguish regular vs anomalous in this case). Hence: Is there a better way of measuring this distance?

Hadar Sharvit
  • 381
  • 1
  • 13

1 Answers1

2

You could try an earth mover distance in 2d or 3d over the image? For example you could follow this example, but call it sequentially. The idea would be something like the following (untested and written on my cell phone):

def cumsum_3d(a):
    a = torch.cumsum(a, -1)
    a = torch.cumsum(a, -2)
    a = torch.cumsum(a, -3)
    return a

def norm_3d(a): return a / torch.sum(a, dim=(-1,-2,-3), keepdim=True)

def emd_3d(a, b): a = norm_3d(a) b = norm_3d(b) return torch.mean(torch.square(cumsum_3d(a) - cumsum_3d(b)), dim=(-1,-2,-3))

This should also work with batched data. I would also try normalizing the images first (so they each sum to 1) unless you want to account for changes in intensity.

John St. John
  • 206
  • 1
  • 4