Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How does the gradient back-propagate from Q to the action $a_i$? #26

Open
xihuai18 opened this issue Aug 10, 2020 · 2 comments
Open

Comments

@xihuai18
Copy link

xihuai18 commented Aug 10, 2020

I wonder how the gradient back propagate from Q to $a_i$.
Trace from Q:

MAAC/utils/critics.py

Lines 149 to 150 in 105d60e

all_q = self.critics[a_i](critic_in)
int_acs = actions[a_i].max(dim=1, keepdim=True)[1]

Then trace critic_in:
critic_in = torch.cat((s_encodings[i], *other_all_values[i]), dim=1)

Since s_encoding doesn't contain input from $a_i$, I then trace other_all_values[i]:

MAAC/utils/critics.py

Lines 125 to 141 in 105d60e

for curr_head_keys, curr_head_values, curr_head_selectors in zip(
all_head_keys, all_head_values, all_head_selectors):
# iterate over agents
for i, a_i, selector in zip(range(len(agents)), agents, curr_head_selectors):
keys = [k for j, k in enumerate(curr_head_keys) if j != a_i]
values = [v for j, v in enumerate(curr_head_values) if j != a_i]
# calculate attention across agents
attend_logits = torch.matmul(selector.view(selector.shape[0], 1, -1),
torch.stack(keys).permute(1, 2, 0))
# scale dot-products by size of key (from Attention is All You Need)
scaled_attend_logits = attend_logits / np.sqrt(keys[0].shape[1])
attend_weights = F.softmax(scaled_attend_logits, dim=2)
other_values = (torch.stack(values).permute(1, 2, 0) *
attend_weights).sum(dim=2)
other_all_values[i].append(other_values)
all_attend_logits[i].append(attend_logits)
all_attend_probs[i].append(attend_weights)

keys and values don't contain agent i's action as input, and selector uses only observations as input:

MAAC/utils/critics.py

Lines 118 to 119 in 105d60e

all_head_selectors = [[sel_ext(enc) for i, enc in enumerate(s_encodings) if i in agents]
for sel_ext in self.selector_extractors]

So, is there gradient from Q to action $a_i$?

@DokinCui
Copy link

keys and values contain the agent i's action since their input is "sa_encoding", but the selector uses only observations as input, I can't understand.
And for the function of "s_encoding", I also can't understand, because only "sa_encoding" is used in the paper, but not "s_encoding".

@zhl606
Copy link

zhl606 commented Aug 22, 2023

keys and values contain the agent i's action since their input is "sa_encoding", but the selector uses only observations as input, I can't understand. And for the function of "s_encoding", I also can't understand, because only "sa_encoding" is used in the paper, but not "s_encoding".

I also have the same question, have you understood it? And what I want to know is, in the PPO algorithm, when estimating the advantage function, do we only need input state information and not action information, so can we use s_encoding without using sa_encoding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants