diff --git a/ohara/dpo.py b/ohara/dpo.py index 26ece66..2f1741d 100644 --- a/ohara/dpo.py +++ b/ohara/dpo.py @@ -3,68 +3,117 @@ import torch.nn.functional as F import torch.optim as optim +from torch import Tensor -## TODO WR -class Network(nn.Module): - def __init__(self, dim, hdim): - super().__init__() - self.up = nn.Linear(dim, dim) - self.down = nn.Linear(dim, dim) +####################################################################################################### +# https://arxiv.org/pdf/2305.18290 +# ... eqn 7 +# dpo_loss = - log( sigmoid( beta * ( log(pi_win/ref_win) - log(pi_lose /ref_lose) ) ) ) +# +# remember log property: log(x/y) = log(x) - log(y) +# lets start witn sub eq, +# = log(pi_win/ref_win) - log(pi_lose /ref_lose) +# = log( (pi_win/ref_win) / ( pi_lose /ref_lose) ) +# = log( (pi_win/ref_win) * ( ref_lose / pi_lose) ) +# = log( (pi_win/ ref_win) * ( ref_lose / ref_win) ) +# = log( (pi_win/ ref_win) / ( ref_win/ ref_lose) ) +# = log(pi_win/ ref_win) - log(ref_win/ ref_lose) +# = (log(pi_win) - log(ref_win) ) - (log(ref_win) - log(ref_lose)) +# +# so now we have +# logits = win_logprop - lose_logprop +# where: +# win_logprop = log(pi_win) - log(ref_win) +# lose_logprop = log(ref_win) - log(ref_lose) +# +# and eqn is +# dpo_loss = - log(sigmoid(beta * logits)) - def forward(self, x): - x = self.up(x) - x = F.silu(x) - return self.down(x) +def dpo_loss( + pi_logps:Tensor, + ref_logps:Tensor, + win_output_idxs:Tensor, + lose_output_idxs:Tensor, + beta:float + ) -> tuple[Tensor, Tensor]: + """ + paper: https://arxiv.org/pdf/2305.18290.pdf + """ -def log_prob(logits, labels): - return torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + # extracting only outputs log probabilities + pi_win_logps, pi_lose_logps = pi_logps[win_output_idxs], pi_logps[lose_output_idxs] + ref_win_logps, ref_lose_logps = ref_logps[win_output_idxs], ref_logps[lose_output_idxs] + + # log(a/b) = log(a) - log(b) + + pi_logratios = pi_win_logps - pi_lose_logps + ref_logratios = ref_win_logps - ref_lose_logps + + logits = pi_logratios - ref_logratios + + # Dpo loss = - log( sigmoid(beta) + losses: Tensor = -F.logsigmoid(beta * logits) + rewards: Tensor = beta * (pi_logps - ref_logps).detach() + return losses, rewards -dim = 10 -seq = 3 -batch = 1 +def cdpo_loss( + pi_logps:Tensor, + ref_logps:Tensor, + win_output_idxs:Tensor, + lose_output_idxs:Tensor, + beta:float, + label_smoothing=0.2, + ) -> tuple[Tensor, Tensor]: + """ + paper:https://ericmitchell.ai/cdpo.pdf + """ -policy = Network(dim, dim) -ref_model = Network(dim, dim) -ref_model.load_state_dict(policy.state_dict()) -ref_model = ref_model.eval() + # extracting only outputs log probabilities + pi_win_logps, pi_lose_logps = pi_logps[win_output_idxs], pi_logps[lose_output_idxs] + ref_win_logps, ref_lose_logps = ref_logps[win_output_idxs], ref_logps[lose_output_idxs] + + # log(a/b) = log(a) - log(b) + + pi_logratios = pi_win_logps - pi_lose_logps + ref_logratios = ref_win_logps - ref_lose_logps + + logits = pi_logratios - ref_logratios + + # Dpo loss = - log( sigmoid(beta) + losses: Tensor = ( + -F.logsigmoid(beta * logits) * (1 - label_smoothing) + - F.logsigmoid(-beta * logits) * label_smoothing + ) + rewards: Tensor = beta * (pi_logps - ref_logps).detach() + return losses, rewards -optimizer = optim.AdamW(policy.parameters(), lr=1e-2) -labels = torch.tensor([1, 2, 3]).unsqueeze(0) -inputs = torch.rand((batch, seq, dim)) -for _idx in range(100): - logits = policy(inputs) - pi = log_prob(logits, labels).sum() - with torch.no_grad(): - logits = ref_model(inputs) - ref = log_prob(logits, labels).sum() +def ipo_loss( + pi_logps:Tensor, + ref_logps:Tensor, + win_output_idxs:Tensor, + lose_output_idxs:Tensor, + beta:float + ) -> tuple[Tensor, Tensor]: + """ + paper: https://arxiv.org/pdf/2310.12036v2.pdf + """ - # print(pi, ref, pi - ref) - loss = F.sigmoid(pi - ref) - print(loss) - loss.backward() - optimizer.step() - optimizer.zero_grad() + # extracting only outputs log probabilities + pi_win_logps, pi_lose_logps = pi_logps[win_output_idxs], pi_logps[lose_output_idxs] + ref_win_logps, ref_lose_logps = ref_logps[win_output_idxs], ref_logps[lose_output_idxs] + + # log(a/b) = log(a) - log(b) + + pi_logratios = pi_win_logps - pi_lose_logps + ref_logratios = ref_win_logps - ref_lose_logps + + logits = pi_logratios - ref_logratios + + # Dpo loss = - log( sigmoid(beta) + losses: Tensor = (logits - 1/(2 * beta)) ** 2 + rewards: Tensor = beta * (pi_logps - ref_logps).detach() + return losses, rewards - -class MoE(nn.Module): - def __init__(self, input_dim, output_dim, num_experts, hidden_dim): - super(MoE, self).__init__() - self.num_experts = num_experts - self.experts = nn.ModuleList( - [ - nn.Sequential( - nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim) - ) - for _ in range(num_experts) - ] - ) - self.gate = nn.Linear(input_dim, num_experts) - - def forward(self, x): - gate_outputs = F.softmax(self.gate(x), dim=-1) - expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) - output = torch.sum(gate_outputs.unsqueeze(-1) * expert_outputs, dim=1) - return output