From a7d5869f9a9e0863e4e29a6c8485b108c9a43435 Mon Sep 17 00:00:00 2001 From: lala8 Date: Mon, 21 Oct 2024 23:37:02 +0000 Subject: [PATCH 1/6] added attribution scoring of FIMO hits --- src/grelu/interpret/motifs.py | 129 ++++++++++++++++++++++++++++++---- tests/test_interpret.py | 9 +++ 2 files changed, 123 insertions(+), 15 deletions(-) diff --git a/src/grelu/interpret/motifs.py b/src/grelu/interpret/motifs.py index 182e95d..31c1d31 100644 --- a/src/grelu/interpret/motifs.py +++ b/src/grelu/interpret/motifs.py @@ -112,8 +112,9 @@ def scan_sequences( seq_ids: Optional[List[str]] = None, pthresh: float = 1e-3, rc: bool = True, - bin_size=0.1, - eps=0.0001, + bin_size: float = 0.1, + eps: float = 0.0001, + attrs: Optional[np.ndarray] = None, ): """ Scan a DNA sequence using motifs. Based on @@ -134,6 +135,9 @@ def scan_sequences( available to support it. Default is 0.1. eps: A small pseudocount to add to the motif PPMs before taking the log. Default is 0.0001. + attrs: An optional numpy array of shape (B, 4, L) containing attributions + for the input sequences. If provided, the results will include site + attribution and motif attribution scores for each FIMO hit. Returns: pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end', @@ -152,10 +156,10 @@ def scan_sequences( motifs = {k: Tensor(v) for k, v in motifs.items()} # Scan each sequence in seqs - results = pd.DataFrame() - for seq, seq_id in zip(seqs, seq_ids): + sites = pd.DataFrame() + for i, (seq, seq_id) in enumerate(zip(seqs, seq_ids)): one_hot = strings_to_one_hot(seq, add_batch_axis=True) - curr_results = fimo( + curr_sites = fimo( motifs, sequences=one_hot, alphabet=["A", "C", "G", "T"], @@ -165,16 +169,18 @@ def scan_sequences( reverse_complement=rc, dim=1, ) - if len(curr_results) == 1: - curr_results = curr_results[0] - curr_results["sequence"] = seq_id - curr_results["matched_seq"] = curr_results.apply( + if len(curr_sites) == 1: + curr_sites = curr_sites[0] + curr_sites["seq_idx"] = i + curr_sites["sequence"] = seq_id + curr_sites["matched_seq"] = curr_sites.apply( lambda row: seq[row.start : row.end], axis=1 ) - curr_results = curr_results[ + curr_sites = curr_sites[ [ "motif_name", "sequence", + "seq_idx", "start", "end", "strand", @@ -183,13 +189,106 @@ def scan_sequences( "matched_seq", ] ] - results = pd.concat([results, curr_results]) + sites = pd.concat([sites, curr_sites]) # Concatenate results from all sequences - if len(results) > 0: - results = results.reset_index(drop=True) - results = results.rename(columns={"motif_name": "motif"}) - return results + if len(sites) > 0: + sites = sites.reset_index(drop=True) + sites = sites.rename(columns={"motif_name": "motif"}) + + # Add attribution scores + if attrs is not None: + sites = score_sites(sites, attrs, seqs) + sites = score_motifs(sites, attrs, motifs) + + return sites + + +def score_sites( + sites: pd.DataFrame, attrs: np.ndarray, seqs: Union[str, List[str]] +) -> pd.DataFrame: + """ + Given a dataframe of motif matching sites identified by FIMO and a set of attributions, this + function assigns each site a 'site attribution score' corresponding to the average attribution value + for all nucleotides within the site. This score gives the importance of the sequence region but does + not reflect the similarity between the PWM and the shape of the attributions. + + Args: + sites: A dataframe containing the output of scan_sequences + attrs: An optional numpy array of shape (B, 4, L) containing attributions + for the sequences. + seqs: A string or a list of DNA sequences as strings, which were the input to scan_sequences. + + Returns: + pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end', + 'strand', 'score', 'pval', 'matched_seq', and 'site_attr_score'. + """ + df = sites.copy() + + # Format sequences + seqs = make_list(seqs) + + # Format attributions + if attrs.ndim == 2: + attrs = np.expand_dims(attrs, 0) + assert attrs.shape[0] == len(seqs) + + # Score sites for each sequence + df["site_attr_score"] = df.apply( + lambda row: attrs[row.seq_idx, :, row.start : row.end].mean(), axis=1 + ) + + return df + + +def score_motifs( + sites: pd.DataFrame, attrs: np.ndarray, motifs: Union[str, List[str]] +) -> pd.DataFrame: + """ + Given a dataframe of motif matching sites identified by FIMO and a set of attributions, this + function assigns each site a 'motif attribution score' which is the sum of the element-wise + product of the motif and the attributions. This score is higher when the shape of the motif + matches the shape of the attribution profile, and is particularly useful for ranking multiple + motifs that all match to the same sequence region. + + Args: + sites: A dataframe containing the output of scan_sequences + attrs: An optional numpy array of shape (B, 4, L) containing attributions + for the sequences. + motifs: A dictionary whose values are Position Probability Matrices + (PPMs) of shape (4, L), or the path to a MEME file. This should be the + same as the input passed to scan_sequences. + + Returns: + pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end', + 'strand', 'score', 'pval', 'matched_seq', and 'motif_attr_score'. + """ + df = sites.copy() + + # Format attributions + if attrs.ndim == 2: + attrs = np.expand_dims(attrs, 0) + + # Format motifs + if isinstance(motifs, str): + motifs = read_meme_file(motifs) + motifs = { + motif_name: {"+": ppm, "-": np.flip(ppm, (0, 1))} + for motif_name, ppm in motifs.items() + } + + # Calculate attr x PWM product for each site + df["motif_attr_score"] = df.apply( + lambda row: np.multiply( + attrs[row.seq_idx, :, row.start : row.end], + motifs[row.motif][row.strand], + ) + .sum(0) + .mean(), + axis=1, + ) + + return df def marginalize_patterns( diff --git a/tests/test_interpret.py b/tests/test_interpret.py index 2eb4c7c..a84c715 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -133,6 +133,8 @@ def test_get_attention_scores(): def test_scan_sequences(): seqs = ["TCACGTGA", "CCTGCGTGA", "CACGCAGG"] + + # No reverse complement out = scan_sequences(seqs, motifs=meme_file, rc=False, pthresh=1e-3) assert out.motif.tolist() == ["MA0004.1 Arnt", "MA0006.1 Ahr::Arnt"] assert out.sequence.tolist() == ["0", "1"] @@ -141,6 +143,7 @@ def test_scan_sequences(): assert out.strand.tolist() == ["+", "+"] assert out.matched_seq.tolist() == ["CACGTG", "TGCGTG"] + # Allow reverse complement out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3) assert out.motif.tolist() == [ "MA0004.1 Arnt", @@ -153,3 +156,9 @@ def test_scan_sequences(): assert out.end.tolist() == [7, 7, 8, 6] assert out.strand.tolist() == ["+", "-", "+", "-"] assert out.matched_seq.tolist() == ["CACGTG", "CACGTG", "TGCGTG", "CACGCA"] + + # Reverse complement with attributions + attrs = get_attributions(model, seqs, method="inputxgradient") + out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3, attrs=attrs) + assert np.allclose(out.site_attr_score, [0.0, 0.0, -0.009259, 0.009259]) + assert np.allclose(out.motif_attr_score, [0.003704, 0.0, -0.035494, 0.0]) From 07504c98ce881c80c2b4378fdc3f4090411e8bc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 23:38:23 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index cade4af..f66d32f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,6 @@ import torch - import wandb + from grelu.model.models import ( BorzoiModel, BorzoiPretrainedModel, From 85823e71b4824fa6709c7afbf2b7b822c69bb11d Mon Sep 17 00:00:00 2001 From: lala8 Date: Mon, 21 Oct 2024 23:47:34 +0000 Subject: [PATCH 3/6] fixed test --- tests/test_interpret.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_interpret.py b/tests/test_interpret.py index a84c715..c8c2a3f 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -132,7 +132,7 @@ def test_get_attention_scores(): def test_scan_sequences(): - seqs = ["TCACGTGA", "CCTGCGTGA", "CACGCAGG"] + seqs = ["TCACGTGAA", "CCTGCGTGA", "CACGCAGGA"] # No reverse complement out = scan_sequences(seqs, motifs=meme_file, rc=False, pthresh=1e-3) @@ -160,5 +160,7 @@ def test_scan_sequences(): # Reverse complement with attributions attrs = get_attributions(model, seqs, method="inputxgradient") out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3, attrs=attrs) - assert np.allclose(out.site_attr_score, [0.0, 0.0, -0.009259, 0.009259]) - assert np.allclose(out.motif_attr_score, [0.003704, 0.0, -0.035494, 0.0]) + assert np.allclose(out.site_attr_score, [0.0, 0.0, -0.009259, 0.009259], rtol=0.001) + assert np.allclose( + out.motif_attr_score, [0.003704, 0.0, -0.035494, 0.0], rtol=0.001 + ) From cf55e24614c0afa1eff13c3b93ae1e2fdc20e2e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:40:25 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_interpret.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_interpret.py b/tests/test_interpret.py index 1a5a34a..1cc7741 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -170,6 +170,7 @@ def test_scan_sequences(): out.motif_attr_score, [0.003704, 0.0, -0.035494, 0.0], rtol=0.001 ) + def test_run_tomtom(): motifs = { From 4e054a1a0a45e54d6053ea4ce9b196bdcc3c8ee2 Mon Sep 17 00:00:00 2001 From: avantikalal Date: Tue, 19 Nov 2024 00:19:09 +0000 Subject: [PATCH 5/6] fixed numpy -> torch issue --- src/grelu/interpret/motifs.py | 6 ++---- tests/test_models.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/grelu/interpret/motifs.py b/src/grelu/interpret/motifs.py index 0a0e3ba..04417f6 100644 --- a/src/grelu/interpret/motifs.py +++ b/src/grelu/interpret/motifs.py @@ -157,14 +157,12 @@ def scan_sequences( if isinstance(motifs, str): motifs = read_meme_file(motifs) - motifs = {k: Tensor(v) for k, v in motifs.items()} - # Scan each sequence in seqs sites = pd.DataFrame() for i, (seq, seq_id) in enumerate(zip(seqs, seq_ids)): one_hot = strings_to_one_hot(seq, add_batch_axis=True) curr_sites = fimo( - motifs, + motifs={k: Tensor(v) for k, v in motifs.items()}, sequences=one_hot, alphabet=["A", "C", "G", "T"], bin_size=bin_size, @@ -246,7 +244,7 @@ def score_sites( def score_motifs( - sites: pd.DataFrame, attrs: np.ndarray, motifs: Union[str, List[str]] + sites: pd.DataFrame, attrs: np.ndarray, motifs: Union[Dict[str, np.ndarray], str] ) -> pd.DataFrame: """ Given a dataframe of motif matching sites identified by FIMO and a set of attributions, this diff --git a/tests/test_models.py b/tests/test_models.py index f66d32f..cade4af 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,6 @@ import torch -import wandb +import wandb from grelu.model.models import ( BorzoiModel, BorzoiPretrainedModel, From cc289363626d101b61511ace43d4cee34e648fbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 00:19:33 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index cade4af..f66d32f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,6 @@ import torch - import wandb + from grelu.model.models import ( BorzoiModel, BorzoiPretrainedModel,