How can I guide my RL agent to solve tasks in the correct order?
I'm trying to train an agent using reinforcement learning, similar to MuZero. The goal is to solve 4 tasks, A/B/C/D. Each task involves two actions, X1/X2. Initially, only action A1 is possible. Performing it enables A2 and B1; B1 enables B1 and C1; C1 enables C2 and D1; D1 enables D2. There are also some other actions that aren't listed here - they all prevent the successful termination of at least one of the tasks. Crucially, performing any action X2 before D1 is done will permanently disable D1. So the only winning policy is A1>B1>C1>D1>{A2, B2, C2, D2 in any order}
To guide the training, I've defined small rewards for solving each task, and a large one for solving all four. I've set the discounting factor to 1.0, so there should be no downside to not immediately doing some X2 step.
Unfortunately, the agent never learns to do all A1-D1 steps first. It manages to do A1 and B1 first, sometimes even C1, but never all 4. It then completes some task A-C and thus fails to solve task D.
I noticed that my policy network learns to assign a prior probability of ~1 to A1>A2, strongly biasing the MCTS to explore this path. I assume the reason is that by random exploration the agent will often do A1>A2>... and get a reward. But A1>B1>C1>D1>A2 is way less likely. So the network is well-trained in predicting the reward for A1>A2, but less well-trained for the longer game sequences. If it cannot predict the reward there, the MCTS is biased to do A1>A2. This causes the policy network to also learn that this action should have a high prior probability, further biasing the MCTS, creating a feedback loop.
How can this be prevented? Would it be enough to increase the temperature? If so, should I only increase the temperature for sampling from the distribution computed by the MCTS, or should I also increase the $c_1$ in the UCB formula? That should increase the exploring within MCTS, and reduce the focus on the rewards.
I guess I could also remove the intermediate rewards, but then my signal will be extremely sparse.
I know that I could enforce the order of tasks using the environment, but I want the make the agent learn this by itself.
