Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
MGlauer committed Feb 16, 2023
2 parents d774567 + 8fcf54b commit 2ecc159
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
28 changes: 17 additions & 11 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def __call__(self, target, input):
return gen_loss + disc_loss


def filter_dict(d, filter_key):
return {str(k)[len(filter_key):]: v for k, v in
d.items() if
str(k).startswith(filter_key)}


class Electra(JCIBaseNet):
NAME = "Electra"

Expand Down Expand Up @@ -151,26 +157,26 @@ def __init__(self, **kwargs):
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0))
model_prefix = kwargs.get("load_prefix", None)
if pretrained_checkpoint:
with open(pretrained_checkpoint, "rb") as fin:
model_dict = torch.load(fin,map_location=self.device)
if model_prefix:
state_dict = {str(k)[len(model_prefix):]:v for k,v in model_dict["state_dict"].items() if str(k).startswith(model_prefix)}
else:
state_dict = model_dict["state_dict"]
self.electra = ElectraModel.from_pretrained(None, state_dict=state_dict, config=self.config)
else:
self.electra = ElectraModel(config=self.config)

in_d = self.config.hidden_size

self.output = nn.Sequential(
nn.Dropout(self.config.hidden_dropout_prob),
nn.Linear(in_d, in_d),
nn.GELU(),
nn.Dropout(self.config.hidden_dropout_prob),
nn.Linear(in_d, self.config.num_labels),
)
if pretrained_checkpoint:
with open(pretrained_checkpoint, "rb") as fin:
model_dict = torch.load(fin,map_location=self.device)
if model_prefix:
state_dict = filter_dict(model_dict["state_dict"], model_prefix)
else:
state_dict = model_dict["state_dict"]
self.electra = ElectraModel.from_pretrained(None, state_dict={k:v for (k,v) in state_dict.items() if k.startswith("electra.")}, config=self.config)
self.output.load_state_dict(filter_dict(state_dict,"output."))
else:
self.electra = ElectraModel(config=self.config)

def _get_data_for_loss(self, model_output, labels):
mask = model_output.get("target_mask")
Expand Down
18 changes: 13 additions & 5 deletions chebai/result/molplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from chebai.result.base import ResultProcessor


class AttentionMolPlot(abc.ABC):
def plot_attentions(self, smiles, attention, threshold, labels):
class AttentionMolPlot:

def draw_attention_molecule(self, smiles, attention):
pmol = self.read_smiles_with_index(smiles)
rdmol = Chem.MolFromSmiles(smiles)
if not rdmol:
Expand All @@ -34,26 +35,33 @@ def plot_attentions(self, smiles, attention, threshold, labels):
}
d = rdMolDraw2D.MolDraw2DCairo(500, 500)
cmap = cm.ScalarMappable(cmap=cm.Greens)
attention_colors = cmap.to_rgba(attention, norm=False)

aggr_attention_colors = cmap.to_rgba(
np.max(attention[2:, :], axis=0), norm=False
)
cols = {
token_to_node_map[token_index]: tuple(
aggr_attention_colors[token_index].tolist()
)
for node, token_index in nx.get_node_attributes(pmol, "token_index").items()
for node, token_index in
nx.get_node_attributes(pmol, "token_index").items()
}
highlight_atoms = [
token_to_node_map[token_index]
for node, token_index in nx.get_node_attributes(pmol, "token_index").items()
for node, token_index in
nx.get_node_attributes(pmol, "token_index").items()
]
rdMolDraw2D.PrepareAndDrawMolecule(
d, rdmol, highlightAtoms=highlight_atoms, highlightAtomColors=cols
)

d.FinishDrawing()
return d

def plot_attentions(self, smiles, attention, threshold, labels):
d = self.draw_attention_molecule(smiles, attention)
cmap = cm.ScalarMappable(cmap=cm.Greens)
attention_colors = cmap.to_rgba(attention, norm=False)
num_tokens = sum(1 for _ in _tokenize(smiles))

fig = plt.figure(figsize=(15, 15), facecolor="w")
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup

setup(
name="ChEBI-learn",
name="chebai",
version="0.0.0",
packages=["chebai", "chebai.models"],
url="",
Expand Down Expand Up @@ -39,6 +39,8 @@
"scikit-network",
"svgutils",
"matplotlib",
"rdkit",
"selfies"
],
extras_require={"dev": ["black", "isort", "pre-commit"]},
)

0 comments on commit 2ecc159

Please sign in to comment.