diff --git a/src/grelu/interpret/motifs.py b/src/grelu/interpret/motifs.py index 7e232a9..04417f6 100644 --- a/src/grelu/interpret/motifs.py +++ b/src/grelu/interpret/motifs.py @@ -114,8 +114,9 @@ def scan_sequences( seq_ids: Optional[List[str]] = None, pthresh: float = 1e-3, rc: bool = True, - bin_size=0.1, - eps=0.0001, + bin_size: float = 0.1, + eps: float = 0.0001, + attrs: Optional[np.ndarray] = None, ): """ Scan a DNA sequence using motifs. Based on @@ -136,6 +137,9 @@ def scan_sequences( available to support it. Default is 0.1. eps: A small pseudocount to add to the motif PPMs before taking the log. Default is 0.0001. + attrs: An optional numpy array of shape (B, 4, L) containing attributions + for the input sequences. If provided, the results will include site + attribution and motif attribution scores for each FIMO hit. Returns: pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end', @@ -153,14 +157,12 @@ def scan_sequences( if isinstance(motifs, str): motifs = read_meme_file(motifs) - motifs = {k: Tensor(v) for k, v in motifs.items()} - # Scan each sequence in seqs - results = pd.DataFrame() - for seq, seq_id in zip(seqs, seq_ids): + sites = pd.DataFrame() + for i, (seq, seq_id) in enumerate(zip(seqs, seq_ids)): one_hot = strings_to_one_hot(seq, add_batch_axis=True) - curr_results = fimo( - motifs, + curr_sites = fimo( + motifs={k: Tensor(v) for k, v in motifs.items()}, sequences=one_hot, alphabet=["A", "C", "G", "T"], bin_size=bin_size, @@ -169,16 +171,18 @@ def scan_sequences( reverse_complement=rc, dim=1, ) - if len(curr_results) == 1: - curr_results = curr_results[0] - curr_results["sequence"] = seq_id - curr_results["matched_seq"] = curr_results.apply( + if len(curr_sites) == 1: + curr_sites = curr_sites[0] + curr_sites["seq_idx"] = i + curr_sites["sequence"] = seq_id + curr_sites["matched_seq"] = curr_sites.apply( lambda row: seq[row.start : row.end], axis=1 ) - curr_results = curr_results[ + curr_sites = curr_sites[ [ "motif_name", "sequence", + "seq_idx", "start", "end", "strand", @@ -187,13 +191,106 @@ def scan_sequences( "matched_seq", ] ] - results = pd.concat([results, curr_results]) + sites = pd.concat([sites, curr_sites]) # Concatenate results from all sequences - if len(results) > 0: - results = results.reset_index(drop=True) - results = results.rename(columns={"motif_name": "motif"}) - return results + if len(sites) > 0: + sites = sites.reset_index(drop=True) + sites = sites.rename(columns={"motif_name": "motif"}) + + # Add attribution scores + if attrs is not None: + sites = score_sites(sites, attrs, seqs) + sites = score_motifs(sites, attrs, motifs) + + return sites + + +def score_sites( + sites: pd.DataFrame, attrs: np.ndarray, seqs: Union[str, List[str]] +) -> pd.DataFrame: + """ + Given a dataframe of motif matching sites identified by FIMO and a set of attributions, this + function assigns each site a 'site attribution score' corresponding to the average attribution value + for all nucleotides within the site. This score gives the importance of the sequence region but does + not reflect the similarity between the PWM and the shape of the attributions. + + Args: + sites: A dataframe containing the output of scan_sequences + attrs: An optional numpy array of shape (B, 4, L) containing attributions + for the sequences. + seqs: A string or a list of DNA sequences as strings, which were the input to scan_sequences. + + Returns: + pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end', + 'strand', 'score', 'pval', 'matched_seq', and 'site_attr_score'. + """ + df = sites.copy() + + # Format sequences + seqs = make_list(seqs) + + # Format attributions + if attrs.ndim == 2: + attrs = np.expand_dims(attrs, 0) + assert attrs.shape[0] == len(seqs) + + # Score sites for each sequence + df["site_attr_score"] = df.apply( + lambda row: attrs[row.seq_idx, :, row.start : row.end].mean(), axis=1 + ) + + return df + + +def score_motifs( + sites: pd.DataFrame, attrs: np.ndarray, motifs: Union[Dict[str, np.ndarray], str] +) -> pd.DataFrame: + """ + Given a dataframe of motif matching sites identified by FIMO and a set of attributions, this + function assigns each site a 'motif attribution score' which is the sum of the element-wise + product of the motif and the attributions. This score is higher when the shape of the motif + matches the shape of the attribution profile, and is particularly useful for ranking multiple + motifs that all match to the same sequence region. + + Args: + sites: A dataframe containing the output of scan_sequences + attrs: An optional numpy array of shape (B, 4, L) containing attributions + for the sequences. + motifs: A dictionary whose values are Position Probability Matrices + (PPMs) of shape (4, L), or the path to a MEME file. This should be the + same as the input passed to scan_sequences. + + Returns: + pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end', + 'strand', 'score', 'pval', 'matched_seq', and 'motif_attr_score'. + """ + df = sites.copy() + + # Format attributions + if attrs.ndim == 2: + attrs = np.expand_dims(attrs, 0) + + # Format motifs + if isinstance(motifs, str): + motifs = read_meme_file(motifs) + motifs = { + motif_name: {"+": ppm, "-": np.flip(ppm, (0, 1))} + for motif_name, ppm in motifs.items() + } + + # Calculate attr x PWM product for each site + df["motif_attr_score"] = df.apply( + lambda row: np.multiply( + attrs[row.seq_idx, :, row.start : row.end], + motifs[row.motif][row.strand], + ) + .sum(0) + .mean(), + axis=1, + ) + + return df def marginalize_patterns( diff --git a/tests/test_interpret.py b/tests/test_interpret.py index 5c5d797..1cc7741 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -137,7 +137,9 @@ def test_get_attention_scores(): def test_scan_sequences(): - seqs = ["TCACGTGA", "CCTGCGTGA", "CACGCAGG"] + seqs = ["TCACGTGAA", "CCTGCGTGA", "CACGCAGGA"] + + # No reverse complement out = scan_sequences(seqs, motifs=meme_file, rc=False, pthresh=1e-3) assert out.motif.tolist() == ["MA0004.1 Arnt", "MA0006.1 Ahr::Arnt"] assert out.sequence.tolist() == ["0", "1"] @@ -146,6 +148,7 @@ def test_scan_sequences(): assert out.strand.tolist() == ["+", "+"] assert out.matched_seq.tolist() == ["CACGTG", "TGCGTG"] + # Allow reverse complement out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3) assert out.motif.tolist() == [ "MA0004.1 Arnt", @@ -159,6 +162,14 @@ def test_scan_sequences(): assert out.strand.tolist() == ["+", "-", "+", "-"] assert out.matched_seq.tolist() == ["CACGTG", "CACGTG", "TGCGTG", "CACGCA"] + # Reverse complement with attributions + attrs = get_attributions(model, seqs, method="inputxgradient") + out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3, attrs=attrs) + assert np.allclose(out.site_attr_score, [0.0, 0.0, -0.009259, 0.009259], rtol=0.001) + assert np.allclose( + out.motif_attr_score, [0.003704, 0.0, -0.035494, 0.0], rtol=0.001 + ) + def test_run_tomtom():