1

I am looking for a way to re-identify/classify/recognize x real life objects (x < 50) with a camera. Each object should be presented to the AI only once for learning and there's always only one of these objects in the query image. New objects should be addable to the list of "known" objects. The objects are not necessarily part of ImageNet nor do I have a training dataset with various instances of these objects.

Example:

In the beginning I have no "known" objects. Now I present a smartphone, a teddy bear and a pair of scissors to the system. It should learn to re-identify these three objects if presented in the future. The objects will be the exact same objects, i.e. not a different phone, but definitely in a different viewing angle, lighting etc.

My understanding is that I would have to place each object in an embedding space and do a simple nearest neighbor lookup in that space for the queries. Maybe just use a trained ResNet, cut off the classification and simply use the output vector for each object? Not sure what the best way would be.

Any advice or hint to the right direction would be highly appreciated.

sonovice
  • 111
  • 3

1 Answers1

0

I have put my initial idea to a test and used a small pretrained CNN (MobileNet) to compute features for reference images and stored the feature vectors in a "database". Query images go through the exact same network and the resulting feature vector is used for nearest neighbor retrieval in the DB.

from glob import glob

import torch from PIL import Image from numpy.linalg import norm from torchvision import transforms from torchvision.models import mobilenet_v2

model = mobilenet_v2(pretrained=True) model.eval()

preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])

Generate DB

db = {} dp_paths = glob('db/*.jpg') for path in dp_paths: image = preprocess(Image.open(path)).unsqueeze(0) with torch.no_grad(): output = model(image) db[output] = path

Query

image = preprocess(Image.open('queries/box.jpg')).unsqueeze(0) with torch.no_grad(): query = model(image)

Nearest Neighbor (poor man's version)

min_distance = float('inf') candidate = None for k, v in db.items(): distance = norm(k.numpy() - query.numpy()) if distance < min_distance: min_distance = distance candidate = v

print(candidate, min_distance)

At least with my 5 test reference images and several query images it worked without a single failed "classification". However, I am not sure if it will stand up to a larger test...

sonovice
  • 111
  • 3