forked from wangjuan001/hicplus
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix straw issue 74 (aidenlab/straw#74) & pytorch device issue
- Loading branch information
1 parent
65d2639
commit dc61a4a
Showing
17 changed files
with
924 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
__version__='1.1.0' | ||
__license__='GPLv3+' | ||
Me = __file__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import torch | ||
from torch.autograd import Variable | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from torch.utils import data | ||
import gzip | ||
import sys | ||
import torch.optim as optim | ||
conv2d1_filters_numbers = 8 | ||
conv2d1_filters_size = 9 | ||
conv2d2_filters_numbers = 8 | ||
conv2d2_filters_size = 1 | ||
conv2d3_filters_numbers = 1 | ||
conv2d3_filters_size = 5 | ||
|
||
class Net(nn.Module): | ||
def __init__(self, D_in, D_out): | ||
super(Net, self).__init__() | ||
# 1 input image channel, 6 output channels, 5x5 square convolution | ||
# kernel | ||
self.conv1 = nn.Conv2d(1, conv2d1_filters_numbers, conv2d1_filters_size) | ||
self.conv2 = nn.Conv2d(conv2d1_filters_numbers, conv2d2_filters_numbers, conv2d2_filters_size) | ||
self.conv3 = nn.Conv2d(conv2d2_filters_numbers, 1, conv2d3_filters_size) | ||
|
||
def forward(self, x): | ||
#print("start forwardingf") | ||
x = self.conv1(x) | ||
x = F.relu(x) | ||
x = self.conv2(x) | ||
x = F.relu(x) | ||
x = self.conv3(x) | ||
x = F.relu(x) | ||
return x | ||
''' | ||
def num_flat_features(self, x): | ||
size = x.size()[1:] # all dimensions except the batch dimension | ||
num_features = 1 | ||
for s in size: | ||
num_features *= s | ||
return num_features | ||
''' | ||
''' | ||
net = Net(40, 24) | ||
#sys.exit() | ||
#low_resolution_samples = low_resolution_samples.reshape((low_resolution_samples.shape[0], 40, 40)) | ||
#print low_resolution_samples[0:1, :,: ,: ].shape | ||
#low_resolution_samples = torch.from_numpy(low_resolution_samples[0:1, :,: ,: ]) | ||
#X = Variable(low_resolution_samples) | ||
#print X | ||
#Y = Variable(torch.from_numpy(Y[0])) | ||
#X = Variable(torch.randn(1, 1, 40, 40)) | ||
#print X | ||
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) | ||
criterion = nn.MSELoss() | ||
for epoch in range(2): # loop over the dataset multiple times | ||
print "epoch", epoch | ||
running_loss = 0.0 | ||
for i, data in enumerate(train_loader, 0): | ||
# get the inputs | ||
inputs, labels = data | ||
#print(inputs.size()) | ||
#print(labels.size()) | ||
#print type(inputs) | ||
# wrap them in Variable | ||
inputs, labels = Variable(inputs), Variable(labels) | ||
# zero the parameter gradients | ||
optimizer.zero_grad() | ||
# forward + backward + optimize | ||
outputs = net(inputs) | ||
#print outputs | ||
loss = criterion(outputs, labels) | ||
loss.backward() | ||
optimizer.step() | ||
print i | ||
# print statistics | ||
#print type(loss) | ||
#print loss | ||
#print loss.data[0] | ||
#print loss.data | ||
#print type(data), len(data) | ||
#print "the key is ", type(data[0]) | ||
print('Finished Training') | ||
output = net(X) | ||
print(output) | ||
print type(output) | ||
loss = criterion(output, Y) | ||
net.zero_grad() # zeroes the gradient buffers of all parameters | ||
print('conv1.bias.grad before backward') | ||
print(net.conv1.bias.grad) | ||
loss.backward() | ||
print('conv1.bias.grad after backward') | ||
print(net.conv1.weight.grad) | ||
''' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import os,sys | ||
from torch.utils import data | ||
from hicplus import model | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.autograd import Variable | ||
import straw | ||
from scipy.sparse import csr_matrix, coo_matrix, vstack, hstack | ||
from scipy import sparse | ||
import numpy as np | ||
from hicplus import utils | ||
from time import gmtime, strftime | ||
from datetime import datetime | ||
import argparse | ||
|
||
startTime = datetime.now() | ||
|
||
use_gpu = 0 #opt.cuda | ||
#if use_gpu and not torch.cuda.is_available(): | ||
# raise Exception("No GPU found, please run without --cuda") | ||
|
||
def predict(M,N,inmodel): | ||
|
||
prediction_1 = np.zeros((N, N)) | ||
|
||
for low_resolution_samples, index in utils.divide(M): | ||
|
||
#print(index.shape) | ||
|
||
batch_size = low_resolution_samples.shape[0] #256 | ||
|
||
lowres_set = data.TensorDataset(torch.from_numpy(low_resolution_samples), torch.from_numpy(np.zeros(low_resolution_samples.shape[0]))) | ||
try: | ||
lowres_loader = torch.utils.data.DataLoader(lowres_set, batch_size=batch_size, shuffle=False) | ||
except: | ||
continue | ||
|
||
hires_loader = lowres_loader | ||
|
||
m = model.Net(40, 28) | ||
m.load_state_dict(torch.load(inmodel, map_location=torch.device('cpu'))) | ||
|
||
if torch.cuda.is_available(): | ||
m = m.cuda() | ||
|
||
for i, v1 in enumerate(lowres_loader): | ||
_lowRes, _ = v1 | ||
_lowRes = Variable(_lowRes).float() | ||
if use_gpu: | ||
_lowRes = _lowRes.cuda() | ||
y_prediction = m(_lowRes) | ||
|
||
|
||
y_predict = y_prediction.data.cpu().numpy() | ||
|
||
|
||
# recombine samples | ||
length = int(y_predict.shape[2]) | ||
y_predict = np.reshape(y_predict, (y_predict.shape[0], length, length)) | ||
|
||
|
||
for i in range(0, y_predict.shape[0]): | ||
|
||
x = int(index[i][1]) | ||
y = int(index[i][2]) | ||
#print np.count_nonzero(y_predict[i]) | ||
prediction_1[x+6:x+34, y+6:y+34] = y_predict[i] | ||
|
||
return(prediction_1) | ||
|
||
def chr_pred(hicfile, chrN1, chrN2, binsize, inmodel): | ||
M = utils.matrix_extract(chrN1, chrN2, binsize, hicfile) | ||
#print(M.shape) | ||
N = M.shape[0] | ||
|
||
chr_Mat = predict(M, N, inmodel) | ||
|
||
|
||
# if Ncol > Nrow: | ||
# chr_Mat = chr_Mat[:Ncol, :Nrow] | ||
# chr_Mat = chr_Mat.T | ||
# if Nrow > Ncol: | ||
# chr_Mat = chr_Mat[:Nrow, :Ncol] | ||
# print(dat.head()) | ||
return(chr_Mat) | ||
|
||
|
||
|
||
def writeBed(Mat, outname,binsize, chrN1,chrN2): | ||
with open(outname,'w') as chrom: | ||
r, c = Mat.nonzero() | ||
for i in range(r.size): | ||
contact = int(round(Mat[r[i],c[i]])) | ||
if contact == 0: | ||
continue | ||
#if r[i]*binsize > Len1 or (r[i]+1)*binsize > Len1: | ||
# continue | ||
#if c[i]*binsize > Len2 or (c[i]+1)*binsize > Len2: | ||
# continue | ||
line = [chrN1, r[i]*binsize, (r[i]+1)*binsize, | ||
chrN2, c[i]*binsize, (c[i]+1)*binsize, contact] | ||
chrom.write('chr'+str(line[0])+':'+str(line[1])+'-'+str(line[2])+ | ||
'\t'+'chr'+str(line[3])+':'+str(line[4])+'-'+str(line[5])+'\t'+str(line[6])+'\n') | ||
|
||
def main(args): | ||
chrN1, chrN2 = args.chrN | ||
binsize = args.binsize | ||
inmodel = args.model | ||
hicfile = args.inputfile | ||
#name = os.path.basename(inmodel).split('.')[0] | ||
#outname = 'chr'+str(chrN1)+'_'+name+'_'+str(binsize//1000)+'pred.txt' | ||
outname = args.outputfile | ||
Mat = chr_pred(hicfile,chrN1,chrN2,binsize,inmodel) | ||
print(Mat.shape) | ||
writeBed(Mat, outname, binsize,chrN1, chrN2) | ||
#print(enhM.shape) | ||
if __name__ == '__main__': | ||
main() | ||
|
||
print(datetime.now() - startTime) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/usr/bin/env python | ||
import os,sys | ||
from torch.utils import data | ||
from hicplus import model | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.autograd import Variable | ||
import straw | ||
from scipy.sparse import csr_matrix, coo_matrix, vstack, hstack | ||
from scipy import sparse | ||
import numpy as np | ||
from hicplus import utils | ||
from time import gmtime, strftime | ||
from datetime import datetime | ||
import argparse | ||
from hicplus import pred_chromosome | ||
|
||
startTime = datetime.now() | ||
|
||
def pred_genome(hicfile, binsize, inmodel): | ||
hic_info = utils.read_hic_header(hicfile) | ||
chromindex = {} | ||
i = 0 | ||
for c, Len in hic_info['chromsizes'].items(): | ||
chromindex[c] = i | ||
i += 1 | ||
print(hic_info) | ||
|
||
name = os.path.basename(inmodel).split('.')[0] | ||
with open('genome.{}_{}.matrix.txt'.format(int(binsize/1000),name), 'w') as genome: | ||
for c1, Len1 in hic_info['chromsizes'].items(): | ||
for c2, Len2 in hic_info['chromsizes'].items(): | ||
if chromindex[c1] > chromindex[c2]: | ||
continue | ||
if c1 == 'M' or c2 == 'M': | ||
continue | ||
try: | ||
Mat = pred_chromosome.chr_pred(hicfile, c1, c2, binsize, inmodel) | ||
r, c = Mat.nonzero() | ||
for i in range(r.size): | ||
contact = int(round(Mat[r[i],c[i]])) | ||
if contact == 0: | ||
continue | ||
if r[i]*binsize > Len1 or (r[i]+1)*binsize > Len1: | ||
continue | ||
if c[i]*binsize > Len2 or (c[i]+1)*binsize > Len2: | ||
continue | ||
line = [c1, r[i]*binsize, (r[i]+1)*binsize, | ||
c2, c[i]*binsize, (c[i]+1)*binsize, contact] | ||
genome.write('chr'+str(line[0])+':'+str(line[1])+'-'+str(line[2])+ | ||
'\t'+'chr'+str(line[3])+':'+str(line[4])+'-'+str(line[5])+'\t'+str(line[6])+'\n') | ||
except: | ||
pass | ||
|
||
|
||
|
||
|
||
def main(args): | ||
binsize = args.binsize | ||
inmodel = args.model | ||
hicfile = args.inputfile | ||
pred_genome(hicfile, binsize, inmodel) | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
Oops, something went wrong.