Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Score FIMO hits with attributions #73

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
133 changes: 115 additions & 18 deletions src/grelu/interpret/motifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ def scan_sequences(
seq_ids: Optional[List[str]] = None,
pthresh: float = 1e-3,
rc: bool = True,
bin_size=0.1,
eps=0.0001,
bin_size: float = 0.1,
eps: float = 0.0001,
attrs: Optional[np.ndarray] = None,
):
"""
Scan a DNA sequence using motifs. Based on
Expand All @@ -136,6 +137,9 @@ def scan_sequences(
available to support it. Default is 0.1.
eps: A small pseudocount to add to the motif PPMs before taking the log.
Default is 0.0001.
attrs: An optional numpy array of shape (B, 4, L) containing attributions
for the input sequences. If provided, the results will include site
attribution and motif attribution scores for each FIMO hit.

Returns:
pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end',
Expand All @@ -153,14 +157,12 @@ def scan_sequences(
if isinstance(motifs, str):
motifs = read_meme_file(motifs)

motifs = {k: Tensor(v) for k, v in motifs.items()}

# Scan each sequence in seqs
results = pd.DataFrame()
for seq, seq_id in zip(seqs, seq_ids):
sites = pd.DataFrame()
for i, (seq, seq_id) in enumerate(zip(seqs, seq_ids)):
one_hot = strings_to_one_hot(seq, add_batch_axis=True)
curr_results = fimo(
motifs,
curr_sites = fimo(
motifs={k: Tensor(v) for k, v in motifs.items()},
sequences=one_hot,
alphabet=["A", "C", "G", "T"],
bin_size=bin_size,
Expand All @@ -169,16 +171,18 @@ def scan_sequences(
reverse_complement=rc,
dim=1,
)
if len(curr_results) == 1:
curr_results = curr_results[0]
curr_results["sequence"] = seq_id
curr_results["matched_seq"] = curr_results.apply(
if len(curr_sites) == 1:
curr_sites = curr_sites[0]
curr_sites["seq_idx"] = i
curr_sites["sequence"] = seq_id
curr_sites["matched_seq"] = curr_sites.apply(
lambda row: seq[row.start : row.end], axis=1
)
curr_results = curr_results[
curr_sites = curr_sites[
[
"motif_name",
"sequence",
"seq_idx",
"start",
"end",
"strand",
Expand All @@ -187,13 +191,106 @@ def scan_sequences(
"matched_seq",
]
]
results = pd.concat([results, curr_results])
sites = pd.concat([sites, curr_sites])

# Concatenate results from all sequences
if len(results) > 0:
results = results.reset_index(drop=True)
results = results.rename(columns={"motif_name": "motif"})
return results
if len(sites) > 0:
sites = sites.reset_index(drop=True)
sites = sites.rename(columns={"motif_name": "motif"})

# Add attribution scores
if attrs is not None:
sites = score_sites(sites, attrs, seqs)
sites = score_motifs(sites, attrs, motifs)

return sites


def score_sites(
sites: pd.DataFrame, attrs: np.ndarray, seqs: Union[str, List[str]]
) -> pd.DataFrame:
"""
Given a dataframe of motif matching sites identified by FIMO and a set of attributions, this
function assigns each site a 'site attribution score' corresponding to the average attribution value
for all nucleotides within the site. This score gives the importance of the sequence region but does
not reflect the similarity between the PWM and the shape of the attributions.

Args:
sites: A dataframe containing the output of scan_sequences
attrs: An optional numpy array of shape (B, 4, L) containing attributions
for the sequences.
seqs: A string or a list of DNA sequences as strings, which were the input to scan_sequences.

Returns:
pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end',
'strand', 'score', 'pval', 'matched_seq', and 'site_attr_score'.
"""
df = sites.copy()

# Format sequences
seqs = make_list(seqs)

# Format attributions
if attrs.ndim == 2:
attrs = np.expand_dims(attrs, 0)
assert attrs.shape[0] == len(seqs)

# Score sites for each sequence
df["site_attr_score"] = df.apply(
lambda row: attrs[row.seq_idx, :, row.start : row.end].mean(), axis=1
)

return df


def score_motifs(
sites: pd.DataFrame, attrs: np.ndarray, motifs: Union[Dict[str, np.ndarray], str]
) -> pd.DataFrame:
"""
Given a dataframe of motif matching sites identified by FIMO and a set of attributions, this
function assigns each site a 'motif attribution score' which is the sum of the element-wise
product of the motif and the attributions. This score is higher when the shape of the motif
matches the shape of the attribution profile, and is particularly useful for ranking multiple
motifs that all match to the same sequence region.

Args:
sites: A dataframe containing the output of scan_sequences
attrs: An optional numpy array of shape (B, 4, L) containing attributions
for the sequences.
motifs: A dictionary whose values are Position Probability Matrices
(PPMs) of shape (4, L), or the path to a MEME file. This should be the
same as the input passed to scan_sequences.

Returns:
pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end',
'strand', 'score', 'pval', 'matched_seq', and 'motif_attr_score'.
"""
df = sites.copy()

# Format attributions
if attrs.ndim == 2:
attrs = np.expand_dims(attrs, 0)

# Format motifs
if isinstance(motifs, str):
motifs = read_meme_file(motifs)
motifs = {
motif_name: {"+": ppm, "-": np.flip(ppm, (0, 1))}
for motif_name, ppm in motifs.items()
}

# Calculate attr x PWM product for each site
df["motif_attr_score"] = df.apply(
lambda row: np.multiply(
attrs[row.seq_idx, :, row.start : row.end],
motifs[row.motif][row.strand],
)
.sum(0)
.mean(),
axis=1,
)

return df


def marginalize_patterns(
Expand Down
13 changes: 12 additions & 1 deletion tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def test_get_attention_scores():


def test_scan_sequences():
seqs = ["TCACGTGA", "CCTGCGTGA", "CACGCAGG"]
seqs = ["TCACGTGAA", "CCTGCGTGA", "CACGCAGGA"]

# No reverse complement
out = scan_sequences(seqs, motifs=meme_file, rc=False, pthresh=1e-3)
assert out.motif.tolist() == ["MA0004.1 Arnt", "MA0006.1 Ahr::Arnt"]
assert out.sequence.tolist() == ["0", "1"]
Expand All @@ -146,6 +148,7 @@ def test_scan_sequences():
assert out.strand.tolist() == ["+", "+"]
assert out.matched_seq.tolist() == ["CACGTG", "TGCGTG"]

# Allow reverse complement
out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3)
assert out.motif.tolist() == [
"MA0004.1 Arnt",
Expand All @@ -159,6 +162,14 @@ def test_scan_sequences():
assert out.strand.tolist() == ["+", "-", "+", "-"]
assert out.matched_seq.tolist() == ["CACGTG", "CACGTG", "TGCGTG", "CACGCA"]

# Reverse complement with attributions
attrs = get_attributions(model, seqs, method="inputxgradient")
out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3, attrs=attrs)
assert np.allclose(out.site_attr_score, [0.0, 0.0, -0.009259, 0.009259], rtol=0.001)
assert np.allclose(
out.motif_attr_score, [0.003704, 0.0, -0.035494, 0.0], rtol=0.001
)


def test_run_tomtom():

Expand Down
Loading