Let's say I'm training a transformer model to perform a seq to seq task, but there are multiple correct answers. For example, the following outputs would all be considered correct:
source: A B C -> target: C B D
source: A B C -> target: C D E B E
...
source: A B C -> target: D E C B E
The way I'm currently handling this is to augment the dataset by duplicating all the data and randomly assigning correct targets to each instance of each input. This works fairly well and has been tried before in previous papers, however I'm wondering if there is a way to modify the loss function to measure the minimum cross entropy loss to all possible answers. I could loop over them and take the min but with batching the complexity would quickly scale to be an issue. If there's any other ideas I could use, or ways to implement taking the min in a memory/time efficient way, I would appreciate the help!