Skip to content

Commit

Permalink
Add --per_res_only option. When this flag is provided, it'll write pe…
Browse files Browse the repository at this point in the history
…r-residue accuracy estimation only
  • Loading branch information
Minkyung Baek committed Jan 9, 2021
1 parent c4257af commit 21c1775
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 21 deletions.
44 changes: 31 additions & 13 deletions DeepAccNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def main():
action="store_true",
default=False,
help="Writing results to a csv file (Default: False)")

parser.add_argument("--per_res_only",
"-pr",
action="store_true",
default=False,
help="Store per-residue accuracy only (Default: False)")

parser.add_argument("--leaveTempFile",
"-lt",
Expand Down Expand Up @@ -242,15 +248,23 @@ def main():

if not args.csv:
if args.ensemble:
np.savez_compressed(join(args.output, s+"_"+modelname[:-4]+".npz"),
lddt = lddt.cpu().detach().numpy().astype(np.float16),
estogram = estogram.cpu().detach().numpy().astype(np.float16),
mask = mask.cpu().detach().numpy().astype(np.float16))
if args.per_res_only:
np.savez_compressed(join(args.output, s+"_"+modelname[:-4]+".npz"),
lddt = lddt.cpu().detach().numpy().astype(np.float16))
else:
np.savez_compressed(join(args.output, s+"_"+modelname[:-4]+".npz"),
lddt = lddt.cpu().detach().numpy().astype(np.float16),
estogram = estogram.cpu().detach().numpy().astype(np.float16),
mask = mask.cpu().detach().numpy().astype(np.float16))
else:
np.savez_compressed(join(args.output, s+".npz"),
lddt = lddt.cpu().detach().numpy().astype(np.float16),
estogram = estogram.cpu().detach().numpy().astype(np.float16),
mask = mask.cpu().detach().numpy().astype(np.float16))
if args.per_res_only:
np.savez_compressed(join(args.output, s+".npz"),
lddt = lddt.cpu().detach().numpy().astype(np.float16))
else:
np.savez_compressed(join(args.output, s+".npz"),
lddt = lddt.cpu().detach().numpy().astype(np.float16),
estogram = estogram.cpu().detach().numpy().astype(np.float16),
mask = mask.cpu().detach().numpy().astype(np.float16))
except:
print("Failed to predict for", join(args.output, s+"_"+modelname[:-4]+".npz"))

Expand Down Expand Up @@ -311,10 +325,14 @@ def main():
val = torch.Tensor(val).to(device)

estogram, mask, lddt, dmy = model(idx, val, f1d, f2d)
np.savez_compressed(outsamplename+".npz",
lddt = lddt.cpu().detach().numpy().astype(np.float16),
estogram = estogram.cpu().detach().numpy().astype(np.float16),
mask = mask.cpu().detach().numpy().astype(np.float16))
if args.per_res_only:
np.savez_compressed(outsamplename+".npz",
lddt = lddt.cpu().detach().numpy().astype(np.float16))
else:
np.savez_compressed(outsamplename+".npz",
lddt = lddt.cpu().detach().numpy().astype(np.float16),
estogram = estogram.cpu().detach().numpy().astype(np.float16),
mask = mask.cpu().detach().numpy().astype(np.float16))

if not args.leaveTempFile:
dan.clean([outsamplename],
Expand All @@ -326,4 +344,4 @@ def main():


if __name__== "__main__":
main()
main()
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ optional arguments:
-h, --help show this help message and exit
--pdb, -pdb Running on a single pdb file instead of a folder (Default: False)
--csv, -csv Writing results to a csv file (Default: False)
--per_res_only, -pr Writing per-residue accuracy only (Default: False)
--leaveTempFile, -lt Leaving temporary files (Default: False)
--process PROCESS, -p PROCESS
Specifying # of cpus to use for featurization (Default: 1)
Expand Down
23 changes: 15 additions & 8 deletions deepAccNet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def seqsep(psize, normalizer=100, axis=-1):
ret[i,j] = abs(i-j)*1.0/100-1.0
return np.expand_dims(ret, axis)

def merge(samples, outfolder, verbose=False):
def merge(samples, outfolder, per_res_only=False, verbose=False):
for j in range(len(samples)):
try:
if verbose: print("Merging", samples[j])
Expand All @@ -108,19 +108,26 @@ def merge(samples, outfolder, verbose=False):
for i in ["best", "second", "third", "fourth"]:
temp = np.load(join(outfolder, samples[j]+"_"+i+".npz"))
lddt.append(temp["lddt"])
if per_res_only:
continue
estogram.append(temp["estogram"])
mask.append(temp["mask"])

# Averaging
lddt = np.mean(lddt, axis=0)
estogram = np.mean(estogram, axis=0)
mask = np.mean(mask, axis=0)
if not per_res_only:
estogram = np.mean(estogram, axis=0)
mask = np.mean(mask, axis=0)

# Saving
np.savez_compressed(join(outfolder, samples[j]+".npz"),
lddt = lddt.astype(np.float16),
estogram = estogram.astype(np.float16),
mask = mask.astype(np.float16))
if per_res_only:
np.savez_compressed(join(outfolder, samples[j]+".npz"),
lddt = lddt.astype(np.float16))
else:
np.savez_compressed(join(outfolder, samples[j]+".npz"),
lddt = lddt.astype(np.float16),
estogram = estogram.astype(np.float16),
mask = mask.astype(np.float16))
except:
print("Failed to merge for", join(outfolder, samples[j]+".npz"))

Expand All @@ -140,4 +147,4 @@ def clean(samples, outfolder, ensemble=False, verbose=False):
if isfile(join(outfolder, samples[i]+"_"+j+".npz")):
os.remove(join(outfolder, samples[i]+"_"+j+".npz"))
except:
print("Failed to clean for", samples[i])
print("Failed to clean for", samples[i])

0 comments on commit 21c1775

Please sign in to comment.