Skip to content

Commit

Permalink
added dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
joey00072 committed Dec 29, 2024
1 parent f182d34 commit d71f3a4
Showing 1 changed file with 103 additions and 54 deletions.
157 changes: 103 additions & 54 deletions ohara/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,117 @@
import torch.nn.functional as F
import torch.optim as optim

from torch import Tensor

## TODO WR
class Network(nn.Module):
def __init__(self, dim, hdim):
super().__init__()
self.up = nn.Linear(dim, dim)
self.down = nn.Linear(dim, dim)
#######################################################################################################
# https://arxiv.org/pdf/2305.18290
# ... eqn 7
# dpo_loss = - log( sigmoid( beta * ( log(pi_win/ref_win) - log(pi_lose /ref_lose) ) ) )
#
# remember log property: log(x/y) = log(x) - log(y)
# lets start witn sub eq,
# = log(pi_win/ref_win) - log(pi_lose /ref_lose)
# = log( (pi_win/ref_win) / ( pi_lose /ref_lose) )
# = log( (pi_win/ref_win) * ( ref_lose / pi_lose) )
# = log( (pi_win/ ref_win) * ( ref_lose / ref_win) )
# = log( (pi_win/ ref_win) / ( ref_win/ ref_lose) )
# = log(pi_win/ ref_win) - log(ref_win/ ref_lose)
# = (log(pi_win) - log(ref_win) ) - (log(ref_win) - log(ref_lose))
#
# so now we have
# logits = win_logprop - lose_logprop
# where:
# win_logprop = log(pi_win) - log(ref_win)
# lose_logprop = log(ref_win) - log(ref_lose)
#
# and eqn is
# dpo_loss = - log(sigmoid(beta * logits))

def forward(self, x):
x = self.up(x)
x = F.silu(x)
return self.down(x)

def dpo_loss(
pi_logps:Tensor,
ref_logps:Tensor,
win_output_idxs:Tensor,
lose_output_idxs:Tensor,
beta:float
) -> tuple[Tensor, Tensor]:
"""
paper: https://arxiv.org/pdf/2305.18290.pdf
"""

def log_prob(logits, labels):
return torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
# extracting only outputs log probabilities
pi_win_logps, pi_lose_logps = pi_logps[win_output_idxs], pi_logps[lose_output_idxs]
ref_win_logps, ref_lose_logps = ref_logps[win_output_idxs], ref_logps[lose_output_idxs]

# log(a/b) = log(a) - log(b)

pi_logratios = pi_win_logps - pi_lose_logps
ref_logratios = ref_win_logps - ref_lose_logps

logits = pi_logratios - ref_logratios

# Dpo loss = - log( sigmoid(beta)
losses: Tensor = -F.logsigmoid(beta * logits)
rewards: Tensor = beta * (pi_logps - ref_logps).detach()
return losses, rewards


dim = 10
seq = 3
batch = 1
def cdpo_loss(
pi_logps:Tensor,
ref_logps:Tensor,
win_output_idxs:Tensor,
lose_output_idxs:Tensor,
beta:float,
label_smoothing=0.2,
) -> tuple[Tensor, Tensor]:
"""
paper:https://ericmitchell.ai/cdpo.pdf
"""

policy = Network(dim, dim)
ref_model = Network(dim, dim)
ref_model.load_state_dict(policy.state_dict())
ref_model = ref_model.eval()
# extracting only outputs log probabilities
pi_win_logps, pi_lose_logps = pi_logps[win_output_idxs], pi_logps[lose_output_idxs]
ref_win_logps, ref_lose_logps = ref_logps[win_output_idxs], ref_logps[lose_output_idxs]

# log(a/b) = log(a) - log(b)

pi_logratios = pi_win_logps - pi_lose_logps
ref_logratios = ref_win_logps - ref_lose_logps

logits = pi_logratios - ref_logratios

# Dpo loss = - log( sigmoid(beta)
losses: Tensor = (
-F.logsigmoid(beta * logits) * (1 - label_smoothing)
- F.logsigmoid(-beta * logits) * label_smoothing
)
rewards: Tensor = beta * (pi_logps - ref_logps).detach()
return losses, rewards

optimizer = optim.AdamW(policy.parameters(), lr=1e-2)
labels = torch.tensor([1, 2, 3]).unsqueeze(0)
inputs = torch.rand((batch, seq, dim))

for _idx in range(100):
logits = policy(inputs)
pi = log_prob(logits, labels).sum()
with torch.no_grad():
logits = ref_model(inputs)
ref = log_prob(logits, labels).sum()
def ipo_loss(
pi_logps:Tensor,
ref_logps:Tensor,
win_output_idxs:Tensor,
lose_output_idxs:Tensor,
beta:float
) -> tuple[Tensor, Tensor]:
"""
paper: https://arxiv.org/pdf/2310.12036v2.pdf
"""

# print(pi, ref, pi - ref)
loss = F.sigmoid(pi - ref)
print(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# extracting only outputs log probabilities
pi_win_logps, pi_lose_logps = pi_logps[win_output_idxs], pi_logps[lose_output_idxs]
ref_win_logps, ref_lose_logps = ref_logps[win_output_idxs], ref_logps[lose_output_idxs]

# log(a/b) = log(a) - log(b)

pi_logratios = pi_win_logps - pi_lose_logps
ref_logratios = ref_win_logps - ref_lose_logps

logits = pi_logratios - ref_logratios

# Dpo loss = - log( sigmoid(beta)
losses: Tensor = (logits - 1/(2 * beta)) ** 2
rewards: Tensor = beta * (pi_logps - ref_logps).detach()
return losses, rewards


class MoE(nn.Module):
def __init__(self, input_dim, output_dim, num_experts, hidden_dim):
super(MoE, self).__init__()
self.num_experts = num_experts
self.experts = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim)
)
for _ in range(num_experts)
]
)
self.gate = nn.Linear(input_dim, num_experts)

def forward(self, x):
gate_outputs = F.softmax(self.gate(x), dim=-1)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
output = torch.sum(gate_outputs.unsqueeze(-1) * expert_outputs, dim=1)
return output

0 comments on commit d71f3a4

Please sign in to comment.