im trying to implement the MAML algorithm in the Reinforcement Learning domain but am not achieving fast adaptation to my validation tasks.
I assume that something may be wrong with my meta loss computation, and how im calculating the second order gradient. Can someone spot an error in how im implementing this?
class MAMLMetaLearner():
def __init__(self, envs, tasks=5, metaLr=1e-2, metaSteps=100, innerTimeSteps=500, innerLr=5e-2, innerRollouts=20):
self.envs = envs
self.env = None
self.metaLr = metaLr
self.innerTimeSteps = innerTimeSteps
self.innerLr = innerLr # inner learning rate is generally higher than the meta learning rate
self.metaSteps = metaSteps
self.tasks = tasks # number of tasks to train on per meta iteration
self.innerRollouts = innerRollouts # number of rollouts for inner loop training
self.metaAgent = PPO("MLP", self.envs[0], lrAnealling=False, progressBar=False) # meta agent
self.metaOptimizer = torch.optim.Adam(self.metaAgent.actorCritic.parameters(), lr=metaLr)
self.metaLoss = []
train the agent on a single task
def ComputeTaskLoss(self, env):
self.metaAgent.env = env # set the environment for inner loop learning
# adapt the task agent parameters using the task model
state = None
# in order to not break the computational graph I will create a dictionary of the inner loop parameters
adaptedParametersDictionary = {name: parameter for (name, parameter) in self.metaAgent.actorCritic.named_parameters()}
for step in range(self.innerRollouts):
state, states, actions, oldLogProbabilities, returns, advantages, _, _ = self.metaAgent.CollectTrajectories(totalTimesteps=self.innerTimeSteps, state=state, parameters = adaptedParametersDictionary, batchSize=self.innerTimeSteps) # collect trajectories using the task agent
if len(advantages) > 1:
advantages = (advantages - advantages.mean()) / torch.max(advantages.std(), torch.tensor(1e-9, device=self.metaAgent.device))
if self.metaAgent.discreteActionSpace:
actionLogits, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
newProbabilityDistribution = Categorical(logits=actionLogits) # using categorical is good because it provides functions to sample and calculate log probabilities
newLogProbabilities = actionDistribution.log_prob(actions)
else:
mean, standardDeviation, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
newProbabilityDistribution = torch.distributions.Normal(mean, standardDeviation)
newLogProbabilities = (newProbabilityDistribution.log_prob(actions)).sum(dim=-1)
probabilityRatio = torch.exp(newLogProbabilities - oldLogProbabilities) # this is a measure of how much the policy has changed from the old policy
# compute loss
loss, _, _, _ = self.metaAgent.ComputeLoss(values, returns, newProbabilityDistribution, probabilityRatio, advantages)
self.metaAgent.optimizer.zero_grad() # zero the gradients
# compute gradients with respect to adapted parameters
gradients = torch.autograd.grad(loss, adaptedParametersDictionary.values(), create_graph=True)
# update the adapted parameters using the gradients
adaptedParametersDictionary = {name: parameter - self.innerLr * gradient for ((name, parameter), gradient) in zip(adaptedParametersDictionary.items(), gradients)}
# compute loss after adaptation
_, states, actions, newLogProbabilities, returns, advantages, _, _ = self.metaAgent.CollectTrajectories(totalTimesteps=self.innerTimeSteps, parameters=adaptedParametersDictionary, batchSize=self.innerTimeSteps) # collect trajectories using the updated task agent
if len(advantages) > 1:
advantages = (advantages - advantages.mean()) / torch.max(advantages.std(), torch.tensor(1e-9, device=self.metaAgent.device))
if self.metaAgent.actorCritic.discreteActionSpace:
actionLogits, _ = self.metaAgent.actorCritic(states)
actionDistribution = Categorical(logits=actionLogits) # using categorical is good because it provides functions to sample and calculate log probabilities
oldLogProbabilities = actionDistribution.log_prob(actions)
else:
mean, standardDeviation, _ = self.metaAgent.actorCritic(states)
oldProbabilityDistribution = torch.distributions.Normal(mean, standardDeviation)
oldLogProbabilities = (oldProbabilityDistribution.log_prob(actions)).sum(dim=-1)
# calculate newProbabilityDistribution
if self.metaAgent.actorCritic.discreteActionSpace:
actionLogits, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
newProbabilityDistribution = Categorical(logits=actionLogits)
else:
mean, standardDeviation, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
newProbabilityDistribution = torch.distributions.Normal(mean, standardDeviation)
probabilityRatio = torch.exp(newLogProbabilities - oldLogProbabilities) # this is a measure of how much the policy has changed from the original policy
adaptedLoss, _, _, _ = self.metaAgent.ComputeLoss(values, returns, newProbabilityDistribution, probabilityRatio, advantages)
return adaptedLoss
def MetaTrain(self):
print("Meta Training Started...")
count = 0
smallestLoss = -float("inf")
for metaIteration in tqdm(range(self.metaSteps), desc="Meta Training Progress", unit = "iterations"):
metaLoss = 0.0
tasks = [random.choice(self.envs) for task in range(self.tasks)] # randomly select tasks from the environment list
for task in tasks:
taskLoss = self.ComputeTaskLoss(task) # inner loop learning
metaLoss = metaLoss + taskLoss
metaLoss = metaLoss / self.tasks
# meta update
self.metaOptimizer.zero_grad() # zero the gradients
torch.nn.utils.clip_grad_norm_(self.metaAgent.actorCritic.parameters(), 0.5) # clip the gradients
# apply the meta gradients
metaGradients = torch.autograd.grad(metaLoss, self.metaAgent.actorCritic.parameters()) # compute gradients
for parameter, gradient in zip(self.metaAgent.actorCritic.parameters(), metaGradients):
parameter.grad = gradient
self.metaOptimizer.step()
count += 1
print(f"Iteration {len(self.metaLoss)} , Loss: {metaLoss.item()}. {count} training iterations completed this session")
print("Meta Training Complete")
Please note that the ForwardPass function used is as follows:
def ForwardPass(self, state, parameters):
if parameters is None:
return self.actorCritic(state)
# if the parameters are not None, we are using the network for MAML and need to pass the adapted parameters to the forward function
else:
return func.functional_call(self.actorCritic, parameters, state)
The ComputeLoss function computes PPO loss.
I was using Meta world's ML10 benchmark for my training and testing set. I used a task distribution of 50 tasks from ML10.