Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency of predictions depending on batch size #22

Open
neverov-am opened this issue May 22, 2024 · 0 comments
Open

Inconsistency of predictions depending on batch size #22

neverov-am opened this issue May 22, 2024 · 0 comments

Comments

@neverov-am
Copy link

neverov-am commented May 22, 2024

Dear Authors,

Thank you for the great tool.

I want to implement an option to predict scores with batch size larger than 1. During my first tests I noticed, that the predictions differ depending on the batch size. Could you check what might be the reason for this behaviour of the model? Below, I provide the example variant (chr12-110435045-G-A), for which the score differs when it's predicted for the single variant and for the provided batch of size 4: 0.5400000214576721 in the original version against 0.5299999713897705 on the batch. I also provide my code to reproduce the issue. To make the question more compact, I give an example with a prediction mismatch for just one of the models.

import torch
import numpy as np
import pyfastx
from pkg_resources import resource_filename
from pangolin.model import *

###############################################################################################
test_variants = [
    'chr12-110435044-T-C',
    'chr12-110435044-T-G',
    'chr12-110435045-G-A',
    'chr12-110435045-G-C',
]

atol = 0.000001 # tolerance value to be used in np.allclose()
d = 50
reference_fasta_path = 'GRCh38.primary_assembly.genome.fa'
###############################################################################################
# the same as in the original version

IN_MAP = np.asarray([[0, 0, 0, 0],
                     [1, 0, 0, 0],
                     [0, 1, 0, 0],
                     [0, 0, 1, 0],
                     [0, 0, 0, 1]])


def one_hot_encode(seq, strand):
    seq = seq.upper().replace('A', '1').replace('C', '2')
    seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
    if strand == '+':
        seq = np.asarray(list(map(int, list(seq))))
    elif strand == '-':
        seq = np.asarray(list(map(int, list(seq[::-1]))))
        seq = (5 - seq) % 5  # Reverse complement
    return IN_MAP[seq.astype('int8')]

models = []
for i in [0,2,4,6]:
    for j in range(1,4):
        model = Pangolin(L, W, AR)
        if torch.cuda.is_available():
            model.cuda()
            weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)))
        else:
            weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)), map_location=torch.device('cpu'))
        model.load_state_dict(weights)
        model.eval()
        models.append(model)

###############################################################################################
# process variants

def prepare_variant_for_batch(lnum, chr, pos, ref, alt, fasta, d):
    
    seq = fasta[chr][pos-5001-d:pos+len(ref)+4999+d].seq
    
    ref_seq = seq
    alt_seq = seq[:5000+d] + alt + seq[5000+d+len(ref):]
    
    return ref_seq, alt_seq

fasta = pyfastx.Fasta(reference_fasta_path)

batch_chroms = []
batch_positions = []
batch_refs = []
batch_alts = []

for test_variant in test_variants:
    chr = test_variant.split('-')[0]
    pos = int(test_variant.split('-')[1])
    ref = test_variant.split('-')[2]
    alt = test_variant.split('-')[3]
    
    ref_seq, alt_seq = prepare_variant_for_batch(0, chr, pos, ref, alt, fasta, d)
    
    batch_chroms.append(chr)
    batch_positions.append(pos)
    batch_refs.append(ref_seq)
    batch_alts.append(alt_seq)

model = models[0]

strand = '-'

# predict on batch

encoded_refs = [] # store encoded reference sequences in a list
encoded_alts = [] # store encoded alternative sequences in a list
    
for i in range(len(batch_refs)):
    ref_seq = torch.from_numpy(one_hot_encode(batch_refs[i], strand).T).float()
    alt_seq = torch.from_numpy(one_hot_encode(batch_alts[i], strand).T).float()
    encoded_refs.append(ref_seq)
    encoded_alts.append(alt_seq)
        
batch_ref = torch.stack(encoded_refs) # create a tensor with multiple ref sequences
batch_alt = torch.stack(encoded_alts) # create a tensor with multiple alt sequences
    
if torch.cuda.is_available():
    batch_ref = batch_ref.to(torch.device("cuda"))
    batch_alt = batch_alt.to(torch.device("cuda"))

with torch.no_grad():
    pred_ref = model(batch_ref)[:,[1,4,7,10][j],:].cpu().numpy() # [0][[1,4,7,10][j],:].cpu().numpy() modify indexing
    pred_alt = model(batch_alt)[:,[1,4,7,10][j],:].cpu().numpy() # [0][[1,4,7,10][j],:].cpu().numpy() modify indexing

# predict single

i=2

ref_seq = one_hot_encode(batch_refs[i], strand).T
ref_seq = torch.from_numpy(np.expand_dims(ref_seq, axis=0)).float()
alt_seq = one_hot_encode(batch_alts[i], strand).T
alt_seq = torch.from_numpy(np.expand_dims(alt_seq, axis=0)).float()

if torch.cuda.is_available():
    ref_seq = ref_seq.to(torch.device("cuda"))
    alt_seq = alt_seq.to(torch.device("cuda"))

with torch.no_grad():
    pred_ref_single = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy()
    pred_alt_single = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy()

# compare

print(np.allclose(pred_ref_single, pred_ref[i], atol=atol)) # Switches from True to False between atol=0.00001 and atol=0.000001
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant