Skip to content

Commit

Permalink
Merge pull request #251 from SysBioChalmers/feat/improve-dlkcat-docker
Browse files Browse the repository at this point in the history
feat: improve dlkcat docker
  • Loading branch information
edkerk authored Mar 5, 2023
2 parents 0b07f29 + c3c3a1a commit 9acd37c
Show file tree
Hide file tree
Showing 23 changed files with 430 additions and 335 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pickle filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ geckopy/geckopy/data_files/*.xml
*.ppt*
# Only allowed .mat file
!phylDist.mat
# dlkcat folder
dlkcat/

# Packages #
############
Expand Down
4 changes: 1 addition & 3 deletions protocol.m
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@
% kcat values. If the file already exists, it will not be overwritten, to
% avoid losing existing kcat values (unless 'overwrite' was set as 'true'
% when running writeDLKcatInput.
% runDLKcat will attempt to download, install and run DLKcat, but this
% might not work for all systems. In that case, the user will be directed
% to manually download, install and DLKcat via the GECKO-provided DLKcat package
% runDLKcat will run DLKcat using a Docker image.

writeDLKcatInput(ecModel);
runDLKcat();
Expand Down
333 changes: 333 additions & 0 deletions src/dlkcat-gecko/DLKcat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
#!/usr/bin/python
# coding: utf-8

# Author: LE YUAN
# This script is customized for use with GECKO 3

import os
import sys
import math
import pickle
import numpy as np
from rdkit import Chem
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import mean_squared_error,r2_score


def load_pickle(file_name):
with open(file_name, 'rb') as f:
return pickle.load(f)

fingerprint_dict = load_pickle('input/fingerprint_dict.pickle')
atom_dict = load_pickle('input/atom_dict.pickle')
bond_dict = load_pickle('input/bond_dict.pickle')
edge_dict = load_pickle('input/edge_dict.pickle')
word_dict = load_pickle('input/sequence_dict.pickle')

def split_sequence(sequence, ngram):
sequence = '-' + sequence + '='
# print(sequence)
# words = [word_dict[sequence[i:i+ngram]] for i in range(len(sequence)-ngram+1)]

words = list()
for i in range(len(sequence)-ngram+1) :
try :
words.append(word_dict[sequence[i:i+ngram]])
except :
word_dict[sequence[i:i+ngram]] = 0
words.append(word_dict[sequence[i:i+ngram]])

return np.array(words)
# return word_dict

def create_atoms(mol):
"""Create a list of atom (e.g., hydrogen and oxygen) IDs
considering the aromaticity."""
# atom_dict = defaultdict(lambda: len(atom_dict))
atoms = [a.GetSymbol() for a in mol.GetAtoms()]
# print(atoms)
for a in mol.GetAromaticAtoms():
i = a.GetIdx()
atoms[i] = (atoms[i], 'aromatic')
atoms = [atom_dict[a] for a in atoms]
# atoms = list()
# for a in atoms :
# try:
# atoms.append(atom_dict[a])
# except :
# atom_dict[a] = 0
# atoms.append(atom_dict[a])

return np.array(atoms)

def create_ijbonddict(mol):
"""Create a dictionary, which each key is a node ID
and each value is the tuples of its neighboring node
and bond (e.g., single and double) IDs."""
# bond_dict = defaultdict(lambda: len(bond_dict))
i_jbond_dict = defaultdict(lambda: [])
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
bond = bond_dict[str(b.GetBondType())]
i_jbond_dict[i].append((j, bond))
i_jbond_dict[j].append((i, bond))
return i_jbond_dict

def extract_fingerprints(atoms, i_jbond_dict, radius):
"""Extract the r-radius subgraphs (i.e., fingerprints)
from a molecular graph using Weisfeiler-Lehman algorithm."""

# fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))
# edge_dict = defaultdict(lambda: len(edge_dict))

if (len(atoms) == 1) or (radius == 0):
fingerprints = [fingerprint_dict[a] for a in atoms]

else:
nodes = atoms
i_jedge_dict = i_jbond_dict

for _ in range(radius):

"""Update each node ID considering its neighboring nodes and edges
(i.e., r-radius subgraphs or fingerprints)."""
fingerprints = []
for i, j_edge in i_jedge_dict.items():
neighbors = [(nodes[j], edge) for j, edge in j_edge]
fingerprint = (nodes[i], tuple(sorted(neighbors)))
# fingerprints.append(fingerprint_dict[fingerprint])
# fingerprints.append(fingerprint_dict.get(fingerprint))
try :
fingerprints.append(fingerprint_dict[fingerprint])
except :
fingerprint_dict[fingerprint] = 0
fingerprints.append(fingerprint_dict[fingerprint])

nodes = fingerprints

"""Also update each edge ID considering two nodes
on its both sides."""
_i_jedge_dict = defaultdict(lambda: [])
for i, j_edge in i_jedge_dict.items():
for j, edge in j_edge:
both_side = tuple(sorted((nodes[i], nodes[j])))
# edge = edge_dict[(both_side, edge)]
# edge = edge_dict.get((both_side, edge))
try :
edge = edge_dict[(both_side, edge)]
except :
edge_dict[(both_side, edge)] = 0
edge = edge_dict[(both_side, edge)]

_i_jedge_dict[i].append((j, edge))
i_jedge_dict = _i_jedge_dict

return np.array(fingerprints)

def create_adjacency(mol):
adjacency = Chem.GetAdjacencyMatrix(mol)
return np.array(adjacency)

def dump_dictionary(dictionary, filename):
with open(filename, 'wb') as file:
pickle.dump(dict(dictionary), file)

def load_tensor(file_name, dtype):
return [dtype(d).to(device) for d in np.load(file_name + '.npy', allow_pickle=True)]

class Predictor(object):
def __init__(self, model):
self.model = model

def predict(self, data):
predicted_value = self.model.forward(data)

return predicted_value

class KcatPrediction(nn.Module):
def __init__(self, device, n_fingerprint, n_word, dim, layer_gnn, window, layer_cnn, layer_output):
super(KcatPrediction, self).__init__()
self.embed_fingerprint = nn.Embedding(n_fingerprint, dim)
self.embed_word = nn.Embedding(n_word, dim)
self.W_gnn = nn.ModuleList([nn.Linear(dim, dim)
for _ in range(layer_gnn)])
self.W_cnn = nn.ModuleList([nn.Conv2d(
in_channels=1, out_channels=1, kernel_size=2*window+1,
stride=1, padding=window) for _ in range(layer_cnn)])
self.W_attention = nn.Linear(dim, dim)
self.W_out = nn.ModuleList([nn.Linear(2*dim, 2*dim)
for _ in range(layer_output)])
# self.W_interaction = nn.Linear(2*dim, 2)
self.W_interaction = nn.Linear(2*dim, 1)

self.device = device
self.dim = dim
self.layer_gnn = layer_gnn
self.window = window
self.layer_cnn = layer_cnn
self.layer_output = layer_output

def gnn(self, xs, A, layer):
for i in range(layer):
hs = torch.relu(self.W_gnn[i](xs))
xs = xs + torch.matmul(A, hs)
# return torch.unsqueeze(torch.sum(xs, 0), 0)
return torch.unsqueeze(torch.mean(xs, 0), 0)

def attention_cnn(self, x, xs, layer):
"""The attention mechanism is applied to the last layer of CNN."""

xs = torch.unsqueeze(torch.unsqueeze(xs, 0), 0)
for i in range(layer):
xs = torch.relu(self.W_cnn[i](xs))
xs = torch.squeeze(torch.squeeze(xs, 0), 0)

h = torch.relu(self.W_attention(x))
hs = torch.relu(self.W_attention(xs))
weights = torch.tanh(F.linear(h, hs))
ys = torch.t(weights) * hs

# return torch.unsqueeze(torch.sum(ys, 0), 0)
return torch.unsqueeze(torch.mean(ys, 0), 0)

def forward(self, inputs):

fingerprints, adjacency, words = inputs

layer_gnn = 3
layer_cnn = 3
layer_output = 3

"""Compound vector with GNN."""
fingerprint_vectors = self.embed_fingerprint(fingerprints)
compound_vector = self.gnn(fingerprint_vectors, adjacency, layer_gnn)

"""Protein vector with attention-CNN."""
word_vectors = self.embed_word(words)
protein_vector = self.attention_cnn(compound_vector,
word_vectors, layer_cnn)

"""Concatenate the above two vectors and output the interaction."""
cat_vector = torch.cat((compound_vector, protein_vector), 1)
for j in range(layer_output):
cat_vector = torch.relu(self.W_out[j](cat_vector))
interaction = self.W_interaction(cat_vector)
# print(interaction)

return interaction

def __call__(self, data, train=True):

inputs, correct_interaction = data[:-1], data[-1]
predicted_interaction = self.forward(inputs)
print(predicted_interaction)

if train:
loss = F.mse_loss(predicted_interaction, correct_interaction)
return loss
else:
correct_values = correct_interaction.to('cpu').data.numpy()
predicted_values = predicted_interaction.to('cpu').data.numpy()[0]
# correct_values = np.concatenate(correct_values)
# predicted_values = np.concatenate(predicted_values)
# ys = F.softmax(predicted_interaction, 1).to('cpu').data.numpy()
# predicted_values = list(map(lambda x: np.argmax(x), ys))
print(correct_values)
print(predicted_values)
# predicted_scores = list(map(lambda x: x[1], ys))
return correct_values, predicted_values

def main() :
inputfile = sys.argv[1:][0]
outputfile = sys.argv[1:][1]
#print(inputfile)

if os.access(inputfile, os.R_OK):
with open(inputfile, 'r') as infile :
lines = infile.readlines()

fingerprint_dict = load_pickle('input/fingerprint_dict.pickle')
atom_dict = load_pickle('input/atom_dict.pickle')
bond_dict = load_pickle('input/bond_dict.pickle')
word_dict = load_pickle('input/sequence_dict.pickle')
n_fingerprint = len(fingerprint_dict)
n_word = len(word_dict)

radius=2
ngram=3

dim=10
layer_gnn=3
side=5
window=11
layer_cnn=3
layer_output=3
lr=1e-3
lr_decay=0.5
decay_interval=10
weight_decay=1e-6
iteration=100

if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

# torch.manual_seed(1234)
Kcat_model = KcatPrediction(device, n_fingerprint, n_word, 2*dim, layer_gnn, window, layer_cnn, layer_output).to(device)
Kcat_model.load_state_dict(torch.load('input/all--radius2--ngram3--dim20--layer_gnn3--window11--layer_cnn3--layer_output3--lr1e-3--lr_decay0.5--decay_interval10--weight_decay1e-6--iteration50', map_location=device))
# print(state_dict.keys())
# model.eval()
predictor = Predictor(Kcat_model)

with open(outputfile, 'w') as outfile :

for line in lines[1:] :
line_item = list()
data = line.strip().split('\t')
rxn,gene,sub,smiles,sequence,Kcat_value = data

try :
if smiles != None and "." not in smiles and len(smiles) !=0 :

mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
atoms = create_atoms(mol)
i_jbond_dict = create_ijbonddict(mol)

fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius)

adjacency = create_adjacency(mol)

words = split_sequence(sequence,ngram)

fingerprints = torch.LongTensor(fingerprints).to(device)
adjacency = torch.FloatTensor(adjacency).to(device)
words = torch.LongTensor(words).to(device)

inputs = [fingerprints, adjacency, words]

prediction = predictor.predict(inputs)
Kcat_log_value = prediction.item()
Kcat_value = '%.4f' %math.pow(2,Kcat_log_value)
line_item = [rxn,gene,sub,smiles,sequence,Kcat_value]

outfile.write('\t'.join(line_item)+'\n')
else :
Kcat_value = 'None'
line_item = [rxn,gene,sub,smiles,sequence,Kcat_value]
outfile.write('\t'.join(line_item)+'\n')

except :
Kcat_value = 'None'
line_item = [rxn,gene,sub,smiles,sequence,Kcat_value]
outfile.write('\t'.join(line_item)+'\n')
else:
print('DLKcat cannot find the input file ' + inputfile)

if __name__ == '__main__' :
main()
8 changes: 8 additions & 0 deletions src/dlkcat-gecko/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM python:3.9-slim

LABEL org.opencontainers.image.source=https://github.com/sysbiochalmers/gecko
LABEL version="0.1"
LABEL description="Docker image of SysBioChalmers/DKLcat adapted for SysBioChalmers/GECKO version 3"

COPY . .
RUN pip install --no-cache-dir -r requirements.txt torch@https://download.pytorch.org/whl/cpu/torch-1.9.1%2Bcpu-cp39-cp39-linux_x86_64.whl
2 changes: 2 additions & 0 deletions src/dlkcat-gecko/input/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.npy filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
3 changes: 3 additions & 0 deletions src/dlkcat-gecko/input/adjacencies.npy
Git LFS file not shown
Binary file not shown.
3 changes: 3 additions & 0 deletions src/dlkcat-gecko/input/atom_dict.pickle
Git LFS file not shown
3 changes: 3 additions & 0 deletions src/dlkcat-gecko/input/bond_dict.pickle
Git LFS file not shown
3 changes: 3 additions & 0 deletions src/dlkcat-gecko/input/compounds.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions src/dlkcat-gecko/input/edge_dict.pickle
Git LFS file not shown
3 changes: 3 additions & 0 deletions src/dlkcat-gecko/input/fingerprint_dict.pickle
Git LFS file not shown
3 changes: 3 additions & 0 deletions src/dlkcat-gecko/input/proteins.npy
Git LFS file not shown
Loading

0 comments on commit 9acd37c

Please sign in to comment.