Skip to content

Commit

Permalink
Merge pull request #8 from shervinea/make-visualization-work-again
Browse files Browse the repository at this point in the history
Make visualization.py work again
  • Loading branch information
shervinea authored Sep 9, 2021
2 parents 958613c + 17f142a commit 31d30e0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
23 changes: 15 additions & 8 deletions enzynet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

# MIT License

from typing import List, Optional, Text, Tuple
from typing import Optional, Text, Tuple

import numpy as np
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt

from enzynet import pdb
Expand Down Expand Up @@ -52,7 +53,11 @@ def plot_volume(vol: np.ndarray, pdb_id: Text, v_size: int, num: int,
plt.rc('font', family='serif')
fig = plt.figure(figsize=(4,4))
ax = fig.gca(projection='3d')
ax.set_aspect('equal')

# Reproduces the functionality of ax.set_aspect('equal').
# Source: https://github.com/matplotlib/matplotlib/issues/17172#issuecomment-830139107
ax.set_box_aspect(
[ub - lb for lb, ub in (getattr(ax, f'get_{a}lim')() for a in 'xyz')])

# Parameters.
len_vol = vol.shape[0]
Expand Down Expand Up @@ -94,9 +99,9 @@ def plot_volume(vol: np.ndarray, pdb_id: Text, v_size: int, num: int,
ax.zaxis._axinfo["grid"]['linewidth'] = 0.1

# Change thickness of ticks.
ax.xaxis._axinfo["tick"]['linewidth'] = 0.1
ax.yaxis._axinfo["tick"]['linewidth'] = 0.1
ax.zaxis._axinfo["tick"]['linewidth'] = 0.1
ax.xaxis._axinfo["tick"]['linewidth'][True] = 0.1
ax.yaxis._axinfo["tick"]['linewidth'][True] = 0.1
ax.zaxis._axinfo["tick"]['linewidth'][True] = 0.1

# Change tick placement.
ax.xaxis._axinfo['tick']['inward_factor'] = 0
Expand All @@ -115,7 +120,7 @@ def plot_volume(vol: np.ndarray, pdb_id: Text, v_size: int, num: int,
def cuboid_data(
pos: Tuple[float, float, float],
size: Tuple[int, int, int]=(1,1,1)
) -> Tuple[List[List[float]], List[List[float]], List[List[float]]]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Gets coordinates of cuboid."""
# Gets the (left, outside, bottom) point.
o = [a - b / 2 for a, b in zip(pos, size)]
Expand All @@ -132,15 +137,17 @@ def cuboid_data(
[o[2], o[2], o[2] + h, o[2] + h, o[2]],
[o[2], o[2], o[2] + h, o[2] + h, o[2]]]

return x, y, z
return np.array(x), np.array(y), np.array(z)


def plot_cube_at(pos: Tuple[float, float, float] = (0,0,0),
ax: Optional[plt.gca] = None) -> None:
"""Plots a cube element at position pos."""
lightsource = mcolors.LightSource(azdeg=135, altdeg=0)
if ax != None:
X, Y, Z = cuboid_data(pos)
ax.plot_surface(X, Y, Z, color='g', rstride=1, cstride=1, alpha=1)
ax.plot_surface(X, Y, Z, color='g', rstride=1, cstride=1, alpha=1,
lightsource=lightsource)


def plot_cube_weights_at(pos: Tuple[float, float, float] = (0, 0, 0),
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ keras
tensorflow

# Plot
matplotlib
matplotlib==3.4.1

# Biology
biopython
Expand Down

0 comments on commit 31d30e0

Please sign in to comment.