-
Notifications
You must be signed in to change notification settings - Fork 0
/
PlotUtils.py
72 lines (62 loc) · 2.71 KB
/
PlotUtils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
File name: plot_utils.py
Author: Benjamin Planche
Date created: 14.02.2019
Date last modified: 14:42 14.02.2019
Python Version: 3.6
Copyright = "Copyright (C) 2018-2019 of Packt"
Credits = ["Eliot Andres, Benjamin Planche"]
License = "MIT"
Version = "1.0.0"
Maintainer = "non"
Status = "Prototype" # "Prototype", "Development", or "Production"
"""
#==============================================================================
# Imported Modules
#==============================================================================
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
#==============================================================================
# Function Definitions
#==============================================================================
def plot_image_grid(images, titles=None, figure=None,
grayscale=False, transpose=False):
"""
Plot a grid of n x m images
:param images: Images in a n x m array
:param titles: (opt.) List of m titles for each image column
:param figure: (opt.) Pyplot figure (if None, will be created)
:param grayscale: (opt.) Flag to draw the images in grayscale
:param transpose: (opt.) Flag to transpose the grid
:return: Pyplot figure filled with the images
"""
num_cols, num_rows = len(images), len(images[0])
img_ratio = images[0][0].shape[1] / images[0][0].shape[0]
if transpose:
vert_grid_shape, hori_grid_shape = (1, num_rows), (num_cols, 1)
figsize = (int(num_rows * 5 * img_ratio), num_cols * 5)
wspace, hspace = 0.2, 0.
else:
vert_grid_shape, hori_grid_shape = (num_rows, 1), (1, num_cols)
figsize = (int(num_cols * 5 * img_ratio), num_rows * 5)
hspace, wspace = 0.2, 0.
if figure is None:
figure = plt.figure(figsize=figsize)
imshow_params = {'cmap': plt.get_cmap('gray')} if grayscale else {}
grid_spec = gridspec.GridSpec(*hori_grid_shape, wspace=0, hspace=0)
for j in range(num_cols):
grid_spec_j = gridspec.GridSpecFromSubplotSpec(
*vert_grid_shape, subplot_spec=grid_spec[j], wspace=wspace, hspace=hspace)
for i in range(num_rows):
ax_img = figure.add_subplot(grid_spec_j[i])
# ax_img.axis('off')
ax_img.set_yticks([])
ax_img.set_xticks([])
if titles is not None:
if transpose:
ax_img.set_ylabel(titles[j], fontsize=25)
else:
ax_img.set_title(titles[j], fontsize=15)
ax_img.imshow(images[j][i], **imshow_params)
figure.tight_layout()
return figure