diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..a090a0d 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -7,11 +7,11 @@ def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): - result, params_info = summary_string( + result, summary_info = summary_string( model, input_size, batch_size, device, dtypes) print(result) - return params_info + return summary_info def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): @@ -19,7 +19,7 @@ def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0 dtypes = [torch.FloatTensor]*len(input_size) summary_str = '' - + summary_info = {"params_info": tuple(), "size_info": tuple()} def register_hook(module): def hook(module, input, output): class_name = str(module.__class__).split(".")[-1].split("'")[0] @@ -117,4 +117,8 @@ def hook(module, input, output): summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" summary_str += "----------------------------------------------------------------" + "\n" # return summary - return summary_str, (total_params, trainable_params) + + summary_info['params_info'] = (total_params, trainable_params) + summary_info['size_info'] = (total_input_size, total_output_size, total_params_size, total_size) + + return summary_str, summary_info