Skip to content

Commit

Permalink
Move finetuning out of main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-gao-GY committed Nov 30, 2023
1 parent 1031c34 commit 206913a
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 37 deletions.
32 changes: 27 additions & 5 deletions baselines/fedvssl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,50 @@ cd ..
```


Finally, we can launch the training:
Finally, we can launch the training.

### Federated SSL pre-training

To run using FedVSSL:
```bash
# run federated SSL training with FedVSSL
python -m fedvssl.main pre_training=true # this will run using the default settings.

# you can override settings directly from the command line
python -m fedvssl.main pre_training=true mix_coeff=1 rounds=100 # will set hyper-parameter alpha to 1 and the number of rounds to 100

# run downstream fine-tuning with pre-trained SSL model
python -m fedvssl.main pre_training=false pretrained_model_path=<CHECKPOINT>.npz # this will run using the default settings.
```

To run using FedAvg:
```bash
# this will run FedAvg baseline
# This is done so to match the experimental setup in the paper
python -m fedvssl.main fedavg=true
python -m fedvssl.main pre_training=true fedavg=true

# this config can also be overriden.
```

### Downstream fine-tuning
To run downstream fine-tuning with pre-trained SSL model,
we first need to transform model format:

```bash
python -m fedvssl.finetune_preprocess --pretrained_model_path=<CHECKPOINT>.npz
```

Then, launch the fine-tuning using CtP script:

```bash
bash fedvssl/CtP/tools/dist_train.sh fedvssl/conf/mmcv_conf/finetuning/r3d_18_ucf101/finetune_ucf101.py 1 --work_dir=./finetune_results --data_dir=fedvssl/data
```

Note that the first parameter of this script is the path of config file, while the second is the number of GPUs used for fine-tuning.

After that, we perform the test process:

```bash
bash fedvssl/CtP/tools/dist_test.sh fedvssl/conf/mmcv_conf/finetuning/r3d_18_ucf101/test_ucf101.py 1 --work_dir=./finetune_results --data_dir=fedvssl/data --progress
```

## Expected results

### Pre-training and fine-tuning on UCF-101
Expand Down
2 changes: 1 addition & 1 deletion baselines/fedvssl/fedvssl/conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pre_training: false
exp_name: fedvssl_results
data_dir: fedvssl/data
partition_dir: annotations/client_distribution
cfg_path_pretrain: fedvssl/conf/mmcv_conf/r3d_18_ucf101/pretraining_for_ucf.py
cfg_path_pretrain: fedvssl/conf/mmcv_conf/pretraining/r3d_18_ucf101/pretraining_for_ucf.py

# FL settings
pool_size: 5
Expand Down
33 changes: 33 additions & 0 deletions baselines/fedvssl/fedvssl/conf/mmcv_conf/finetuning/model_r3d18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model = dict(
type='TSN',
backbone=dict(
type='R3D',
depth=18,
num_stages=4,
stem=dict(
temporal_kernel_size=3,
temporal_stride=1,
in_channels=3,
with_pool=False,
),
down_sampling=[False, True, True, True],
channel_multiplier=1.0,
bottleneck_multiplier=1.0,
with_bn=True,
zero_init_residual=False,
pretrained=None,
),
st_module=dict(
spatial_type='avg',
temporal_size=2, # 16//8
spatial_size=7),
cls_head=dict(
with_avg_pool=False,
temporal_feature_size=1,
spatial_feature_size=1,
dropout_ratio=0.5,
in_channels=512,
init_std=0.001,
num_classes=101
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = ['../model_r3d18.py',
'../runtime_ucf101.py']

work_dir = './finetune_ucf101/'

model = dict(
backbone=dict(
pretrained='./model_pretrained.pth',
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = ['../model_r3d18.py',
'../runtime_ucf101.py']

work_dir = './finetune_ucf101/'

model = dict(
backbone=dict(
pretrained='/finetune/ucf101/epoch_150.pth',
),
)
133 changes: 133 additions & 0 deletions baselines/fedvssl/fedvssl/conf/mmcv_conf/finetuning/runtime_ucf101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
syncbn = True

train_cfg = None
test_cfg = None
evaluation = dict(interval=10)

data = dict(
videos_per_gpu=4, # total batch size 8*4 == 32
workers_per_gpu=4,
train=dict(
type='TSNDataset',
name='ucf101_train_split1',
data_source=dict(
type='JsonClsDataSource',
ann_file='ucf101/annotations/train_split_1.json',
),
backend=dict(
type='ZipBackend',
zip_fmt='ucf101/zips/{}.zip',
frame_fmt='img_{:05d}.jpg',
),
frame_sampler=dict(
type='RandomFrameSampler',
num_clips=1,
clip_len=16,
strides=2,
temporal_jitter=False
),
test_mode=False,
transform_cfg=[
dict(type='GroupScale', scales=[(149, 112), (171, 128), (192, 144)]),
dict(type='GroupFlip', flip_prob=0.35),
dict(type='RandomBrightness', prob=0.20, delta=32),
dict(type='RandomContrast', prob=0.20, delta=0.20),
dict(type='RandomHueSaturation', prob=0.20, hue_delta=12, saturation_delta=0.1),
dict(type='GroupRandomCrop', out_size=112),
dict(
type='GroupToTensor',
switch_rgb_channels=True,
div255=True,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
]
),
val=dict(
type='TSNDataset',
name='ucf101_test_split1',
data_source=dict(
type='JsonClsDataSource',
ann_file='ucf101/annotations/test_split_1.json',
),
backend=dict(
type='ZipBackend',
zip_fmt='ucf101/zips/{}.zip',
frame_fmt='img_{:05d}.jpg',
),
frame_sampler=dict(
type='UniformFrameSampler',
num_clips=10,
clip_len=16,
strides=2,
temporal_jitter=False
),
test_mode=True,
transform_cfg=[
dict(type='GroupScale', scales=[(171, 128)]),
dict(type='GroupCenterCrop', out_size=112),
dict(
type='GroupToTensor',
switch_rgb_channels=True,
div255=True,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
]
),
test=dict(
type='TSNDataset',
name='ucf101_test_split1',
data_source=dict(
type='JsonClsDataSource',
ann_file='ucf101/annotations/test_split_1.json',
),
backend=dict(
type='ZipBackend',
zip_fmt='ucf101/zips/{}.zip',
frame_fmt='img_{:05d}.jpg',
),
frame_sampler=dict(
type='UniformFrameSampler',
num_clips=10,
clip_len=16,
strides=2,
temporal_jitter=False
),
test_mode=True,
transform_cfg=[
dict(type='GroupScale', scales=[(171, 128)]),
dict(type='GroupCenterCrop', out_size=112),
dict(
type='GroupToTensor',
switch_rgb_channels=True,
div255=True,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
]
),
)

# optimizer
total_epochs = 150
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=5e-4)
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
step=[60, 120]
)
checkpoint_config = dict(interval=1, max_keep_ckpts=1, create_symlink=False)
workflow = [('train', 50)]
log_config = dict(
interval=10,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook'),
]
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""Config file used for pre-training on UCF-101 dataset."""

_base_ = "../pretraining_runtime_ucf.py"
# _base_ = '../pretraining_runtime_kinetics.py'

# work_dir = './output/ctp/r3d_18_ucf101/pretraining/'

model = {
"type": "CtP",
Expand Down

This file was deleted.

This file was deleted.

50 changes: 50 additions & 0 deletions baselines/fedvssl/fedvssl/finetune_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
from collections import OrderedDict
import numpy as np
import torch
from mmengine.config import Config
from flwr.common import parameters_to_ndarrays
from .CtP.pyvrl.builder import build_model


def args_parser():
"""Parse arguments to pre-process pre-trained SSL model for fine-tuning."""
parser = argparse.ArgumentParser()

parser.add_argument(
"--cfg_path",
default="fedvssl/conf/mmcv_conf/finetuning/r3d_18_ucf101/finetune_ucf101.py",
type=str,
help="Path of config file for fine-tuning.",
)
parser.add_argument(
"--pretrained_model_path",
default="",
type=str,
help="Path of pre-trained SSL model.",
)

args = parser.parse_args()
return args


args = args_parser()

# Load config file
cfg = Config.fromfile(args.cfg_path)
cfg.model.backbone["pretrained"] = None

# Build a model using the config file
model = build_model(cfg.model)

# Conversion of the format of pre-trained SSL model from .npz files to .pth format.
params = np.load(args.pretrained_model_path, allow_pickle=True)
params = params["arr_0"].item()
params = parameters_to_ndarrays(params)
params_dict = zip(model.state_dict().keys(), params)
state_dict = {
"state_dict": OrderedDict(
{k: torch.from_numpy(v) for k, v in params_dict}
)
}
torch.save(state_dict, "./model_pretrained.pth")

0 comments on commit 206913a

Please sign in to comment.