From 925ec7452d158af56eee58540d8211c31511fd5b Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Wed, 28 Nov 2018 02:01:35 -0600 Subject: [PATCH] Bug fix (#347) * bug fix * new release --- autokeras/bayesian.py | 4 +- autokeras/net_transformer.py | 12 +----- autokeras/nn/graph.py | 68 +++++++++++++++++-------------- autokeras/nn/layer_transformer.py | 17 ++++++-- setup.py | 4 +- tests/nn/test_graph.py | 23 ++++++++++- 6 files changed, 78 insertions(+), 50 deletions(-) diff --git a/autokeras/bayesian.py b/autokeras/bayesian.py index 0b172f94e..74fdf0daf 100644 --- a/autokeras/bayesian.py +++ b/autokeras/bayesian.py @@ -342,11 +342,11 @@ def generate(self, descriptors, timeout, multiprocessing_queue): pq.put(elem_class(metric_value, model_id, graph)) t = 1.0 - # t_min = self.t_min + t_min = self.t_min alpha = 0.9 opt_acq = self._get_init_opt_acq_value() remaining_time = timeout - while not pq.empty() and remaining_time > 0: + while not pq.empty() and remaining_time > 0 and t > t_min: if multiprocessing_queue.qsize() != 0: break elem = pq.get() diff --git a/autokeras/net_transformer.py b/autokeras/net_transformer.py index 6bd042025..2f78639fb 100644 --- a/autokeras/net_transformer.py +++ b/autokeras/net_transformer.py @@ -1,11 +1,9 @@ from copy import deepcopy -from operator import itemgetter from random import randrange, sample from autokeras.nn.graph import NetworkDescriptor from autokeras.constant import Constant -from autokeras.nn.layer_transformer import init_dense_weight, init_conv_weight, init_bn_weight from autokeras.nn.layers import is_layer, StubDense, get_dropout_class, StubReLU, get_conv_class, \ get_batch_norm_class, get_pooling_class @@ -29,22 +27,14 @@ def to_wider_graph(graph): def to_skip_connection_graph(graph): # The last conv layer cannot be widen since wider operator cannot be done over the two sides of flatten. weighted_layer_ids = graph.skip_connection_layer_ids() - descriptor = graph.extract_descriptor() - sorted_skips = sorted(descriptor.skip_connections, key=itemgetter(2, 0, 1)) - p = 0 valid_connection = [] for skip_type in sorted([NetworkDescriptor.ADD_CONNECT, NetworkDescriptor.CONCAT_CONNECT]): for index_a in range(len(weighted_layer_ids)): for index_b in range(len(weighted_layer_ids))[index_a + 1:]: - if p < len(sorted_skips) and sorted_skips[p] == (index_a + 1, index_b + 1, skip_type): - p += 1 - else: - valid_connection.append((index_a, index_b, skip_type)) + valid_connection.append((index_a, index_b, skip_type)) if len(valid_connection) < 1: return graph - # n_skip_connection = randint(1, len(valid_connection)) - # for index_a, index_b, skip_type in sample(valid_connection, n_skip_connection): for index_a, index_b, skip_type in sample(valid_connection, 1): a_id = weighted_layer_ids[index_a] b_id = weighted_layer_ids[index_b] diff --git a/autokeras/nn/graph.py b/autokeras/nn/graph.py index 92c4040c9..306abc387 100644 --- a/autokeras/nn/graph.py +++ b/autokeras/nn/graph.py @@ -10,7 +10,7 @@ wider_pre_conv, add_noise, init_dense_weight, init_conv_weight, init_bn_weight from autokeras.nn.layers import StubConcatenate, StubAdd, is_layer, layer_width, \ to_real_keras_layer, set_torch_weight_to_stub, set_stub_weight_to_torch, set_stub_weight_to_keras, \ - set_keras_weight_to_stub, get_conv_class, get_pooling_class, StubReLU + set_keras_weight_to_stub, get_conv_class, StubReLU class NetworkDescriptor: @@ -205,6 +205,7 @@ def _redirect_edge(self, u_id, v_id, new_v_id): if edge_tuple[0] == v_id: layer_id = edge_tuple[1] self.adj_list[u_id][index] = (new_v_id, layer_id) + self.layer_list[layer_id].output = self.node_list[new_v_id] break for index, edge_tuple in enumerate(self.reverse_adj_list[v_id]): @@ -261,9 +262,9 @@ def _get_pooling_layers(self, start_node_id, end_node_id): for layer_id in layer_list: layer = self.layer_list[layer_id] if is_layer(layer, 'Pooling'): - ret.append((layer.kernel_size, layer.stride, layer.padding)) + ret.append(layer) elif is_layer(layer, 'Conv') and layer.stride != 1: - ret.append((int((layer.kernel_size + 1) / 2), layer.stride, 0)) + ret.append(layer) return ret def _depth_first_search(self, target_id, layer_id_list, node_list): @@ -417,12 +418,12 @@ def to_add_skip_model(self, start_id, end_id): self.operation_history.append(('to_add_skip_model', start_id, end_id)) filters_end = self.layer_list[end_id].output.shape[-1] filters_start = self.layer_list[start_id].output.shape[-1] - conv_block_input_id = self.layer_id_to_output_node_ids[start_id][0] + start_node_id = self.layer_id_to_output_node_ids[start_id][0] - block_last_layer_input_id = self.layer_id_to_input_node_ids[end_id][0] - block_last_layer_output_id = self.layer_id_to_output_node_ids[end_id][0] + pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0] + end_node_id = self.layer_id_to_output_node_ids[end_id][0] - skip_output_id = self._insert_pooling_layer_chain(block_last_layer_input_id, conv_block_input_id) + skip_output_id = self._insert_pooling_layer_chain(start_node_id, end_node_id) # Add the conv layer new_conv_layer = get_conv_class(self.n_dim)(filters_start, @@ -431,15 +432,15 @@ def to_add_skip_model(self, start_id, end_id): skip_output_id = self.add_layer(new_conv_layer, skip_output_id) # Add the add layer. - add_input_node_id = self._add_node(deepcopy(self.node_list[block_last_layer_output_id])) + add_input_node_id = self._add_node(deepcopy(self.node_list[end_node_id])) add_layer = StubAdd() - self._redirect_edge(block_last_layer_input_id, block_last_layer_output_id, add_input_node_id) - self._add_edge(add_layer, add_input_node_id, block_last_layer_output_id) - self._add_edge(add_layer, skip_output_id, block_last_layer_output_id) + self._redirect_edge(pre_end_node_id, end_node_id, add_input_node_id) + self._add_edge(add_layer, add_input_node_id, end_node_id) + self._add_edge(add_layer, skip_output_id, end_node_id) add_layer.input = [self.node_list[add_input_node_id], self.node_list[skip_output_id]] - add_layer.output = self.node_list[block_last_layer_output_id] - self.node_list[block_last_layer_output_id].shape = add_layer.output_shape + add_layer.output = self.node_list[end_node_id] + self.node_list[end_node_id].shape = add_layer.output_shape # Set weights to the additional conv layer. if self.weighted: @@ -458,15 +459,15 @@ def to_concat_skip_model(self, start_id, end_id): self.operation_history.append(('to_concat_skip_model', start_id, end_id)) filters_end = self.layer_list[end_id].output.shape[-1] filters_start = self.layer_list[start_id].output.shape[-1] - conv_block_input_id = self.layer_id_to_output_node_ids[start_id][0] + start_node_id = self.layer_id_to_output_node_ids[start_id][0] - block_last_layer_input_id = self.layer_id_to_input_node_ids[end_id][0] - block_last_layer_output_id = self.layer_id_to_output_node_ids[end_id][0] + pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0] + end_node_id = self.layer_id_to_output_node_ids[end_id][0] - skip_output_id = self._insert_pooling_layer_chain(block_last_layer_input_id, conv_block_input_id) + skip_output_id = self._insert_pooling_layer_chain(start_node_id, end_node_id) - concat_input_node_id = self._add_node(deepcopy(self.node_list[block_last_layer_output_id])) - self._redirect_edge(block_last_layer_input_id, block_last_layer_output_id, concat_input_node_id) + concat_input_node_id = self._add_node(deepcopy(self.node_list[end_node_id])) + self._redirect_edge(pre_end_node_id, end_node_id, concat_input_node_id) concat_layer = StubConcatenate() concat_layer.input = [self.node_list[concat_input_node_id], self.node_list[skip_output_id]] @@ -479,10 +480,10 @@ def to_concat_skip_model(self, start_id, end_id): # Add the concatenate layer. new_conv_layer = get_conv_class(self.n_dim)(filters_start + filters_end, filters_end, 1) - self._add_edge(new_conv_layer, concat_output_node_id, block_last_layer_output_id) + self._add_edge(new_conv_layer, concat_output_node_id, end_node_id) new_conv_layer.input = self.node_list[concat_output_node_id] - new_conv_layer.output = self.node_list[block_last_layer_output_id] - self.node_list[block_last_layer_output_id].shape = new_conv_layer.output_shape + new_conv_layer.output = self.node_list[end_node_id] + self.node_list[end_node_id].shape = new_conv_layer.output_shape if self.weighted: filter_shape = (1,) * self.n_dim @@ -497,12 +498,16 @@ def to_concat_skip_model(self, start_id, end_id): bias = np.zeros(filters_end) new_conv_layer.set_weights((add_noise(weights, np.array([0, 1])), add_noise(bias, np.array([0, 1])))) - def _insert_pooling_layer_chain(self, block_last_layer_input_id, conv_block_input_id): - skip_output_id = conv_block_input_id - for kernel_size, stride, padding in self._get_pooling_layers(conv_block_input_id, block_last_layer_input_id): - skip_output_id = self.add_layer(get_pooling_class(self.n_dim)(kernel_size, - stride=stride, - padding=padding), skip_output_id) + def _insert_pooling_layer_chain(self, start_node_id, end_node_id): + skip_output_id = start_node_id + for layer in self._get_pooling_layers(start_node_id, end_node_id): + new_layer = deepcopy(layer) + if is_layer(new_layer, 'Conv'): + filters = self.node_list[start_node_id].shape[-1] + new_layer = get_conv_class(self.n_dim)(filters, filters, 1, layer.stride) + else: + new_layer = deepcopy(layer) + skip_output_id = self.add_layer(new_layer, skip_output_id) skip_output_id = self.add_layer(StubReLU(), skip_output_id) return skip_output_id @@ -577,7 +582,7 @@ def get_main_chain_layers(self): """Return a list of layer IDs in the main chain.""" main_chain = self.get_main_chain() ret = [] - for u in range(self.n_nodes): + for u in main_chain: for v, layer_id in self.adj_list[u]: if v in main_chain and u in main_chain: ret.append(layer_id) @@ -592,8 +597,11 @@ def _dense_layer_ids_in_order(self): def deep_layer_ids(self): ret = [] for layer_id in self.get_main_chain_layers(): - if is_layer(self.layer_list[layer_id], 'GlobalAveragePooling'): + layer = self.layer_list[layer_id] + if is_layer(layer, 'GlobalAveragePooling'): break + if is_layer(layer, 'Add') or is_layer(layer, 'Concatenate'): + continue ret.append(layer_id) return ret diff --git a/autokeras/nn/layer_transformer.py b/autokeras/nn/layer_transformer.py index 4b04659f8..602897331 100644 --- a/autokeras/nn/layer_transformer.py +++ b/autokeras/nn/layer_transformer.py @@ -1,6 +1,6 @@ import numpy as np -from autokeras.nn.layers import StubDense, StubReLU, get_n_dim, get_conv_class, get_batch_norm_class +from autokeras.nn.layers import StubDense, get_n_dim, get_conv_class, get_batch_norm_class NOISE_RATIO = 1e-4 @@ -51,7 +51,10 @@ def wider_pre_conv(layer, n_add_filters, weighted=True): new_weight = new_weight[np.newaxis, ...] student_w = np.concatenate((student_w, new_weight), axis=0) student_b = np.append(student_b, teacher_b[teacher_index]) - new_pre_layer = get_conv_class(n_dim)(layer.input_channel, n_pre_filters + n_add_filters, layer.kernel_size) + new_pre_layer = get_conv_class(n_dim)(layer.input_channel, + n_pre_filters + n_add_filters, + kernel_size=layer.kernel_size, + stride=layer.stride) new_pre_layer.set_weights((add_noise(student_w, teacher_w), add_noise(student_b, teacher_b))) return new_pre_layer @@ -59,7 +62,10 @@ def wider_pre_conv(layer, n_add_filters, weighted=True): def wider_next_conv(layer, start_dim, total_dim, n_add, weighted=True): n_dim = get_n_dim(layer) if not weighted: - return get_conv_class(n_dim)(layer.input_channel + n_add, layer.filters, kernel_size=layer.kernel_size) + return get_conv_class(n_dim)(layer.input_channel + n_add, + layer.filters, + kernel_size=layer.kernel_size, + stride=layer.stride) n_filters = layer.filters teacher_w, teacher_b = layer.get_weights() @@ -70,7 +76,10 @@ def wider_next_conv(layer, start_dim, total_dim, n_add, weighted=True): student_w = np.concatenate((teacher_w[:, :start_dim, ...].copy(), add_noise(new_weight, teacher_w), teacher_w[:, start_dim:total_dim, ...].copy()), axis=1) - new_layer = get_conv_class(n_dim)(layer.input_channel + n_add, n_filters, layer.kernel_size) + new_layer = get_conv_class(n_dim)(layer.input_channel + n_add, + n_filters, + kernel_size=layer.kernel_size, + stride=layer.stride) new_layer.set_weights((student_w, teacher_b)) return new_layer diff --git a/setup.py b/setup.py index a9b731c33..6d8be9cfc 100644 --- a/setup.py +++ b/setup.py @@ -15,12 +15,12 @@ 'imageio==2.4.1', 'requests==2.20.1', 'GPUtil==1.3.0'], - version='0.3.2', + version='0.3.3', description='AutoML for deep learning', author='DATA Lab at Texas A&M University', author_email='jhfjhfj1@gmail.com', url='http://autokeras.com', - download_url='https://github.com/jhfjhfj1/autokeras/archive/0.3.2.tar.gz', + download_url='https://github.com/jhfjhfj1/autokeras/archive/0.3.3.tar.gz', keywords=['AutoML', 'keras'], # arbitrary keywords classifiers=[] ) diff --git a/tests/nn/test_graph.py b/tests/nn/test_graph.py index 7af4a9462..b6884c979 100644 --- a/tests/nn/test_graph.py +++ b/tests/nn/test_graph.py @@ -1,6 +1,5 @@ from autokeras.nn.generator import CnnGenerator, ResNetGenerator from autokeras.nn.graph import * -from autokeras.nn.layers import StubBatchNormalization from tests.common import get_conv_data, get_add_skip_model, get_conv_dense_model, get_pooling_model, \ get_concat_skip_model @@ -195,3 +194,25 @@ def test_long_transform(): model = graph.produce_model() model(torch.Tensor(np.random.random((10, 1, 28, 28)))) + +def test_long_transform2(): + graph = CnnGenerator(10, (28, 28, 1)).generate() + graph.to_add_skip_model(2, 3) + graph.to_concat_skip_model(2, 3) + model = graph.produce_model() + model(torch.Tensor(np.random.random((10, 1, 28, 28)))) + + +def test_long_transform4(): + graph = ResNetGenerator(10, (28, 28, 1)).generate() + graph.to_concat_skip_model(57, 68) + model = graph.produce_model() + model(torch.Tensor(np.random.random((10, 1, 28, 28)))) + + +def test_long_transform5(): + graph = ResNetGenerator(10, (28, 28, 1)).generate() + graph.to_concat_skip_model(19, 60) + graph.to_wider_model(52, 256) + model = graph.produce_model() + model(torch.Tensor(np.random.random((10, 1, 28, 28))))