Skip to content

Commit

Permalink
Merge pull request #14 from KR-HappyFace/dev
Browse files Browse the repository at this point in the history
Merge Dev
  • Loading branch information
snoop2head authored Dec 26, 2021
2 parents e528c4e + f14ca5a commit 0a93f4e
Show file tree
Hide file tree
Showing 13 changed files with 1,227 additions and 108 deletions.
144 changes: 143 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,143 @@
models
models

# dall-e generation outputs
outputs/
*.pt
taming/
wandb/
dalle-ds-cp/
*.out
/.github
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.ipynb
# C extensions
*.so
*.yaml
VQGAN_blue_e7
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# Visual Studio Code
.vscode

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
1 change: 1 addition & 0 deletions clip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
### CLIP
79 changes: 79 additions & 0 deletions clip/clipmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch.nn as nn
from transformers import RobertaModel, RobertaConfig
import timm


class ImageEncoder(nn.Module):
def __init__(self, model_name, pretrained):
super().__init__()
self.model = timm.create_model(
model_name, pretrained=pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = True

def forward(self, x):
return self.model(x)


class TextEncoder(nn.Module):
def __init__(self, pretrained):
super().__init__()
if pretrained:
self.model = RobertaModel.from_pretrained("klue/roberta-base")
else:
config = RobertaConfig.from_pretrained("klue/roberta-base")
self.model = RobertaModel(config)

for p in self.model.parameters():
p.requires_grad = True
self.target_token_idx = 0

def forward(self, input_ids, token_type_ids, attention_mask):
output = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]


class ProjectionHead(nn.Module):
def __init__(self, embedding_dim, projection_dim):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)

def forward(self, x):
projected = self.projection(x)
return projected


class CLIPModel(nn.Module):
def __init__(
self,
):
super().__init__()
self.image_encoder = ImageEncoder("efficientnet_b0", pretrained=False)
self.text_encoder = TextEncoder(pretrained=True)
image_embedding_dim = list(self.image_encoder.parameters())[-1].shape[0]
text_embedding_dim = list(self.text_encoder.parameters())[-1].shape[0]
self.image_projection = ProjectionHead(
embedding_dim=image_embedding_dim, projection_dim=512
)
self.text_projection = ProjectionHead(
embedding_dim=text_embedding_dim, projection_dim=512
)

def forward(self, text, image):
image_features = self.image_encoder(image)
text_features = self.text_encoder(
input_ids=text["input_ids"],
attention_mask=text["attention_mask"],
token_type_ids=text["token_type_ids"],
)

image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)

return text_embeddings, image_embeddings
41 changes: 41 additions & 0 deletions clip/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from tqdm import tqdm
import re


class CLIPDataset(Dataset):
def __init__(self, texts, images):
self.texts = texts
self.images = images
self.transform = A.Compose([A.Resize(224, 224), ToTensorV2()])

def __getitem__(self, index):
t = self.texts[index]
single_im = cv2.imread(self.images[index])
single_im = cv2.cvtColor(single_im, cv2.COLOR_BGR2RGB)
im = self.transform(image=single_im)["image"]
return t, im

def __len__(self):
return len(self.texts)


def get_dataset(text_path, image_path):
image_files = [
*image_path.glob("**/*[0-9].png"),
*image_path.glob("**/*[0-9].jpg"),
*image_path.glob("**/*[0-9].jpeg"),
]
text_files = [*text_path.glob("**/*[0-9].txt")]
texts = []
print("Extracting text information!")
for i in tqdm(range(len(text_files))):
with open(text_files[i], "r", encoding="utf-8") as f:
te = f.read()
te = re.sub("์Šคํƒ€์ผ์—์„œ ์Šคํƒ€์ผ์€ [๊ฐ€-ํžฃ]+.", "", te)
te = re.sub("์—์„œ", "", te)
texts.append(te)
return texts, image_files
Loading

0 comments on commit 0a93f4e

Please sign in to comment.