Skip to content

Commit

Permalink
Initial commit for new resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Nov 26, 2023
1 parent 50cbfb7 commit b6ffe40
Show file tree
Hide file tree
Showing 3 changed files with 368 additions and 0 deletions.
231 changes: 231 additions & 0 deletions mixmatch/models/nested_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
#!/usr/bin/env python
"""`nested_dict` provides dictionaries with multiple levels of nested-ness."""
from __future__ import print_function
from __future__ import division

################################################################################
#
# nested_dict.py
#
# Copyright (c) 2009, 2015 Leo Goodstadt
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#################################################################################


from collections import defaultdict

import sys


def flatten_nested_items(dictionary):
"""
Flatten a nested_dict.
iterate through nested dictionary (with iterkeys() method)
and return with nested keys flattened into a tuple
"""
if sys.hexversion < 0x03000000:
keys = dictionary.iterkeys
keystr = "iterkeys"
else:
keys = dictionary.keys
keystr = "keys"
for key in keys():
value = dictionary[key]
if hasattr(value, keystr):
for keykey, value in flatten_nested_items(value):
yield (key,) + keykey, value
else:
yield (key,), value


class _recursive_dict(defaultdict):
"""
Parent class of nested_dict.
Defined separately for _nested_levels to work
transparently, so dictionaries with a specified (and constant) degree of nestedness
can be created easily.
The "_flat" functions are defined here rather than in nested_dict because they work
recursively.
"""

def iteritems_flat(self):
"""Iterate through items with nested keys flattened into a tuple."""
for key, value in flatten_nested_items(self):
yield key, value

def iterkeys_flat(self):
"""Iterate through keys with nested keys flattened into a tuple."""
for key, value in flatten_nested_items(self):
yield key

def itervalues_flat(self):
"""Iterate through values with nested keys flattened into a tuple."""
for key, value in flatten_nested_items(self):
yield value

items_flat = iteritems_flat
keys_flat = iterkeys_flat
values_flat = itervalues_flat

def to_dict(self, input_dict=None):
"""Convert the nested dictionary to a nested series of standard ``dict`` objects."""
#
# Calls itself recursively to unwind the dictionary.
# Use to_dict() to start at the top level of nesting
plain_dict = dict()
if input_dict is None:
input_dict = self
for key in input_dict.keys():
value = input_dict[key]
if isinstance(value, _recursive_dict):
# print "recurse", value
plain_dict[key] = self.to_dict(value)
else:
# print "plain", value
plain_dict[key] = value
return plain_dict

def __str__(self, indent=None):
"""Representation of self as a string."""
import json
return json.dumps(self.to_dict(), indent=indent)


class _any_type(object):
pass


def _nested_levels(level, nested_type):
"""Helper function to create a specified degree of nested dictionaries."""
if level > 2:
return lambda: _recursive_dict(_nested_levels(level - 1, nested_type))
if level == 2:
if isinstance(nested_type, _any_type):
return lambda: _recursive_dict()
else:
return lambda: _recursive_dict(_nested_levels(level - 1, nested_type))
return nested_type


if sys.hexversion < 0x03000000:
iteritems = dict.iteritems
else:
iteritems = dict.items


# _________________________________________________________________________________________
#
# nested_dict
#
# _________________________________________________________________________________________
def nested_dict_from_dict(orig_dict, nd):
"""Helper to build nested_dict from a dict."""
for key, value in iteritems(orig_dict):
if isinstance(value, (dict,)):
nd[key] = nested_dict_from_dict(value, nested_dict())
else:
nd[key] = value
return nd


def _recursive_update(nd, other):
for key, value in iteritems(other):
#print ("key=", key)
if isinstance(value, (dict,)):

# recursive update if my item is nested_dict
if isinstance(nd[key], (_recursive_dict,)):
#print ("recursive update", key, type(nd[key]))
_recursive_update(nd[key], other[key])

# update if my item is dict
elif isinstance(nd[key], (dict,)):
#print ("update", key, type(nd[key]))
nd[key].update(other[key])

# overwrite
else:
#print ("self not nested dict or dict: overwrite", key)
nd[key] = value
# other not dict: overwrite
else:
#print ("other not dict: overwrite", key)
nd[key] = value
return nd


# _________________________________________________________________________________________
#
# nested_dict
#
# _________________________________________________________________________________________
class nested_dict(_recursive_dict):
"""
Nested dict.
Uses defaultdict to automatically add levels of nested dicts and other types.
"""

def update(self, other):
"""Update recursively."""
_recursive_update(self, other)

def __init__(self, *param, **named_param):
"""
Constructor.
Takes one or two parameters
1) int, [TYPE]
1) dict
"""
if not len(param):
self.factory = nested_dict
defaultdict.__init__(self, self.factory)
return

if len(param) == 1:
# int = level
if isinstance(param[0], int):
self.factory = _nested_levels(param[0], _any_type())
defaultdict.__init__(self, self.factory)
return
# existing dict
if isinstance(param[0], dict):
self.factory = nested_dict
defaultdict.__init__(self, self.factory)
nested_dict_from_dict(param[0], self)
return

if len(param) == 2:
if isinstance(param[0], int):
self.factory = _nested_levels(*param)
defaultdict.__init__(self, self.factory)
return

raise Exception("nested_dict should be initialised with either "
"1) the number of nested levels and an optional type, or "
"2) an existing dict to be converted into a nested dict "
"(factory = %s. len(param) = %d, param = %s"
% (self.factory, len(param), param))
71 changes: 71 additions & 0 deletions mixmatch/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
from torch.nn.init import kaiming_normal_
import torch.nn.functional as F
from torch.nn.parallel._functions import Broadcast
from torch.nn.parallel import scatter, parallel_apply, gather
from functools import partial
from mixmatch.models.nested_dict import nested_dict


def cast(params, dtype='float'):
if isinstance(params, dict):
return {k: cast(v, dtype) for k, v in params.items()}
else:
return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)()


def conv_params(ni, no, k=1):
return kaiming_normal_(torch.Tensor(no, ni, k, k))


def linear_params(ni, no):
return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)}


def bnparams(n):
return {'weight': torch.rand(n),
'bias': torch.zeros(n),
'running_mean': torch.zeros(n),
'running_var': torch.ones(n)}


def data_parallel(f, input, params, device_ids, output_device=None):
assert isinstance(device_ids, list)
if output_device is None:
output_device = device_ids[0]

if len(device_ids) == 1:
return f(input, params)

params_all = Broadcast.apply(device_ids, *params.values())
params_replicas = [{k: params_all[i + j * len(params)] for i, k in enumerate(params.keys())}
for j in range(len(device_ids))]

replicas = [partial(f, params=p)
for p in params_replicas]
inputs = scatter([input], device_ids)
outputs = parallel_apply(replicas, inputs)
return gather(outputs, output_device)


def flatten(params):
return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None}


def batch_norm(x, params, base):
return F.batch_norm(x, weight=params[base + '.weight'],
bias=params[base + '.bias'],
running_mean=params[base + '.running_mean'],
running_var=params[base + '.running_var'],)


def print_tensor_dict(params):
kmax = max(len(key) for key in params.keys())
for i, (key, v) in enumerate(params.items()):
print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad)


def set_requires_grad_except_bn_(params):
for k, v in params.items():
if not k.endswith('running_mean') and not k.endswith('running_var'):
v.requires_grad = True
66 changes: 66 additions & 0 deletions mixmatch/models/wideresnet_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch.nn.functional as F
import mixmatch.models.utils as utils


def resnet(depth, width, num_classes):
assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
n = (depth - 4) // 6
widths = [int(v * width) for v in (16, 32, 64)]

def gen_block_params(ni, no):
return {
'conv0': utils.conv_params(ni, no, 3),
'conv1': utils.conv_params(no, no, 3),
'bn0': utils.bnparams(ni),
'bn1': utils.bnparams(no),
'convdim': utils.conv_params(ni, no, 1) if ni != no else None,
}

def gen_group_params(ni, no, count):
return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
for i in range(count)}

flat_params = utils.cast(utils.flatten({
'conv0': utils.conv_params(3, 16, 3),
'group0': gen_group_params(16, widths[0], n),
'group1': gen_group_params(widths[0], widths[1], n),
'group2': gen_group_params(widths[1], widths[2], n),
'bn': utils.bnparams(widths[2]),
'fc': utils.linear_params(widths[2], num_classes),
}))

utils.set_requires_grad_except_bn_(flat_params)

def block(x, params, base, stride):
o1 = F.relu(utils.batch_norm(x, params, base + '.bn0'), inplace=True)
y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
o2 = F.relu(utils.batch_norm(y, params, base + '.bn1'), inplace=True)
z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
if base + '.convdim' in params:
return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
else:
return z + x

def group(o, params, base, stride):
for i in range(n):
o = block(o, params, '%s.block%d' % (base, i), stride if i == 0 else 1)
return o

def f(input, params):
x = F.conv2d(input, params['conv0'], padding=1)
g0 = group(x, params, 'group0', 1)
g1 = group(g0, params, 'group1', 2)
g2 = group(g1, params, 'group2', 2)
o = F.relu(utils.batch_norm(g2, params, 'bn'))
o = F.avg_pool2d(o, 8, 1, 0)
o = o.view(o.size(0), -1)
o = F.linear(o, params['fc.weight'], params['fc.bias'])
return o

return f, flat_params


f, p = resnet(28, 2, 10)
import torch

f(torch.rand(16, 3, 100, 100), p, )

0 comments on commit b6ffe40

Please sign in to comment.