-
Notifications
You must be signed in to change notification settings - Fork 2
/
clip_model.py
29 lines (24 loc) · 1.29 KB
/
clip_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import clip
import torch
import torch.nn as nn
class CLIPEncoder(nn.Module):
def __init__(self, device):
super(CLIPEncoder, self).__init__()
self.encoder, _ = clip.load("ViT-B/16", device=device) # loads already in eval mode
def forward(self, x):
"""
Expects a tensor of size (batch_size, 3, 224, 224)
"""
with torch.no_grad():
x = x.type(self.encoder.visual.conv1.weight.dtype)
x = self.encoder.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.encoder.visual.positional_embedding.to(x.dtype)
x = self.encoder.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.encoder.visual.transformer(x)
grid_feats = x.permute(1, 0, 2) # LND -> NLD (N, 197, 768)
grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])
return grid_feats.float()