Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of unsupervised CNN using forward forward algorithm #196

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8c987c2
Update README.md
cyclotomicextension Jan 16, 2023
05667f0
Update README.md
cyclotomicextension Jan 16, 2023
be1cd24
Merge pull request #1 from cyclotomicextension/cyclotomicextension-pa…
cyclotomicextension Jan 16, 2023
6f41a1b
Add files via upload
cyclotomicextension Jan 17, 2023
97646bd
Update root_op.py
cyclotomicextension Feb 6, 2023
5b87ee3
Update functions.py
cyclotomicextension Feb 6, 2023
9834b2f
Update modules.py
cyclotomicextension Feb 6, 2023
40173cc
Update build_models.py
cyclotomicextension Feb 6, 2023
a17312a
Update trainers.py
cyclotomicextension Feb 6, 2023
5ad6249
Update trainers.py
cyclotomicextension Feb 6, 2023
337cc4b
Update trainers.py
cyclotomicextension Feb 11, 2023
298faed
Delete ffa.py
cyclotomicextension Feb 11, 2023
8d27ab1
Delete ffa2.py
cyclotomicextension Feb 11, 2023
a239670
Update functions.py
cyclotomicextension Feb 11, 2023
80d68fb
Update apps/accelerate/forward_forward/README.md
cyclotomicextension Feb 12, 2023
d8293b5
Update trainers.py
cyclotomicextension Feb 12, 2023
77a8dd4
Update trainers.py
cyclotomicextension Feb 12, 2023
03d365d
Update trainers.py
cyclotomicextension Feb 12, 2023
8291c75
Update trainers.py
cyclotomicextension Feb 12, 2023
1b3d8df
Update trainers.py
cyclotomicextension Feb 18, 2023
e739f3a
Update trainers.py
cyclotomicextension Feb 18, 2023
e8e463c
Update trainers.py
cyclotomicextension Feb 18, 2023
89766d4
Update apps/accelerate/forward_forward/forward_forward/operations/tra…
cyclotomicextension Feb 24, 2023
385996d
Update benchmark.py
cyclotomicextension Feb 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion apps/accelerate/forward_forward/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,20 @@ This process will just install the minimum requirements for running the module.
At the current stage, this implementation supports the main architectures discussed by Hinton in his paper. Each architecture can be trained with the following command:

```python
import os

import torch

from forward_forward import train_with_forward_forward_algorithm

device = "cuda" if torch.cuda.is_available() else "cpu"

trained_model = train_with_forward_forward_algorithm(
model_type="progressive",
n_layers=3,
hidden_size=2000,
lr=0.03,
device="cuda",
device=device,
epochs=100,
batch_size=5000,
theta=2.,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def train_with_forward_forward_algorithm(
elif model_type is ForwardForwardModelType.RECURRENT:
input_size = 28 * 28
output_size = len(datasets.MNIST.classes)
elif model_type is ForwardForwardModelType.CNN:
input_size = (28,28,1)
output_size = len(datasets.MNIST.classes)
else: # model_type is ForwardForwardModelType.NLP
input_size = 10 # number of characters
output_size = 30 # length of vocabulary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FCNetFFProgressive,
RecurrentFCNetFF,
LMFFNet,
ConvFFLayer,
)


Expand Down Expand Up @@ -112,3 +113,30 @@ def execute(
predicted_tokens=-1,
)
self.model = model

class CNNBuildOperation(BaseModelBuildOperation):
def __init__(self):
super().__init__()

def execute(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int,
optimizer_name: str,
optimizer_params: dict,
loss_fn_name: str,
):
model = ConvFFLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
optimizer_name=optimizer_name,
optimizer_kwargs=optimizer_params,
loss_fn_name=loss_fn_name,
)
self.model = model
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod

import torch
from nebullvm.operations.base import Operation
from nebullvm.operations.fetch_operations.local import FetchModelFromLocal
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
from torchvision import datasets
from nebullvm.operations.base import Operation
from nebullvm.operations.fetch_operations.local import FetchModelFromLocal

from forward_forward.operations.data import VOCABULARY
from forward_forward.operations.fetch_operations import (
Expand Down Expand Up @@ -183,3 +187,69 @@ def _train(
predictions, _ = model.positive_eval(test_data, theta)
perplexity = compute_perplexity(predictions)
self.logger.info(f"Perplexity: {perplexity}")


class UnsupervisedCNNForwardForwardTrainer(BaseForwardForwardTrainer):
def _train(self, epochs: int, theta: float, device: str, **kwargs):
model = self.model.to(device)


unsupervised_loader, kmeans = self.get_unsupervised_label(device)

for epoch in range(epochs):
accumulated_goodness = None
model.train()

for j, (data, target) in enumerate(unsupervised_loader):
data = data.to(device)
target = target.to(device)

_, goodness = model.ff_train(data, target, theta)
if accumulated_goodness is None:
accumulated_goodness = goodness
else:
accumulated_goodness[0] += goodness[0]
accumulated_goodness[1] += goodness[1]
goodness_ratio = (
accumulated_goodness[0] - accumulated_goodness[1]
) / abs(max(accumulated_goodness))
self.logger.info(f"Epoch {epoch + 1}")
self.logger.info(f"Accumulated goodness: {accumulated_goodness}")
self.logger.info(f"Goodness ratio: {goodness_ratio}")
model.eval()


correct = 0
with torch.no_grad():
for data, target in self.test_data:
data = data.to(device)
numpy_data = data.flatten().cpu().numpy()
target = torch.from_numpy(kmeans.predict(numpy_data)).to(device)
pred, _ = model.positive_eval(data.unsqueeze(1), theta)
correct += pred.eq(target.view_as(pred)).sum().item()
self.logger.info(
f"Test accuracy: {correct} / 10000 ({correct / 10000 * 100}%)"
)

def get_unsupervised_label(self, device):
x_train = np.concatenate(
[data.detach().cpu().numpy() for data, label in self.train_data], axis=0)


kmeans = KMeans(n_clusters=10, random_state=0)
train_labels = kmeans.fit_predict(x_train)
train_labels = torch.from_numpy(train_labels).to(device)
label_injector = LabelsInjector([f"cluster_{i}" for i in range(10)])

train_bs = self.train_data.batch_size
progressive_train_dataset = ProgressiveTrainingDataset(
(label_injector.inject_train(
x.unsqueeze(1),
train_labels[i*train_bs: (i+1)*train_bs]
) for i, (x, y) in enumerate(self.train_data))
)
progressive_train_dataloader = torch.utils.data.DataLoader(
progressive_train_dataset,
batch_size=2 * train_bs, shuffle=False
)
return progressive_train_dataloader, kmeans
7 changes: 7 additions & 0 deletions apps/accelerate/forward_forward/forward_forward/root_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
FCNetFFProgressiveBuildOperation,
RecurrentFCNetFFBuildOperation,
LMFFNetBuildOperation,
CNNBuildOperation
)
from forward_forward.operations.data import (
MNISTDataLoaderOperation,
Expand All @@ -15,13 +16,15 @@
ForwardForwardTrainer,
RecurrentForwardForwardTrainer,
NLPForwardForwardTrainer,
CNNForwardForwardTrainer
)


class ForwardForwardModelType(Enum):
PROGRESSIVE = "progressive"
RECURRENT = "recurrent"
NLP = "nlp"
CNN = "cnn"


class ForwardForwardRootOp(Operation):
Expand All @@ -40,6 +43,10 @@ def __init__(self, model_type: ForwardForwardModelType):
self.build_model = LMFFNetBuildOperation()
self.train_model = NLPForwardForwardTrainer()
self.load_data = AesopFablesDataLoaderOperation()
elif model_type is ForwardForwardModelType.CNN:
self.build_model = CNNBuildOperation()
self.train_model = CNNForwardForwardTrainer()
self.load_data = MNISTDataLoaderOperation()

def execute(
self,
Expand Down
49 changes: 49 additions & 0 deletions apps/accelerate/forward_forward/forward_forward/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,3 +829,52 @@ def positive_eval(self, input_tensor: torch.Tensor, theta: float):
)
cumulated_goodness /= self.predicted_tokens
return prediction, cumulated_goodness

class ConvBNReLU(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.bn = torch.nn.BatchNorm2d(out_channels)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.bn(self.conv(x)))

class ConvFFLayer(BaseFFLayer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
optimizer_name,
optimizer_kwargs,
loss_fn_name,
):
super().__init__()
self.layer = ConvBNReLU(in_channels, out_channels, kernel_size, stride, padding)
self.optimizer = getattr(torch.optim, optimizer_name)(self.layer.parameters(), **optimizer_kwargs)
self.loss_fn = eval(loss_fn_name)

def forward(self, x):
return self.layer(x)

def ff_train(self, x, signs, theta):
new_x = self(x.detach())
y_pos = new_x[signs == 1]
y_neg = new_x[signs == -1]
loss_pos, goodness_pos = self.loss_fn(y_pos, theta, 1)
loss_neg, goodness_neg = self.loss_fn(y_neg, theta, -1)
loss = loss_pos + loss_neg
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return new_x, [goodness_pos, goodness_neg]

@torch.no_grad()
def positive_eval(self, x, theta):
new_x = self(x)
goodness = new_x.pow(2).mean(dim=1) - theta
return new_x, goodness

46 changes: 45 additions & 1 deletion nebullvm/tools/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from abc import abstractmethod, ABC
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import Any, Dict, Type

import numpy as np
Expand All @@ -19,7 +20,6 @@
is_data_subscriptable,
)


def _get_dl_framework(model: Any):
if isinstance(model, torch.nn.Module) or str(model).startswith("Pytorch"):
return DeepLearningFramework.PYTORCH
Expand Down Expand Up @@ -53,6 +53,50 @@ def _create_model_inputs(

return input_data

class HuggingFaceBenchmark(BaseBenchmark):
def __init__(self, model_name, input_texts, device, n_warmup=50, n_runs=1000):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
self.input_texts = input_texts
self.device = device
self.n_warmup = n_warmup
self.n_runs = n_runs

def benchmark(self):
input_tensors = []
for input_text in self.input_texts:
encoded = self.tokenizer(input_text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
input_tensors.append(encoded.to(self.device))

batch_size = input_tensors[0].shape[0]

with torch.no_grad():
for i in tqdm(
range(self.n_warmup),
desc=f"Performing warm up on {self.n_warmup} iterations",
):
self.model(**input_tensors[i % min(self.n_warmup, len(input_tensors))])

if self.device.type is DeviceType.GPU:
torch.cuda.synchronize()
timings = []
with torch.no_grad():
for i in tqdm(
range(1, self.n_runs + 1),
desc=f"Performing benchmark on {self.n_runs} iterations",
):
start_time = time.time()
self.model(**input_tensors[i % min(self.n_runs, len(input_tensors))])
if self.device.type is DeviceType.GPU:
torch.cuda.synchronize()
end_time = time.time()
timings.append(end_time - start_time)

throughput = batch_size / np.mean(timings)
latency = np.mean(timings) / batch_size

return throughput, latency


class BaseBenchmark(ABC):
def __init__(self, model, input_tensors, n_warmup=50, n_runs=1000):
Expand Down