From f4d95043fb807c02eb0c7d3b0848d7bbdacd86b9 Mon Sep 17 00:00:00 2001 From: alexsweeten Date: Sun, 6 Oct 2024 11:56:02 -0500 Subject: [PATCH 1/5] Add grid functionality --- src/moddotplot/moddotplot.py | 168 +++++++---- src/moddotplot/static_plots.py | 504 ++++++++++++++++++++++++++++++++- 2 files changed, 612 insertions(+), 60 deletions(-) diff --git a/src/moddotplot/moddotplot.py b/src/moddotplot/moddotplot.py index 368573f..cf96ca9 100644 --- a/src/moddotplot/moddotplot.py +++ b/src/moddotplot/moddotplot.py @@ -21,7 +21,7 @@ import argparse import math -from moddotplot.static_plots import read_df_from_file, create_plots +from moddotplot.static_plots import read_df_from_file, create_plots, create_grid import json import numpy as np import pickle @@ -342,6 +342,12 @@ def get_parser(): help="Plot comparative plots in an NxN grid like format.", ) + static_parser.add_argument( + "--grid-only", + action="store_true", + help="Plot comparative plots in an NxN grid like format, skipping individual plots", + ) + # TODO: Implement static mode logging options return parser @@ -428,6 +434,13 @@ def main(): args.bin_freq = config.get("bin_freq", args.bin_freq) # TODO: Include logging options here + # Check if conflicitng command line args + if args.grid or args.grid_only: + if not (args.compare or args.compare_only) and not args.bed: + print( + f"Option --grid was selected, but no comparative plots will be produced. Please rerun ModDotPlot with the `--compare` or `--compare-only` option.\n" + ) + sys.exit(10) # -----------INPUT COMMAND VALIDATION----------- # TODO: More tests! @@ -447,10 +460,16 @@ def main(): # -----------BEDFILE INPUT FOR STATIC MODE----------- if hasattr(args, "bed") and args.bed: + if args.grid or args.grid_only: + single_vals = [] + double_vals = [] + single_val_name = [] + double_val_name = [] try: for bed in args.bed: # If args.bed is provided as input, run static mode directly from the bed file. Skip counting input k-mers. df = read_df_from_file(bed) + unique_query_names = df["#query_name"].unique() unique_reference_names = df["reference_name"].unique() assert len(unique_query_names) == len(unique_reference_names) @@ -458,60 +477,87 @@ def main(): assert len(unique_reference_names) == 1 self_id_scores = df[df["#query_name"] == df["reference_name"]] pairwise_id_scores = df[df["#query_name"] != df["reference_name"]] - print( - f"Input bed file {bed} read successfully! Creating plots... \n" - ) + if not args.grid_only: + print( + f"Input bed file {bed} read successfully! Creating plots... \n" + ) # Create directory if not args.output_dir: args.output_dir = os.getcwd() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if len(self_id_scores) > 1: - create_plots( - sdf=None, - directory=args.output_dir if args.output_dir else ".", - name_x=unique_query_names[0], - name_y=unique_query_names[0], - palette=args.palette, - palette_orientation=args.palette_orientation, - no_hist=args.no_hist, - width=args.width, - dpi=args.dpi, - is_freq=args.bin_freq, - xlim=args.axes_limits, - custom_colors=args.colors, - custom_breakpoints=args.breakpoints, - from_file=df, - is_pairwise=False, - axes_labels=args.axes_ticks, - ) + if not args.grid_only: + create_plots( + sdf=None, + directory=args.output_dir if args.output_dir else ".", + name_x=unique_query_names[0], + name_y=unique_query_names[0], + palette=args.palette, + palette_orientation=args.palette_orientation, + no_hist=args.no_hist, + width=args.width, + dpi=args.dpi, + is_freq=args.bin_freq, + xlim=args.axes_limits, + custom_colors=args.colors, + custom_breakpoints=args.breakpoints, + from_file=df, + is_pairwise=False, + axes_labels=args.axes_ticks, + ) + if args.grid or args.grid_only: + single_vals.append(df) + single_val_name.append(unique_query_names[0]) # Case 2: Pairwise bed file if len(pairwise_id_scores) > 1: - create_plots( - sdf=None, - directory=args.output_dir if args.output_dir else ".", - name_x=unique_query_names[0], - name_y=unique_reference_names[0], - palette=args.palette, - palette_orientation=args.palette_orientation, - no_hist=args.no_hist, - width=args.width, - dpi=args.dpi, - is_freq=args.bin_freq, - xlim=args.axes_limits, # TODO: Get xlim working - custom_colors=args.colors, - custom_breakpoints=args.breakpoints, - from_file=df, - is_pairwise=True, - axes_labels=args.axes_ticks, - ) + if not args.grid_only: + create_plots( + sdf=None, + directory=args.output_dir if args.output_dir else ".", + name_x=unique_query_names[0], + name_y=unique_reference_names[0], + palette=args.palette, + palette_orientation=args.palette_orientation, + no_hist=args.no_hist, + width=args.width, + dpi=args.dpi, + is_freq=args.bin_freq, + xlim=args.axes_limits, + custom_colors=args.colors, + custom_breakpoints=args.axes_ticks, + from_file=df, + is_pairwise=True, + axes_labels=args.axes_ticks, + ) + if args.grid or args.grid_only: + double_vals.append(df) + double_val_name.append([unique_query_names[0],unique_reference_names[0]]) # Exit once all bed files have been iterated through + if args.grid or args.grid_only: + create_grid( + singles=single_vals, + doubles=double_vals, + directory=args.output_dir if args.output_dir else ".", + palette=args.palette, + palette_orientation=args.palette_orientation, + no_hist=args.no_hist, + single_names=single_val_name, + double_names=double_val_name, + is_freq=args.bin_freq, + xlim=args.axes_limits, + custom_colors=args.colors, + custom_breakpoints=args.axes_ticks, + axes_label=args.axes_ticks, + is_bed=True + ) sys.exit(0) except Exception as e: # Exit code 7: Error getting info from bed file: # TODO: Change to logs print(f"Error in bed file: {e}") sys.exit(7) + # -----------INPUT SEQUENCE VALIDATION----------- seq_list = [] @@ -531,7 +577,6 @@ def main(): for i in args.fasta: kmer_list.append(readKmersFromFile(i, args.kmer, False)) k_list = [item for sublist in kmer_list for item in sublist] - # Throw error if compare only selected with one sequence. if len(k_list) < 2 and args.compare_only: print( @@ -817,8 +862,12 @@ def main(): # -----------SETUP STATIC MODE----------- elif args.command == "static": # -----------SET SPARSITY VALUE----------- - # TODO: this is not sorting correctly + if args.grid or args.grid_only: + grid_val_singles = [] + grid_val_single_names = [] sequences = list(zip(seq_list, k_list)) + if len(sequences) > 6 and (args.grid or args.grid_only): + print("Too many sequences to create a grid. Skipping. \n") # Create output directory, if doesn't exist: if (args.output_dir) and not os.path.exists(args.output_dir): @@ -835,6 +884,7 @@ def main(): else: win = math.ceil(seq_length / args.resolution) + print(win) if win < args.modimizer: args.modimizer = win if win < 10: @@ -884,6 +934,9 @@ def main(): bed = convertMatrixToBed( self_mat, win, args.identity, seq_list[i], seq_list[i], True ) + if args.grid or args.grid_only: + grid_val_singles.append(bed) + grid_val_single_names.append(sequences[i][0]) if not args.no_bed: # Log saving bed file @@ -938,8 +991,9 @@ def main(): seq_sparsity = 2 ** (int(math.log2(seq_sparsity - 1)) + 1) expectation = round(win / seq_sparsity) - if args.grid: - grid_vals = [] + if args.grid or args.grid_only: + grid_val_doubles = [] + grid_val_double_names = [] for i in range(len(sequences)): larger_seq = sequences[i][1] @@ -987,7 +1041,7 @@ def main(): expectation, ) # Throw error if the matrix is empty - if np.all(pair_mat == 0): + if np.all(pair_mat == 0) and (not (args.grid or args.grid_only)): print( f"The pairwise identity matrix for {sequences[i][0]} and {sequences[j][0]} is empty. Skipping.\n" ) @@ -1000,8 +1054,9 @@ def main(): seq_list[j], False, ) - if args.grid: - grid_vals.append(bed) + if args.grid or args.grid_only: + grid_val_doubles.append(bed) + grid_val_double_names.append([larger_seq_name,smaller_seq_name]) if not args.no_bed: # Log saving bed file @@ -1045,8 +1100,23 @@ def main(): axes_labels=args.axes_ticks, ) - """if args.grid: - print(grid_vals)""" + if args.grid or args.grid_only: + create_grid( + singles=grid_val_singles, + doubles=grid_val_doubles, + directory=args.output_dir if args.output_dir else ".", + palette=args.palette, + palette_orientation=args.palette_orientation, + no_hist=args.no_hist, + single_names=grid_val_single_names, + double_names=grid_val_double_names, + is_freq=args.bin_freq, + xlim=args.axes_limits, + custom_colors=args.colors, + custom_breakpoints=args.axes_ticks, + axes_label=args.axes_ticks, + is_bed=False + ) if __name__ == "__main__": diff --git a/src/moddotplot/static_plots.py b/src/moddotplot/static_plots.py index 3724c3a..9687ab6 100644 --- a/src/moddotplot/static_plots.py +++ b/src/moddotplot/static_plots.py @@ -21,6 +21,11 @@ element_text, theme_light, theme_void, + geom_blank, + annotate, + element_rect, + coord_flip, + theme_minimal ) import pandas as pd import numpy as np @@ -28,6 +33,8 @@ import patchworklib as pw import math import os +import sys +from moddotplot.parse_fasta import printProgressBar from moddotplot.const import ( DIVERGING_PALETTES, @@ -37,6 +44,55 @@ from typing import List from palettable.colorbrewer import qualitative, sequential, diverging +def is_plot_empty(p): + # Check if the plot has data or any layers + return len(p.layers) == 0 and p.data.empty + +def check_pascal(single_val, double_val): + try: + if len(single_val) == 2: + assert len(double_val) == 1 + elif len(single_val) == 3: + assert len(double_val) == 3 + elif len(single_val) == 4: + assert len(double_val) == 6 + elif len(single_val) == 5: + assert len(double_val) == 10 + elif len(single_val) == 6: + assert len(double_val) == 15 + elif len(single_val) == 0: + assert len(double_val) == (1 or 3 or 6 or 10 or 15) + except AssertionError as e: + print(f"Missing bed files required to create grid. Please verify all bed files are included.") + sys.exit(8) + +def reverse_pascal(double_vals): + if len(double_vals) == 1: + return 2 + elif len(double_vals) == 3: + return 3 + elif len(double_vals) == 6: + return 4 + elif len(double_vals) == 10: + return 5 + elif len(double_vals) == 15: + return 6 + else: + sys.exit(9) + +# Hardcoding for now, I have the formula just lazy +def transpose_order(double_vals): + if len(double_vals) == 1: + return [0] + elif len(double_vals) == 3: + return [2,0,1] + elif len(double_vals) == 6: + return [5,2,0,4,1,3] + elif len(double_vals) == 10: + return [9,5,2,0,8,4,1,7,3,6] + elif len(double_vals) == 15: + return [14,9,5,2,0,13,8,4,1,12,7,3,11,6,10] + def check_st_en_equality(df): unequal_rows = df[(df["q_st"] != df["r_st"]) | (df["q_en"] != df["r_en"])] @@ -110,7 +166,10 @@ def overlap_axis(rotated_plot, filename, prefix): def get_colors(sdf, ncolors, is_freq, custom_breakpoints): assert ncolors > 2 and ncolors < 12 - bot = math.floor(min(sdf["perID_by_events"])) + try: + bot = math.floor(min(sdf["perID_by_events"])) + except ValueError: + bot = 0 top = 100.0 interval = (top - bot) / ncolors breaks = [] @@ -200,7 +259,10 @@ def read_df( ) # Calculate the window size - window = max(df["q_en"] - df["q_st"]) + try: + window = max(df["q_en"] - df["q_st"]) + except ValueError: + window = 0 # Calculate the position of the first and second intervals df["first_pos"] = df["q_st"] / window @@ -243,7 +305,15 @@ def make_dot(sdf, title_name, palette, palette_orientation, colors, breaks, xlim if colors: new_hexcodes = colors max_val = max(sdf["q_en"].max(), sdf["r_en"].max(), xlim) - window = max(sdf["q_en"] - sdf["q_st"]) + try: + window = max(sdf["q_en"] - sdf["q_st"]) + except: + p = (ggplot(aes(x=[], y=[])) + + theme_minimal() # Use a minimal theme with gridlines + + theme(panel_grid_major=element_blank(), # Remove major gridlines (optional) + panel_grid_minor=element_blank()) # Remove minor gridlines (optional) + ) + return p if max_val < 100000: x_label = "Genomic Position (Kbp)" elif max_val < 100000000: @@ -287,6 +357,181 @@ def make_dot(sdf, title_name, palette, palette_orientation, colors, breaks, xlim return p +def make_dot2(sdf, title_name, palette, palette_orientation, colors, breaks, xlim): + if not breaks: + breaks = True + else: + breaks = [float(number) for number in breaks] + if not xlim: + xlim = 0 + hexcodes = [] + new_hexcodes = [] + if palette in DIVERGING_PALETTES: + function_name = getattr(diverging, palette) + hexcodes = function_name.hex_colors + if palette_orientation == "+": + palette_orientation = "-" + else: + palette_orientation = "+" + elif palette in QUALITATIVE_PALETTES: + function_name = getattr(qualitative, palette) + hexcodes = function_name.hex_colors + elif palette in SEQUENTIAL_PALETTES: + function_name = getattr(sequential, palette) + hexcodes = function_name.hex_colors + else: + function_name = getattr(sequential, "Spectral_11") + palette_orientation = "-" + hexcodes = function_name.hex_colors + + if palette_orientation == "-": + new_hexcodes = hexcodes[::-1] + else: + new_hexcodes = hexcodes + if colors: + new_hexcodes = colors + max_val = max(sdf["q_en"].max(), sdf["r_en"].max(), xlim) + try: + window = max(sdf["q_en"] - sdf["q_st"]) + except: + p = (ggplot(aes(x=[], y=[])) + + theme_minimal() # Use a minimal theme with gridlines + + theme(panel_grid_major=element_blank(), # Remove major gridlines (optional) + panel_grid_minor=element_blank()) # Remove minor gridlines (optional) + ) + return p + if max_val < 100000: + x_label = "Genomic Position (Kbp)" + elif max_val < 100000000: + x_label = "Genomic Position (Mbp)" + else: + x_label = "Genomic Position (Gbp)" + p = ( + ggplot(sdf) + + geom_tile( + aes(x="q_st", y="r_st", fill="discrete", height=window, width=window) + ) + + scale_color_discrete(guide=False) + + scale_fill_manual( + values=new_hexcodes, + guide=False, + ) + + theme( + legend_position="none", + panel_grid_major=element_blank(), + panel_grid_minor=element_blank(), + plot_background=element_blank(), + panel_background=element_blank(), + axis_line=element_line(color="black"), # Adjust axis line size + axis_text=element_text( + family=["Dejavu Sans"] + ), # Change axis text font and size + axis_ticks_major=element_line(), + title=element_text( + family=["Dejavu Sans"], # Change title font family + ), + ) + + scale_x_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + + scale_y_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + + coord_fixed(ratio=1) + + labs(x=None, y=None, title=title_name) + ) + + # Adjust x-axis label size + p += theme(axis_title_x=element_blank()) + p += theme(axis_title_y=element_blank()) + #p += theme(title=element_blank()) + + return p + +def make_dot3(sdf, title_name, palette, palette_orientation, colors, breaks, xlim): + #sdf = sdf.transpose() + if not breaks: + breaks = True + else: + breaks = [float(number) for number in breaks] + if not xlim: + xlim = 0 + hexcodes = [] + new_hexcodes = [] + if palette in DIVERGING_PALETTES: + function_name = getattr(diverging, palette) + hexcodes = function_name.hex_colors + if palette_orientation == "+": + palette_orientation = "-" + else: + palette_orientation = "+" + elif palette in QUALITATIVE_PALETTES: + function_name = getattr(qualitative, palette) + hexcodes = function_name.hex_colors + elif palette in SEQUENTIAL_PALETTES: + function_name = getattr(sequential, palette) + hexcodes = function_name.hex_colors + else: + function_name = getattr(sequential, "Spectral_11") + palette_orientation = "-" + hexcodes = function_name.hex_colors + + if palette_orientation == "-": + new_hexcodes = hexcodes[::-1] + else: + new_hexcodes = hexcodes + if colors: + new_hexcodes = colors + max_val = max(sdf["q_en"].max(), sdf["r_en"].max(), xlim) + try: + window = max(sdf["q_en"] - sdf["q_st"]) + except: + p = (ggplot(aes(x=[], y=[])) + + theme_minimal() # Use a minimal theme with gridlines + + theme(panel_grid_major=element_blank(), # Remove major gridlines (optional) + panel_grid_minor=element_blank()) # Remove minor gridlines (optional) + ) + return p + if max_val < 100000: + x_label = "Genomic Position (Kbp)" + elif max_val < 100000000: + x_label = "Genomic Position (Mbp)" + else: + x_label = "Genomic Position (Gbp)" + p = ( + ggplot(sdf) + + geom_tile( + aes(x="r_st", y="q_st", fill="discrete", height=window, width=window) + ) + + scale_color_discrete(guide=False) + + scale_fill_manual( + values=new_hexcodes, + guide=False, + ) + + theme( + legend_position="none", + panel_grid_major=element_blank(), + panel_grid_minor=element_blank(), + plot_background=element_blank(), + panel_background=element_blank(), + axis_line=element_line(color="black"), # Adjust axis line size + axis_text=element_text( + family=["Dejavu Sans"] + ), # Change axis text font and size + axis_ticks_major=element_line(), + title=element_text( + family=["Dejavu Sans"], # Change title font family + ), + ) + + scale_x_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + + scale_y_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + + coord_fixed(ratio=1) + + labs(x=None, y=None, title=title_name) + ) + + # Adjust x-axis label size + p += theme(axis_title_x=element_blank()) + p += theme(axis_title_y=element_blank()) + #p += theme(title=element_blank()) + + return p + def make_tri(sdf, title_name, palette, palette_orientation, colors, breaks, xlim): if not breaks: @@ -477,7 +722,10 @@ def make_hist(sdf, palette, palette_orientation, custom_colors, custom_breakpoin if custom_colors: new_hexcodes = custom_colors - bot = np.quantile(sdf["perID_by_events"], q=0.001) + try: + bot = np.quantile(sdf["perID_by_events"], q=0.001) + except IndexError: + bot = 0 count = sdf.shape[0] extra = "" @@ -503,22 +751,245 @@ def create_grid( singles, doubles, directory, - name_x, - name_y, palette, palette_orientation, no_hist, - width, - dpi, + single_names, + double_names, is_freq, xlim, custom_colors, custom_breakpoints, - from_file, - is_pairwise, axes_label, + is_bed ): - print(singles) + new_index = [] + transpose_index = [] + check_pascal(singles,doubles) + #Singles can be empty if not selected + for i in range(len(single_names)): + for j in range(i+1, len(single_names)): + try: + index = double_names.index([single_names[i],single_names[j]]) + transpose_index.append(0) + except: + index = double_names.index([single_names[j],single_names[i]]) + transpose_index.append(1) + + new_index.append(index) + + single_list = [] + double_list = [] + single_heatmap_list = [] + normal_heatmap_list = [] + transpose_heatmap_list = [] + for matrix in singles: + if is_bed: + df = read_df( + None, + palette, + palette_orientation, + is_freq, + custom_colors, + custom_breakpoints, + matrix, + ) + else: + df = read_df( + [matrix], + palette, + palette_orientation, + is_freq, + custom_colors, + custom_breakpoints, + None, + ) + single_list.append(df) + for matrix in doubles: + if is_bed: + df = read_df( + None, + palette, + palette_orientation, + is_freq, + custom_colors, + custom_breakpoints, + matrix, + ) + else: + df = read_df( + [matrix], + palette, + palette_orientation, + is_freq, + custom_colors, + custom_breakpoints, + None, + ) + double_list.append(df) + for plot in single_list: + heatmap = make_dot2( + plot, + plot['q'].iloc[1], + palette, + palette_orientation, + custom_colors, + axes_label, + xlim, + ) + single_heatmap_list.append(heatmap) + for indie in new_index: + xd = new_index.index(indie) + if transpose_index[xd] == 0: + heatmap = make_dot2( + double_list[indie], + f"{double_names[indie][0]}_{double_names[indie][1]}", + palette, + palette_orientation, + custom_colors, + axes_label, + xlim, + ) + normal_heatmap_list.append(heatmap) + heatmap_t = make_dot3( + double_list[indie], + f"{double_names[indie][1]}_{double_names[indie][0]}", + palette, + palette_orientation, + custom_colors, + axes_label, + xlim, + ) + transpose_heatmap_list.append(heatmap_t) + else: + heatmap = make_dot3( + double_list[indie], + f"{double_names[indie][0]}_{double_names[indie][1]}", + palette, + palette_orientation, + custom_colors, + axes_label, + xlim, + ) + normal_heatmap_list.append(heatmap) + heatmap_t = make_dot2( + double_list[indie], + f"{double_names[indie][1]}_{double_names[indie][0]}", + palette, + palette_orientation, + custom_colors, + axes_label, + xlim, + ) + transpose_heatmap_list.append(heatmap_t) + + assert len(transpose_heatmap_list) == len(normal_heatmap_list) + single_length = len(single_heatmap_list) + if single_length == 0: + single_length = reverse_pascal(len(normal_heatmap_list)) + + normal_counter = 0 + trans_counter = 0 + start_grid = pw.Brick(figsize=(9,9)) + n = single_length * single_length + + printProgressBar(0, n, prefix="Progress:", suffix="Complete", length=40) + tots = 0 + col_names = pw.Brick(figsize=(2,9)) + row_names = pw.Brick(figsize=(9,2)) + for i in range(single_length): + row_grid = pw.Brick(figsize=(9,9)) + for j in range(single_length): + if i == j: + if len(single_heatmap_list) == 0: + g1 = pw.Brick(figsize=(9,9)) + else: + g1 = pw.load_ggplot(single_heatmap_list[i], figsize=(9,9)) + + elif i < j: + g1 = pw.load_ggplot(normal_heatmap_list[normal_counter], figsize=(9,9)) + normal_counter += 1 + elif i > j: + g1 = pw.load_ggplot(transpose_heatmap_list[trans_counter], figsize=(9,9)) + trans_counter += 1 + if j == 0: + row_grid = g1 + else: + row_grid = (row_grid|g1) + tots += 1 + printProgressBar(tots, n, prefix="Progress:", suffix="Complete", length=40) + if i == 0: + start_grid = row_grid + else: + start_grid = (row_grid/start_grid) + for w in range(single_length): + p1 = ( + ggplot() + + geom_blank() + # Use geom_blank to create a plot with no data + annotate('text', x=0, y=0, label=single_names[w], size=32, angle=90, ha='center', va='center') + + theme( + # Center the plot area and make backgrounds transparent + axis_title_x=element_blank(), + axis_title_y=element_blank(), + axis_ticks=element_blank(), + axis_text=element_blank(), + plot_background=element_rect(fill='none'), # Transparent plot background + panel_background=element_rect(fill='none'), # Transparent panel background + panel_grid=element_blank(), + aspect_ratio=0.5 # Adjust the aspect ratio for the desired width/height + ) + + coord_flip() # Rotate the plot 90 degrees counterclockwise + ) + p2 = ( + ggplot() + + geom_blank() + # Use geom_blank to create a plot with no data + annotate('text', x=0, y=0, label=single_names[w], size=32, ha='center', va='center') + + theme( + # Center the plot area and make backgrounds transparent + axis_title_x=element_blank(), + axis_title_y=element_blank(), + axis_ticks=element_blank(), + axis_text=element_blank(), + plot_background=element_rect(fill='none'), # Transparent plot background + panel_background=element_rect(fill='none'), # Transparent panel background + panel_grid=element_blank(), + aspect_ratio=0.5 # Adjust the aspect ratio for the desired width/height + ) + ) + g1 = pw.load_ggplot(p1, figsize=(2,9)) + g2 = pw.load_ggplot(p2, figsize=(9,2)) + + if w == 0: + col_names = g1 + row_names = g2 + else: + col_names = g1/col_names + row_names = row_names|g2 + #Create a ghost 2x2 + pghost = ( + ggplot() + + geom_blank() + # Use geom_blank to create a plot with no data + annotate('text', x=0, y=0, label="", size=32, ha='center', va='center') + + theme( + # Center the plot area and make backgrounds transparent + axis_title_x=element_blank(), + axis_title_y=element_blank(), + axis_ticks=element_blank(), + axis_text=element_blank(), + plot_background=element_rect(fill='none'), # Transparent plot background + panel_background=element_rect(fill='none'), # Transparent panel background + panel_grid=element_blank(), + aspect_ratio=0.5 # Adjust the aspect ratio for the desired width/height + ) + ) + ghosty = pw.load_ggplot(pghost, figsize=(2,2)) + col_names = ghosty/col_names + start_grid = (col_names|(row_names/start_grid)) + gridname = f"{single_length}x{single_length}_GRID" + print(f"\nGrid complete! Saving to {directory}/{gridname}...\n") + start_grid.savefig(f"{directory}/{gridname}.png") + start_grid.savefig(f"{directory}/{gridname}.pdf", format="pdf") + print(f"Grid saved successfully!\n") def create_plots( @@ -586,6 +1057,17 @@ def create_plots( filename=f"{plot_filename}_COMPARE.png", verbose=False, ) + try: + if not heatmap.data: + print( + f"{plot_filename}_COMPARE.pdf and {plot_filename}_COMPARE.png saved sucessfully. \n" + ) + return 0 + except ValueError: + print( + f"{plot_filename}_COMPARE.pdf and {plot_filename}_COMPARE.png saved sucessfully. \n" + ) + return 0 if no_hist: print( f"{plot_filename}_COMPARE.pdf and {plot_filename}_COMPARE.png saved sucessfully. \n" From 363ed501be25cb8ae30eca947576ec73457e2882 Mon Sep 17 00:00:00 2001 From: alexsweeten Date: Sun, 6 Oct 2024 13:30:01 -0500 Subject: [PATCH 2/5] Update grid functionality --- src/moddotplot/moddotplot.py | 6 ++++++ src/moddotplot/static_plots.py | 14 +++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/moddotplot/moddotplot.py b/src/moddotplot/moddotplot.py index 3c967b9..07d67ee 100644 --- a/src/moddotplot/moddotplot.py +++ b/src/moddotplot/moddotplot.py @@ -535,6 +535,9 @@ def main(): double_val_name.append([unique_query_names[0],unique_reference_names[0]]) # Exit once all bed files have been iterated through if args.grid or args.grid_only: + print( + f"Creating a {len(sequences)}x{len(sequences)} grid.\n" + ) create_grid( singles=single_vals, doubles=double_vals, @@ -1101,6 +1104,9 @@ def main(): ) if args.grid or args.grid_only: + print( + f"Creating a {len(sequences)}x{len(sequences)} grid.\n" + ) create_grid( singles=grid_val_singles, doubles=grid_val_doubles, diff --git a/src/moddotplot/static_plots.py b/src/moddotplot/static_plots.py index 9687ab6..516e2f0 100644 --- a/src/moddotplot/static_plots.py +++ b/src/moddotplot/static_plots.py @@ -85,13 +85,13 @@ def transpose_order(double_vals): if len(double_vals) == 1: return [0] elif len(double_vals) == 3: - return [2,0,1] + return [0,1,2] elif len(double_vals) == 6: - return [5,2,0,4,1,3] + return [0,1,3,2,4,5] elif len(double_vals) == 10: - return [9,5,2,0,8,4,1,7,3,6] + return [0,1,4,2,5,7,3,6,8,9] elif len(double_vals) == 15: - return [14,9,5,2,0,13,8,4,1,12,7,3,11,6,10] + return [0,1,5,2,6,9,3,7,10,12,4,8,11,13,14] def check_st_en_equality(df): @@ -890,9 +890,13 @@ def create_grid( normal_counter = 0 trans_counter = 0 + trans_to_use = transpose_order(normal_heatmap_list) start_grid = pw.Brick(figsize=(9,9)) n = single_length * single_length + if n > 4: + print(f"This might take a while\n...\n") + printProgressBar(0, n, prefix="Progress:", suffix="Complete", length=40) tots = 0 col_names = pw.Brick(figsize=(2,9)) @@ -910,7 +914,7 @@ def create_grid( g1 = pw.load_ggplot(normal_heatmap_list[normal_counter], figsize=(9,9)) normal_counter += 1 elif i > j: - g1 = pw.load_ggplot(transpose_heatmap_list[trans_counter], figsize=(9,9)) + g1 = pw.load_ggplot(transpose_heatmap_list[trans_to_use[trans_counter]], figsize=(9,9)) trans_counter += 1 if j == 0: row_grid = g1 From 9563b4abb30f3cfa545b359e5abf430f32728171 Mon Sep 17 00:00:00 2001 From: alexsweeten Date: Sun, 6 Oct 2024 13:30:45 -0500 Subject: [PATCH 3/5] lint with black --- src/moddotplot/moddotplot.py | 23 ++-- src/moddotplot/static_plots.py | 204 ++++++++++++++++++++------------- 2 files changed, 136 insertions(+), 91 deletions(-) diff --git a/src/moddotplot/moddotplot.py b/src/moddotplot/moddotplot.py index 07d67ee..54e1fdc 100644 --- a/src/moddotplot/moddotplot.py +++ b/src/moddotplot/moddotplot.py @@ -469,7 +469,7 @@ def main(): for bed in args.bed: # If args.bed is provided as input, run static mode directly from the bed file. Skip counting input k-mers. df = read_df_from_file(bed) - + unique_query_names = df["#query_name"].unique() unique_reference_names = df["reference_name"].unique() assert len(unique_query_names) == len(unique_reference_names) @@ -532,12 +532,12 @@ def main(): ) if args.grid or args.grid_only: double_vals.append(df) - double_val_name.append([unique_query_names[0],unique_reference_names[0]]) + double_val_name.append( + [unique_query_names[0], unique_reference_names[0]] + ) # Exit once all bed files have been iterated through if args.grid or args.grid_only: - print( - f"Creating a {len(sequences)}x{len(sequences)} grid.\n" - ) + print(f"Creating a {len(sequences)}x{len(sequences)} grid.\n") create_grid( singles=single_vals, doubles=double_vals, @@ -552,7 +552,7 @@ def main(): custom_colors=args.colors, custom_breakpoints=args.axes_ticks, axes_label=args.axes_ticks, - is_bed=True + is_bed=True, ) sys.exit(0) except Exception as e: @@ -560,7 +560,6 @@ def main(): # TODO: Change to logs print(f"Error in bed file: {e}") sys.exit(7) - # -----------INPUT SEQUENCE VALIDATION----------- seq_list = [] @@ -1059,7 +1058,9 @@ def main(): ) if args.grid or args.grid_only: grid_val_doubles.append(bed) - grid_val_double_names.append([larger_seq_name,smaller_seq_name]) + grid_val_double_names.append( + [larger_seq_name, smaller_seq_name] + ) if not args.no_bed: # Log saving bed file @@ -1104,9 +1105,7 @@ def main(): ) if args.grid or args.grid_only: - print( - f"Creating a {len(sequences)}x{len(sequences)} grid.\n" - ) + print(f"Creating a {len(sequences)}x{len(sequences)} grid.\n") create_grid( singles=grid_val_singles, doubles=grid_val_doubles, @@ -1121,7 +1120,7 @@ def main(): custom_colors=args.colors, custom_breakpoints=args.axes_ticks, axes_label=args.axes_ticks, - is_bed=False + is_bed=False, ) diff --git a/src/moddotplot/static_plots.py b/src/moddotplot/static_plots.py index 516e2f0..dc9523b 100644 --- a/src/moddotplot/static_plots.py +++ b/src/moddotplot/static_plots.py @@ -25,7 +25,7 @@ annotate, element_rect, coord_flip, - theme_minimal + theme_minimal, ) import pandas as pd import numpy as np @@ -44,10 +44,12 @@ from typing import List from palettable.colorbrewer import qualitative, sequential, diverging + def is_plot_empty(p): # Check if the plot has data or any layers return len(p.layers) == 0 and p.data.empty + def check_pascal(single_val, double_val): try: if len(single_val) == 2: @@ -63,9 +65,12 @@ def check_pascal(single_val, double_val): elif len(single_val) == 0: assert len(double_val) == (1 or 3 or 6 or 10 or 15) except AssertionError as e: - print(f"Missing bed files required to create grid. Please verify all bed files are included.") + print( + f"Missing bed files required to create grid. Please verify all bed files are included." + ) sys.exit(8) + def reverse_pascal(double_vals): if len(double_vals) == 1: return 2 @@ -80,18 +85,19 @@ def reverse_pascal(double_vals): else: sys.exit(9) + # Hardcoding for now, I have the formula just lazy def transpose_order(double_vals): if len(double_vals) == 1: return [0] elif len(double_vals) == 3: - return [0,1,2] + return [0, 1, 2] elif len(double_vals) == 6: - return [0,1,3,2,4,5] + return [0, 1, 3, 2, 4, 5] elif len(double_vals) == 10: - return [0,1,4,2,5,7,3,6,8,9] + return [0, 1, 4, 2, 5, 7, 3, 6, 8, 9] elif len(double_vals) == 15: - return [0,1,5,2,6,9,3,7,10,12,4,8,11,13,14] + return [0, 1, 5, 2, 6, 9, 3, 7, 10, 12, 4, 8, 11, 13, 14] def check_st_en_equality(df): @@ -308,11 +314,14 @@ def make_dot(sdf, title_name, palette, palette_orientation, colors, breaks, xlim try: window = max(sdf["q_en"] - sdf["q_st"]) except: - p = (ggplot(aes(x=[], y=[])) + p = ( + ggplot(aes(x=[], y=[])) + theme_minimal() # Use a minimal theme with gridlines - + theme(panel_grid_major=element_blank(), # Remove major gridlines (optional) - panel_grid_minor=element_blank()) # Remove minor gridlines (optional) - ) + + theme( + panel_grid_major=element_blank(), # Remove major gridlines (optional) + panel_grid_minor=element_blank(), + ) # Remove minor gridlines (optional) + ) return p if max_val < 100000: x_label = "Genomic Position (Kbp)" @@ -357,6 +366,7 @@ def make_dot(sdf, title_name, palette, palette_orientation, colors, breaks, xlim return p + def make_dot2(sdf, title_name, palette, palette_orientation, colors, breaks, xlim): if not breaks: breaks = True @@ -394,11 +404,14 @@ def make_dot2(sdf, title_name, palette, palette_orientation, colors, breaks, xli try: window = max(sdf["q_en"] - sdf["q_st"]) except: - p = (ggplot(aes(x=[], y=[])) + p = ( + ggplot(aes(x=[], y=[])) + theme_minimal() # Use a minimal theme with gridlines - + theme(panel_grid_major=element_blank(), # Remove major gridlines (optional) - panel_grid_minor=element_blank()) # Remove minor gridlines (optional) - ) + + theme( + panel_grid_major=element_blank(), # Remove major gridlines (optional) + panel_grid_minor=element_blank(), + ) # Remove minor gridlines (optional) + ) return p if max_val < 100000: x_label = "Genomic Position (Kbp)" @@ -440,12 +453,13 @@ def make_dot2(sdf, title_name, palette, palette_orientation, colors, breaks, xli # Adjust x-axis label size p += theme(axis_title_x=element_blank()) p += theme(axis_title_y=element_blank()) - #p += theme(title=element_blank()) + # p += theme(title=element_blank()) return p + def make_dot3(sdf, title_name, palette, palette_orientation, colors, breaks, xlim): - #sdf = sdf.transpose() + # sdf = sdf.transpose() if not breaks: breaks = True else: @@ -482,11 +496,14 @@ def make_dot3(sdf, title_name, palette, palette_orientation, colors, breaks, xli try: window = max(sdf["q_en"] - sdf["q_st"]) except: - p = (ggplot(aes(x=[], y=[])) + p = ( + ggplot(aes(x=[], y=[])) + theme_minimal() # Use a minimal theme with gridlines - + theme(panel_grid_major=element_blank(), # Remove major gridlines (optional) - panel_grid_minor=element_blank()) # Remove minor gridlines (optional) - ) + + theme( + panel_grid_major=element_blank(), # Remove major gridlines (optional) + panel_grid_minor=element_blank(), + ) # Remove minor gridlines (optional) + ) return p if max_val < 100000: x_label = "Genomic Position (Kbp)" @@ -528,7 +545,7 @@ def make_dot3(sdf, title_name, palette, palette_orientation, colors, breaks, xli # Adjust x-axis label size p += theme(axis_title_x=element_blank()) p += theme(axis_title_y=element_blank()) - #p += theme(title=element_blank()) + # p += theme(title=element_blank()) return p @@ -761,19 +778,19 @@ def create_grid( custom_colors, custom_breakpoints, axes_label, - is_bed + is_bed, ): new_index = [] transpose_index = [] - check_pascal(singles,doubles) - #Singles can be empty if not selected + check_pascal(singles, doubles) + # Singles can be empty if not selected for i in range(len(single_names)): - for j in range(i+1, len(single_names)): + for j in range(i + 1, len(single_names)): try: - index = double_names.index([single_names[i],single_names[j]]) + index = double_names.index([single_names[i], single_names[j]]) transpose_index.append(0) except: - index = double_names.index([single_names[j],single_names[i]]) + index = double_names.index([single_names[j], single_names[i]]) transpose_index.append(1) new_index.append(index) @@ -830,7 +847,7 @@ def create_grid( for plot in single_list: heatmap = make_dot2( plot, - plot['q'].iloc[1], + plot["q"].iloc[1], palette, palette_orientation, custom_colors, @@ -887,11 +904,11 @@ def create_grid( single_length = len(single_heatmap_list) if single_length == 0: single_length = reverse_pascal(len(normal_heatmap_list)) - + normal_counter = 0 trans_counter = 0 trans_to_use = transpose_order(normal_heatmap_list) - start_grid = pw.Brick(figsize=(9,9)) + start_grid = pw.Brick(figsize=(9, 9)) n = single_length * single_length if n > 4: @@ -899,96 +916,125 @@ def create_grid( printProgressBar(0, n, prefix="Progress:", suffix="Complete", length=40) tots = 0 - col_names = pw.Brick(figsize=(2,9)) - row_names = pw.Brick(figsize=(9,2)) + col_names = pw.Brick(figsize=(2, 9)) + row_names = pw.Brick(figsize=(9, 2)) for i in range(single_length): - row_grid = pw.Brick(figsize=(9,9)) + row_grid = pw.Brick(figsize=(9, 9)) for j in range(single_length): if i == j: if len(single_heatmap_list) == 0: - g1 = pw.Brick(figsize=(9,9)) + g1 = pw.Brick(figsize=(9, 9)) else: - g1 = pw.load_ggplot(single_heatmap_list[i], figsize=(9,9)) - + g1 = pw.load_ggplot(single_heatmap_list[i], figsize=(9, 9)) + elif i < j: - g1 = pw.load_ggplot(normal_heatmap_list[normal_counter], figsize=(9,9)) + g1 = pw.load_ggplot(normal_heatmap_list[normal_counter], figsize=(9, 9)) normal_counter += 1 elif i > j: - g1 = pw.load_ggplot(transpose_heatmap_list[trans_to_use[trans_counter]], figsize=(9,9)) + g1 = pw.load_ggplot( + transpose_heatmap_list[trans_to_use[trans_counter]], figsize=(9, 9) + ) trans_counter += 1 if j == 0: row_grid = g1 else: - row_grid = (row_grid|g1) + row_grid = row_grid | g1 tots += 1 printProgressBar(tots, n, prefix="Progress:", suffix="Complete", length=40) if i == 0: start_grid = row_grid else: - start_grid = (row_grid/start_grid) + start_grid = row_grid / start_grid for w in range(single_length): p1 = ( - ggplot() + - geom_blank() + # Use geom_blank to create a plot with no data - annotate('text', x=0, y=0, label=single_names[w], size=32, angle=90, ha='center', va='center') + - theme( + ggplot() + + geom_blank() + + annotate( # Use geom_blank to create a plot with no data + "text", + x=0, + y=0, + label=single_names[w], + size=32, + angle=90, + ha="center", + va="center", + ) + + theme( # Center the plot area and make backgrounds transparent axis_title_x=element_blank(), axis_title_y=element_blank(), axis_ticks=element_blank(), axis_text=element_blank(), - plot_background=element_rect(fill='none'), # Transparent plot background - panel_background=element_rect(fill='none'), # Transparent panel background + plot_background=element_rect( + fill="none" + ), # Transparent plot background + panel_background=element_rect( + fill="none" + ), # Transparent panel background panel_grid=element_blank(), - aspect_ratio=0.5 # Adjust the aspect ratio for the desired width/height - ) + - coord_flip() # Rotate the plot 90 degrees counterclockwise + aspect_ratio=0.5, # Adjust the aspect ratio for the desired width/height + ) + + coord_flip() # Rotate the plot 90 degrees counterclockwise ) p2 = ( - ggplot() + - geom_blank() + # Use geom_blank to create a plot with no data - annotate('text', x=0, y=0, label=single_names[w], size=32, ha='center', va='center') + - theme( + ggplot() + + geom_blank() + + annotate( # Use geom_blank to create a plot with no data + "text", + x=0, + y=0, + label=single_names[w], + size=32, + ha="center", + va="center", + ) + + theme( # Center the plot area and make backgrounds transparent axis_title_x=element_blank(), axis_title_y=element_blank(), axis_ticks=element_blank(), axis_text=element_blank(), - plot_background=element_rect(fill='none'), # Transparent plot background - panel_background=element_rect(fill='none'), # Transparent panel background + plot_background=element_rect( + fill="none" + ), # Transparent plot background + panel_background=element_rect( + fill="none" + ), # Transparent panel background panel_grid=element_blank(), - aspect_ratio=0.5 # Adjust the aspect ratio for the desired width/height + aspect_ratio=0.5, # Adjust the aspect ratio for the desired width/height ) ) - g1 = pw.load_ggplot(p1, figsize=(2,9)) - g2 = pw.load_ggplot(p2, figsize=(9,2)) - + g1 = pw.load_ggplot(p1, figsize=(2, 9)) + g2 = pw.load_ggplot(p2, figsize=(9, 2)) + if w == 0: col_names = g1 row_names = g2 else: - col_names = g1/col_names - row_names = row_names|g2 - #Create a ghost 2x2 + col_names = g1 / col_names + row_names = row_names | g2 + # Create a ghost 2x2 pghost = ( - ggplot() + - geom_blank() + # Use geom_blank to create a plot with no data - annotate('text', x=0, y=0, label="", size=32, ha='center', va='center') + - theme( - # Center the plot area and make backgrounds transparent - axis_title_x=element_blank(), - axis_title_y=element_blank(), - axis_ticks=element_blank(), - axis_text=element_blank(), - plot_background=element_rect(fill='none'), # Transparent plot background - panel_background=element_rect(fill='none'), # Transparent panel background - panel_grid=element_blank(), - aspect_ratio=0.5 # Adjust the aspect ratio for the desired width/height - ) + ggplot() + + geom_blank() + + annotate( # Use geom_blank to create a plot with no data + "text", x=0, y=0, label="", size=32, ha="center", va="center" + ) + + theme( + # Center the plot area and make backgrounds transparent + axis_title_x=element_blank(), + axis_title_y=element_blank(), + axis_ticks=element_blank(), + axis_text=element_blank(), + plot_background=element_rect(fill="none"), # Transparent plot background + panel_background=element_rect(fill="none"), # Transparent panel background + panel_grid=element_blank(), + aspect_ratio=0.5, # Adjust the aspect ratio for the desired width/height ) - ghosty = pw.load_ggplot(pghost, figsize=(2,2)) - col_names = ghosty/col_names - start_grid = (col_names|(row_names/start_grid)) + ) + ghosty = pw.load_ggplot(pghost, figsize=(2, 2)) + col_names = ghosty / col_names + start_grid = col_names | (row_names / start_grid) gridname = f"{single_length}x{single_length}_GRID" print(f"\nGrid complete! Saving to {directory}/{gridname}...\n") start_grid.savefig(f"{directory}/{gridname}.png") From f5f10d9c0714a4bcbc670b0ad4529c7a22e78205 Mon Sep 17 00:00:00 2001 From: alexsweeten Date: Sun, 6 Oct 2024 13:31:32 -0500 Subject: [PATCH 4/5] bump version --- pyproject.toml | 2 +- src/moddotplot/const.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ce477f..97433d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ModDotPlot" -version = "0.8.8" +version = "0.9.0" requires-python = ">= 3.7" dependencies = [ "pysam", diff --git a/src/moddotplot/const.py b/src/moddotplot/const.py index 7e0bf70..9d3202c 100644 --- a/src/moddotplot/const.py +++ b/src/moddotplot/const.py @@ -1,4 +1,4 @@ -VERSION = "0.8.8" +VERSION = "0.9.0" COLS = [ "#query_name", "query_start", From 50ecda4eff91acd00584090afd380d4a355be7aa Mon Sep 17 00:00:00 2001 From: alexsweeten Date: Sun, 6 Oct 2024 13:33:42 -0500 Subject: [PATCH 5/5] remove debug artifact --- src/moddotplot/moddotplot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/moddotplot/moddotplot.py b/src/moddotplot/moddotplot.py index 54e1fdc..a3ef32f 100644 --- a/src/moddotplot/moddotplot.py +++ b/src/moddotplot/moddotplot.py @@ -886,7 +886,6 @@ def main(): else: win = math.ceil(seq_length / args.resolution) - print(win) if win < args.modimizer: args.modimizer = win if win < 10: