0

I am trying to implement MADDPG from scratch i finished the code but after 3000 episodes I still don't see any improvement on the behaviour of the agent could someone please help me. Thank you in advance ! Here is the code of my train function.

s : is the observation
a : is the action taken
ps : is the next observation
r : is the reward
d : is the state of the agents (True or False)

maddpg algorithm from the paper

    s, a, r, ps, d = buffer.sample(MINI_BATCH)
s = np.asarray(s)
a = np.asarray(a)
ps = np.asarray(ps)
# we store the observation of the mini_batch in a dict this way it is easier to compute the target actor 
# since the target actor only depends on it's associated observation and not the observation of other agent
ps_by_agent = defaultdict(list)
r_by_agent = defaultdict(list)
s_by_agent = defaultdict(list)
a_by_agent = defaultdict(list)
d_by_agent = defaultdict(list)

for i in range(MINI_BATCH):
    for j in range(len(agents)):
        ps_by_agent[str(j)].append(ps[i][j])
        r_by_agent[str(j)].append(r[i][j])
        s_by_agent[str(j)].append(s[i][j])
        a_by_agent[str(j)].append(a[i][j])
        d_by_agent[str(j)].append(d[i][j])

if DEBUG:
    print('aaa', np.asarray(ps_by_agent['0']).shape)
    print('b', np.asarray(r_by_agent['0']).shape)
    print('c', np.asarray(s_by_agent['0']).shape)
    print('d', np.asarray(a_by_agent['0']).shape)
    print('e', np.asarray(d_by_agent['0']).shape)

for i in range(len(agents)):
    target_action = []
    for k, agent in enumerate(agents):
        ps_ = torch.tensor(np.asarray(ps_by_agent[str(k)]),dtype=torch.float).to(device)
        target_action.append(agent.target_actor.forward(ps_))

    #ps_ = torch.tensor(ps_by_agent[str(i)]).to(device)
    ta_ = np.asarray([t.detach().cpu().numpy() for t in target_action])

    if DEBUG:
        print(ta_.shape)
        #print(ps_.shape)
        print(torch.tensor(ps).shape)

    tc = agents[i].target_critic.forward(
        torch.tensor(ps,dtype=torch.float).to(device),
        torch.reshape(torch.tensor(ta_,dtype=torch.float).to(device),(ps.shape[0],ps.shape[1],ACTION_SHAPE))
        )
    r_i = torch.tensor(r_by_agent[str(i)]).unsqueeze(1).unsqueeze(2).to(device)

    if DEBUG :
        print("r: ",r_i.shape)
        print("tc : ",tc.shape)

    # compute y_j the target we manipute the tensor to make sure they are the same dimension so they can add up 
    y_j = r_i+ gamma * tc

    if DEBUG:
        print(y_j.shape)
        print("state: ",torch.tensor(np.asarray(s)).shape)
        print("action : ",torch.tensor(np.asarray(a)).shape)


    q_i = agents[i].critic.forward(torch.tensor(s,dtype=torch.float).to(device),torch.tensor(a,dtype=torch.float).to(device))

    # compute the gradient for the critic
    critic_loss = F.mse_loss(y_j.to(device), q_i.to(device))
    agents[i].critic.optimizer.zero_grad()
    critic_loss.backward()
    agents[i].critic.optimizer.step()

    # compute the gradient for the actor
    action = []
    for k, agent in enumerate(agents):
        s_ = torch.tensor(np.asarray(s_by_agent[str(k)]),dtype=torch.float).to(device)
        action.append(agent.actor.forward(s_))

    a_ = np.asarray([t.detach().cpu().numpy() for t in action])
    c = agents[i].critic.forward(
        torch.tensor(s,dtype=torch.float).to(device),
        torch.reshape(torch.tensor(a_,dtype=torch.float).to(device),(s.shape[0],s.shape[1],ACTION_SHAPE))
        )
    actor_loss = -torch.mean(c)
    agents[i].actor.optimizer.zero_grad()
    actor_loss.backward(retain_graph=True)
    agents[i].actor.optimizer.step()

# soft upadate of the target critic and target actor
# Update the target networks for each agent
for i in range(len(agents)):
    for target_param, source_param in zip(agents[i].target_actor.parameters(), agents[i].actor.parameters()):
        target_param.data.copy_(TAU * source_param.data + (1.0 - TAU) * target_param.data)
    for target_param, source_param in zip(agents[i].target_critic.parameters(), agents[i].critic.parameters()):
        target_param.data.copy_(TAU * source_param.data + (1.0 - TAU) * target_param.data)

0 Answers0