forked from TaliaferroLab/LABRAT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcountfrombam.py
370 lines (302 loc) · 12.7 KB
/
countfrombam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
#Count aligned reads (as bam) at pA sites using bedtools
#Only reads that are 300 nt or less upstream from a pA site will be counted for that pA site
import gffutils
import os
import sys
import pysam
import pickle
import pandas as pd
from functools import reduce
import argparse
#Does a gene pass filters?
def genefilters(gene, db):
proteincoding = False
if 'protein_coding' in gene.attributes['gene_type']:
proteincoding = True
if proteincoding == True:
return True
else:
return False
#Does a transcript pass filters?
def transcriptfilters(transcript, db):
exonnumberpass = False
TFlengthpass = False
proteincoding = False
mrnaendpass = False
#How many exons does it have
if len(list(db.children(transcript, featuretype = 'exon'))) >= 2:
exonnumberpass = True
else:
return False
#What is the length of the terminal fragment
exons = []
if transcript.strand == '+':
for exon in db.children(transcript, featuretype = 'exon', order_by = 'start'):
exons.append([exon.start, exon.end + 1])
elif transcript.strand == '-':
for exon in db.children(transcript, featuretype = 'exon', order_by = 'start', reverse = True):
exons.append([exon.start, exon.end + 1])
penultimateexonlen = len(range(exons[-2][0], exons[-2][1]))
lastexonlen = len(range(exons[-1][0], exons[-1][1]))
TFlength = penultimateexonlen + lastexonlen
if TFlength > 200:
TFlengthpass = True
#Is this transcript protein coding
if 'protein_coding' in transcript.attributes['transcript_type']:
proteincoding = True
#Are we confident in the 3' end of this mrnaendpass
if 'tag' not in transcript.attributes or 'mRNA_end_NF' not in transcript.attributes['tag']:
mrnaendpass = True
if exonnumberpass and TFlengthpass and proteincoding and mrnaendpass:
return True
else:
return False
#Given an annotation in gff format, get the position factors for all transcripts.
#It merges any two transcript ends that are less than <lengthfilter> away from each other into a single end.
#This is so that you dont end up with unique regions that are like 4 nt long.
#They might causes issues when it comes to counting kmers or reads that map to a given region.
def getpositionfactors(gff, lengthfilter):
lengthfilter = int(lengthfilter)
genecount = 0
txends = {} #{ENSMUSG : [strand, [list of distinct transcript end coords]]}
posfactors = {} #{ENSMUSG : {ENSMUST : positionfactor}}
#Make gff database
print('Indexing gff...')
gff_fn = gff
db_fn = os.path.abspath(gff_fn) + '.db'
if os.path.isfile(db_fn) == False:
gffutils.create_db(gff_fn, db_fn, merge_strategy = 'merge', verbose = True)
db = gffutils.FeatureDB(db_fn)
print('Done indexing!')
#Get number of distinct transcript ends for each gene
genes = db.features_of_type('gene')
for gene in genes:
#Only protein coding genes
passgenefilters = genefilters(gene, db)
if passgenefilters == False:
continue
genename = str(gene.id).replace('gene:', '')
ends = []
if gene.strand == '+':
for transcript in db.children(gene, featuretype = 'transcript', level = 1, order_by = 'end'):
#Skip transcripts that do not pass filters
passtranscriptfilters = transcriptfilters(transcript, db)
if passtranscriptfilters == False:
continue
if transcript.end not in ends:
ends.append(transcript.end)
elif gene.strand == '-':
for transcript in db.children(gene, featuretype = 'transcript', level = 1, order_by = 'start', reverse = True):
#Skip transcripts that do not pass filters
passtranscriptfilters = transcriptfilters(transcript, db)
if passtranscriptfilters == False:
continue
if transcript.start not in ends:
ends.append(transcript.start)
if ends: #Sometimes there are no 'transcripts' for a gene, like with pseudogenes, etc.
txends[genename] = [gene.strand, ends]
#Sort transcript end coords
s_txends = {} #{ENSMUSG : [sorted (most upstream to most downstream) tx end coords]}
for gene in txends:
strand = txends[gene][0]
coords = txends[gene][1]
if strand == '+':
sortedcoords = sorted(coords)
elif strand == '-':
sortedcoords = sorted(coords, reverse = True)
s_txends[gene] = sortedcoords
#Get m values (the numerator of the position factor fraction), combining an end that is less than <lengthfilter> nt away from
#the previous utr into the same m value as the previous utr
mvalues = {} #{ENSMUSG : {txendcoord : mvalue}}
for gene in s_txends:
mvalues[gene] = {}
currentendcoord = s_txends[gene][0]
currentmvalue = 0
mvalues[gene][currentendcoord] = 0 #the first one has to have m = 0
for endcoord in s_txends[gene][1:]:
#If this endcoord is too close to the last one
if abs(endcoord - currentendcoord) <= lengthfilter:
#this end gets the current m value
mvalues[gene][endcoord] = currentmvalue
#we stay on this m value for the next round
currentmvalue = currentmvalue
#update currentendcoord
currentendcoord = endcoord
#If this endcoord is sufficiently far away from the last one
elif abs(endcoord - currentendcoord) > lengthfilter:
#this end coord gets the next m value
mvalues[gene][endcoord] = currentmvalue + 1
#we move on to the next m value for the next round
currentmvalue = currentmvalue + 1
#update currentendcoord
currentendcoord = endcoord
#Figure out postion scaling factor for each transcript (position / (number of total positions - 1)) (m / (n - 1))
genes = db.features_of_type('gene')
for gene in genes:
genecount +=1
genename = str(gene.id).replace('gene:', '')
#Only protein coding genes
passgenefilters = genefilters(gene, db)
if passgenefilters == False:
continue
#If this gene isnt in mvalues or there is only one m value for the entire gene, skip it
if genename not in mvalues:
continue
if len(set(mvalues[genename].values())) == 1:
continue
#Get number of different m values for this gene
n = len(set(mvalues[genename].values()))
posfactors[genename] = {}
for transcript in db.children(gene, featuretype = 'transcript', level = 1, order_by = 'end'):
#Skip transcripts that do not pass filters
passtranscriptfilters = transcriptfilters(transcript, db)
if passtranscriptfilters == False:
continue
txname = str(transcript.id).replace('transcript:', '')
if gene.strand == '+':
m = mvalues[genename][transcript.end]
elif gene.strand == '-':
m = mvalues[genename][transcript.start]
posfactor = m / float(n - 1)
posfactors[genename][txname] = posfactor
#Output file of the number of posfactors for each gene
with open('numberofposfactors.txt', 'w') as outfh:
outfh.write(('\t').join(['Gene', 'numberofposfactors', 'txids', 'interpolyAdist']) + '\n')
for gene in posfactors: #{ENSMUSG : {ENSMUST : positionfactor}}
pfs = []
for tx in posfactors[gene]:
pfs.append(posfactors[gene][tx])
pfs = list(set(pfs))
#write distance between polyA sites for those genes that only have 2 pfs
if len(pfs) == 2:
g = db[gene]
for tx in posfactors[gene]:
if posfactors[gene][tx] == 0:
txpf1 = tx
elif posfactors[gene][tx] == 1:
txpf2 = tx
t1 = db[txpf1]
t2 = db[txpf2]
if g.strand == '+':
interpolyAdist = t2.end - t1.end
elif g.strand == '-':
interpolyAdist = t1.start - t2.start
elif len(pfs) != 2:
interpolyAdist = 'NA'
#Get list of txs that belong to each pf
txids = {} #{positionfactor : [list of transcriptIDs]}
for pf in sorted(pfs):
txids[pf] = []
for tx in posfactors[gene]:
if posfactors[gene][tx] == pf:
txids[float(pf)].append(tx)
alltxids = []
for pf in sorted(list(txids.keys())):
alltxids.append((',').join(txids[pf]))
outfh.write(('\t').join([gene, str(len(pfs)), ('_').join(alltxids), str(interpolyAdist)]) + '\n')
return posfactors
def getpositionfactorintervals(gff, posfactors):
#The mean insert sizes for lexogen 3' end data is 200-300 nt. Looking at the Corley data
#produced with this protocol on a genome browser, it looks like it is very often less than that.
#So almost all of the reads associated with a particular polyA site should be within 300 nt of that site.
#So for each transcript that is assigned to a position factor, define an interval as the last 300 nt of that
#transcript. Count how many reads lie in that window, and calculate psi values using those read counts.
#posfactors = {} #{ENSMUSG : {ENSMUST : positionfactor}}
posfactors_intervals = {} #{ENST : [chrm, windowstart, windowstop, strand]}
print('Indexing gff...')
gff_fn = gff
db_fn = os.path.abspath(gff_fn) + '.db'
if os.path.isfile(db_fn) == False:
gffutils.create_db(gff_fn, db_fn, merge_strategy = 'merge', verbose = True)
db = gffutils.FeatureDB(db_fn)
print('Done indexing!')
genecounter = 0
for gene in posfactors:
genecounter +=1
if genecounter % 1000 == 0:
print('Gene {0} of {1}...'.format(genecounter, len(posfactors)))
for txid in posfactors[gene]:
txexons = [] #list of exonic coordinates for this transcript
tx = db[txid]
for exon in db.children(tx, featuretype = 'exon', order_by = 'start'):
txexons += list(range(exon.start, exon.end + 1))
if tx.strand == '-':
txexons = list(reversed(txexons))
if len(txexons) > 300:
windowstart = txexons[-300]
else:
windowstart = txexons[0]
windowstop = txexons[-1]
if tx.strand == '+':
posfactors_intervals[txid] = [str(tx.chrom), int(windowstart), int(windowstop), str(tx.strand)]
elif tx.strand == '-':
posfactors_intervals[txid] = [str(tx.chrom), int(windowstop), int(windowstart), str(tx.strand)]
return posfactors_intervals
def calculatepsi(posfactors, posfactors_intervals, bamfile):
#Take a bam of aligned reads and calculate psi by counting the number of reads
#in the predefined 300 nt windows for each transcript
txcounts = {} #{transcriptid : readcounts}
genecounts = {} #{geneid : [transcript counts]} (unscaled)
posfactorgenecounts = {} #{geneid : [transcriptcounts scaled by posfactor]} (scaled)
psis = {} #{geneid : psi}
bam = pysam.AlignmentFile(bamfile, 'rb')
print('Calculating psi for {0}...'.format(os.path.basename(bamfile)))
genecounter = 0
for gene in posfactors:
genecounter +=1
if genecounter % 1000 == 0:
print('Gene {0} of {1}...'.format(genecounter, len(posfactors)))
genecounts[gene] = []
posfactorgenecounts[gene] = []
for tx in posfactors[gene]:
posfactor = posfactors[gene][tx]
window = posfactors_intervals[tx]
windowchrm = window[0]
windowstart = window[1]
windowstop = window[2]
strand = window[3]
overlappingreads = bam.fetch(windowchrm, windowstart, windowstop)
readsinwindow = []
for read in overlappingreads:
if read.is_reverse:
readstrand = '-'
elif not read.is_reverse:
readstrand = '+'
if strand == readstrand:
readsinwindow.append(read)
counts = len(readsinwindow)
scaledcounts = counts * posfactor
genecounts[gene].append(counts)
posfactorgenecounts[gene].append(scaledcounts)
#Calculate psi values
for gene in posfactorgenecounts:
totalcounts = sum(genecounts[gene])
scaledcounts = sum(posfactorgenecounts[gene])
if totalcounts < 100:
psi = 'NA'
else:
psi = scaledcounts / totalcounts
psis[gene] = round(psi, 3)
#turn into df
df = pd.DataFrame.from_dict(psis, orient = 'index', columns = [os.path.basename(bamfile)])
return df
def domanybams(bamdir, posfactors, posfactors_intervals):
dfs = []
for f in os.listdir(bamdir):
fp = os.path.join(os.path.abspath(bamdir), f)
if fp.endswith('.bam'):
df = calculatepsi(posfactors, posfactors_intervals, fp)
dfs.append(df)
#merge dfs
mergeddf = reduce(lambda df1, df2: pd.merge(df1, df2, left_index = True, right_index = True, how = 'inner'), dfs)
return mergeddf
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gff', help = 'Genome annotation in gff format.')
parser.add_argument('--bamdir', help = 'Directory of bam files to quantify.')
parser.add_argument('--output', help = 'Output file for psi values.')
args = parser.parse_args()
posfactors = getpositionfactors(args.gff, 25)
positionfactorintervals = getpositionfactorintervals(args.gff, posfactors)
df = domanybams(args.bamdir, posfactors, positionfactorintervals)
df.to_csv(args.output, sep = '\t', header = True, index = True, index_label = 'Gene')