Skip to content

Commit

Permalink
back
Browse files Browse the repository at this point in the history
  • Loading branch information
filippo-merlo committed Feb 6, 2024
1 parent 95c6e96 commit c8fcc8c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
# train parameters
resize = 224
lr = 1e-3
epochs = 1#5
epochs = 5

sim_batch = 128
gen_batch = 128
Expand Down
56 changes: 28 additions & 28 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,35 @@ def my_train_clip_encoder(dt, memory, attr, lesson):
loss = 10
ct = 0
centroid_sim = torch.rand(1, latent_dim).to(device)
#while loss > 0.008:
##
ct += 1
#if ct > 5:
# break
for i in range(1):#200):
# Get Inputs: sim_batch, (sim_batch, 4, 128, 128)
base_name_sim, images_sim = dt.get_better_similar(attr, lesson)
images_sim = images_sim.to(device)

# run similar model
z_sim = model(clip_model, images_sim)
centroid_sim = centroid_sim.detach()
centroid_sim, loss_sim = get_sim_loss(torch.vstack((z_sim, centroid_sim)))

# Run Difference
base_name_dif, images_dif = dt.get_better_similar_not(attr, lesson)
images_dif = images_dif.to(device)

# run difference model
z_dif = model(clip_model, images_dif)
loss_dif = get_sim_not_loss(centroid_sim, z_dif)

# compute loss
loss = (loss_sim)**2 + (loss_dif-1)**2
optimizer.zero_grad()
loss.backward()
optimizer.step()
while loss > 0.008:
##
ct += 1
if ct > 5:
break
for i in range(200):
# Get Inputs: sim_batch, (sim_batch, 4, 128, 128)
base_name_sim, images_sim = dt.get_better_similar(attr, lesson)
images_sim = images_sim.to(device)

# run similar model
z_sim = model(clip_model, images_sim)
centroid_sim = centroid_sim.detach()
centroid_sim, loss_sim = get_sim_loss(torch.vstack((z_sim, centroid_sim)))

# Run Difference
base_name_dif, images_dif = dt.get_better_similar_not(attr, lesson)
images_dif = images_dif.to(device)

# run difference model
z_dif = model(clip_model, images_dif)
loss_dif = get_sim_not_loss(centroid_sim, z_dif)

# compute loss
loss = (loss_sim)**2 + (loss_dif-1)**2
optimizer.zero_grad()
loss.backward()
optimizer.step()
##
print('[', ct, ']', loss.detach().item(), loss_sim.detach().item(),
loss_dif.detach().item())

Expand Down

0 comments on commit c8fcc8c

Please sign in to comment.