diff --git a/mixmatch/models/nested_dict.py b/mixmatch/models/nested_dict.py new file mode 100644 index 0000000..9b7b41a --- /dev/null +++ b/mixmatch/models/nested_dict.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python +"""`nested_dict` provides dictionaries with multiple levels of nested-ness.""" +from __future__ import print_function +from __future__ import division + +################################################################################ +# +# nested_dict.py +# +# Copyright (c) 2009, 2015 Leo Goodstadt +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +################################################################################# + + +from collections import defaultdict + +import sys + + +def flatten_nested_items(dictionary): + """ + Flatten a nested_dict. + + iterate through nested dictionary (with iterkeys() method) + and return with nested keys flattened into a tuple + """ + if sys.hexversion < 0x03000000: + keys = dictionary.iterkeys + keystr = "iterkeys" + else: + keys = dictionary.keys + keystr = "keys" + for key in keys(): + value = dictionary[key] + if hasattr(value, keystr): + for keykey, value in flatten_nested_items(value): + yield (key,) + keykey, value + else: + yield (key,), value + + +class _recursive_dict(defaultdict): + """ + Parent class of nested_dict. + + Defined separately for _nested_levels to work + transparently, so dictionaries with a specified (and constant) degree of nestedness + can be created easily. + + The "_flat" functions are defined here rather than in nested_dict because they work + recursively. + + """ + + def iteritems_flat(self): + """Iterate through items with nested keys flattened into a tuple.""" + for key, value in flatten_nested_items(self): + yield key, value + + def iterkeys_flat(self): + """Iterate through keys with nested keys flattened into a tuple.""" + for key, value in flatten_nested_items(self): + yield key + + def itervalues_flat(self): + """Iterate through values with nested keys flattened into a tuple.""" + for key, value in flatten_nested_items(self): + yield value + + items_flat = iteritems_flat + keys_flat = iterkeys_flat + values_flat = itervalues_flat + + def to_dict(self, input_dict=None): + """Convert the nested dictionary to a nested series of standard ``dict`` objects.""" + # + # Calls itself recursively to unwind the dictionary. + # Use to_dict() to start at the top level of nesting + plain_dict = dict() + if input_dict is None: + input_dict = self + for key in input_dict.keys(): + value = input_dict[key] + if isinstance(value, _recursive_dict): + # print "recurse", value + plain_dict[key] = self.to_dict(value) + else: + # print "plain", value + plain_dict[key] = value + return plain_dict + + def __str__(self, indent=None): + """Representation of self as a string.""" + import json + return json.dumps(self.to_dict(), indent=indent) + + +class _any_type(object): + pass + + +def _nested_levels(level, nested_type): + """Helper function to create a specified degree of nested dictionaries.""" + if level > 2: + return lambda: _recursive_dict(_nested_levels(level - 1, nested_type)) + if level == 2: + if isinstance(nested_type, _any_type): + return lambda: _recursive_dict() + else: + return lambda: _recursive_dict(_nested_levels(level - 1, nested_type)) + return nested_type + + +if sys.hexversion < 0x03000000: + iteritems = dict.iteritems +else: + iteritems = dict.items + + +# _________________________________________________________________________________________ +# +# nested_dict +# +# _________________________________________________________________________________________ +def nested_dict_from_dict(orig_dict, nd): + """Helper to build nested_dict from a dict.""" + for key, value in iteritems(orig_dict): + if isinstance(value, (dict,)): + nd[key] = nested_dict_from_dict(value, nested_dict()) + else: + nd[key] = value + return nd + + +def _recursive_update(nd, other): + for key, value in iteritems(other): + #print ("key=", key) + if isinstance(value, (dict,)): + + # recursive update if my item is nested_dict + if isinstance(nd[key], (_recursive_dict,)): + #print ("recursive update", key, type(nd[key])) + _recursive_update(nd[key], other[key]) + + # update if my item is dict + elif isinstance(nd[key], (dict,)): + #print ("update", key, type(nd[key])) + nd[key].update(other[key]) + + # overwrite + else: + #print ("self not nested dict or dict: overwrite", key) + nd[key] = value + # other not dict: overwrite + else: + #print ("other not dict: overwrite", key) + nd[key] = value + return nd + + +# _________________________________________________________________________________________ +# +# nested_dict +# +# _________________________________________________________________________________________ +class nested_dict(_recursive_dict): + """ + Nested dict. + + Uses defaultdict to automatically add levels of nested dicts and other types. + """ + + def update(self, other): + """Update recursively.""" + _recursive_update(self, other) + + def __init__(self, *param, **named_param): + """ + Constructor. + + Takes one or two parameters + 1) int, [TYPE] + 1) dict + """ + if not len(param): + self.factory = nested_dict + defaultdict.__init__(self, self.factory) + return + + if len(param) == 1: + # int = level + if isinstance(param[0], int): + self.factory = _nested_levels(param[0], _any_type()) + defaultdict.__init__(self, self.factory) + return + # existing dict + if isinstance(param[0], dict): + self.factory = nested_dict + defaultdict.__init__(self, self.factory) + nested_dict_from_dict(param[0], self) + return + + if len(param) == 2: + if isinstance(param[0], int): + self.factory = _nested_levels(*param) + defaultdict.__init__(self, self.factory) + return + + raise Exception("nested_dict should be initialised with either " + "1) the number of nested levels and an optional type, or " + "2) an existing dict to be converted into a nested dict " + "(factory = %s. len(param) = %d, param = %s" + % (self.factory, len(param), param)) \ No newline at end of file diff --git a/mixmatch/models/utils.py b/mixmatch/models/utils.py new file mode 100644 index 0000000..c1f974f --- /dev/null +++ b/mixmatch/models/utils.py @@ -0,0 +1,71 @@ +import torch +from torch.nn.init import kaiming_normal_ +import torch.nn.functional as F +from torch.nn.parallel._functions import Broadcast +from torch.nn.parallel import scatter, parallel_apply, gather +from functools import partial +from mixmatch.models.nested_dict import nested_dict + + +def cast(params, dtype='float'): + if isinstance(params, dict): + return {k: cast(v, dtype) for k, v in params.items()} + else: + return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)() + + +def conv_params(ni, no, k=1): + return kaiming_normal_(torch.Tensor(no, ni, k, k)) + + +def linear_params(ni, no): + return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)} + + +def bnparams(n): + return {'weight': torch.rand(n), + 'bias': torch.zeros(n), + 'running_mean': torch.zeros(n), + 'running_var': torch.ones(n)} + + +def data_parallel(f, input, params, device_ids, output_device=None): + assert isinstance(device_ids, list) + if output_device is None: + output_device = device_ids[0] + + if len(device_ids) == 1: + return f(input, params) + + params_all = Broadcast.apply(device_ids, *params.values()) + params_replicas = [{k: params_all[i + j * len(params)] for i, k in enumerate(params.keys())} + for j in range(len(device_ids))] + + replicas = [partial(f, params=p) + for p in params_replicas] + inputs = scatter([input], device_ids) + outputs = parallel_apply(replicas, inputs) + return gather(outputs, output_device) + + +def flatten(params): + return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None} + + +def batch_norm(x, params, base): + return F.batch_norm(x, weight=params[base + '.weight'], + bias=params[base + '.bias'], + running_mean=params[base + '.running_mean'], + running_var=params[base + '.running_var'],) + + +def print_tensor_dict(params): + kmax = max(len(key) for key in params.keys()) + for i, (key, v) in enumerate(params.items()): + print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad) + + +def set_requires_grad_except_bn_(params): + for k, v in params.items(): + if not k.endswith('running_mean') and not k.endswith('running_var'): + v.requires_grad = True diff --git a/mixmatch/models/wideresnet_new.py b/mixmatch/models/wideresnet_new.py new file mode 100644 index 0000000..9a16894 --- /dev/null +++ b/mixmatch/models/wideresnet_new.py @@ -0,0 +1,66 @@ +import torch.nn.functional as F +import mixmatch.models.utils as utils + + +def resnet(depth, width, num_classes): + assert (depth - 4) % 6 == 0, 'depth should be 6n+4' + n = (depth - 4) // 6 + widths = [int(v * width) for v in (16, 32, 64)] + + def gen_block_params(ni, no): + return { + 'conv0': utils.conv_params(ni, no, 3), + 'conv1': utils.conv_params(no, no, 3), + 'bn0': utils.bnparams(ni), + 'bn1': utils.bnparams(no), + 'convdim': utils.conv_params(ni, no, 1) if ni != no else None, + } + + def gen_group_params(ni, no, count): + return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) + for i in range(count)} + + flat_params = utils.cast(utils.flatten({ + 'conv0': utils.conv_params(3, 16, 3), + 'group0': gen_group_params(16, widths[0], n), + 'group1': gen_group_params(widths[0], widths[1], n), + 'group2': gen_group_params(widths[1], widths[2], n), + 'bn': utils.bnparams(widths[2]), + 'fc': utils.linear_params(widths[2], num_classes), + })) + + utils.set_requires_grad_except_bn_(flat_params) + + def block(x, params, base, stride): + o1 = F.relu(utils.batch_norm(x, params, base + '.bn0'), inplace=True) + y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1) + o2 = F.relu(utils.batch_norm(y, params, base + '.bn1'), inplace=True) + z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1) + if base + '.convdim' in params: + return z + F.conv2d(o1, params[base + '.convdim'], stride=stride) + else: + return z + x + + def group(o, params, base, stride): + for i in range(n): + o = block(o, params, '%s.block%d' % (base, i), stride if i == 0 else 1) + return o + + def f(input, params): + x = F.conv2d(input, params['conv0'], padding=1) + g0 = group(x, params, 'group0', 1) + g1 = group(g0, params, 'group1', 2) + g2 = group(g1, params, 'group2', 2) + o = F.relu(utils.batch_norm(g2, params, 'bn')) + o = F.avg_pool2d(o, 8, 1, 0) + o = o.view(o.size(0), -1) + o = F.linear(o, params['fc.weight'], params['fc.bias']) + return o + + return f, flat_params + + +f, p = resnet(28, 2, 10) +import torch + +f(torch.rand(16, 3, 100, 100), p, )