-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
clean up gsea script and add comments
- Loading branch information
1 parent
16abf0a
commit e1dea62
Showing
1 changed file
with
114 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,129 +1,131 @@ | ||
# cat ~/sjpp/test.txt | python gsea.py | ||
import blitzgsea as blitz | ||
import json | ||
import time | ||
import sys | ||
import sqlite3 | ||
import os | ||
import numpy as np | ||
import pandas as pd | ||
|
||
import blitzgsea as blitz | ||
import json | ||
import time | ||
import sys | ||
import sqlite3 | ||
import os | ||
import numpy as np | ||
import pandas as pd | ||
|
||
# Helper function to extract gene symbols from a dictionary | ||
def extract_symbols(x): | ||
return x['symbol'] | ||
|
||
def extract_plot_data(signature, geneset, library, result, center=True): | ||
signature = signature.copy() | ||
signature.columns = ["i","v"] | ||
signature = signature.sort_values("v", ascending=False).set_index("i") | ||
signature = signature[~signature.index.duplicated(keep='first')] | ||
if center: | ||
signature.loc[:,"v"] -= np.mean(signature.loc[:,"v"]) | ||
signature_map = {} | ||
for i,h in enumerate(signature.index): | ||
signature_map[h] = i | ||
|
||
gs = set(library[geneset]) | ||
hits = [i for i,x in enumerate(signature.index) if x in gs] | ||
|
||
running_sum, es = blitz.enrichment_score(np.array(np.abs(signature.iloc[:,0])), signature_map, gs) | ||
running_sum = list(running_sum) | ||
nn = np.where(np.abs(running_sum)==np.max(np.abs(running_sum)))[0][0] | ||
#print ("nn:",nn) | ||
#print ("running_sum:",running_sum) | ||
#print ("es:",es) | ||
running_sum_str=[str(elem) for elem in running_sum] | ||
print ('result: {"nn":'+str(nn)+',"running_sum":"'+",".join(running_sum_str)+'","es":'+str(es)+'}') | ||
|
||
return x['symbol'] # Return the 'symbol' field from the dictionary | ||
|
||
# Main function | ||
# Main function | ||
try: | ||
# Try to read a single character from stdin without blocking | ||
# Check if there is input from stdin | ||
if sys.stdin.read(1): | ||
# Read from stdin | ||
# Read each line from stdin | ||
for line in sys.stdin: | ||
# Process each line | ||
# Parse the JSON input | ||
json_object = json.loads(line) | ||
cachedir=json_object['cachedir'] | ||
genes=json_object['genes'] | ||
fold_change=json_object['fold_change'] | ||
table_name=json_object['geneset_group'] | ||
filter_non_coding_genes=json_object['filter_non_coding_genes'] | ||
df = {'Genes': genes, 'fold_change': fold_change} | ||
signature=pd.DataFrame(df) | ||
db=json_object['db'] | ||
cachedir = json_object['cachedir'] # Get the cache directory from the JSON object | ||
genes = json_object['genes'] # Get the genes from the JSON object | ||
fold_change = json_object['fold_change'] # Get the fold change values from the JSON object | ||
table_name = json_object['geneset_group'] # Get the gene set group from the JSON object | ||
filter_non_coding_genes = json_object['filter_non_coding_genes'] # Get the filter_non_coding_genes flag from the JSON object | ||
db = json_object['db'] # Get the database path from the JSON object | ||
|
||
# Create a DataFrame for the signature | ||
df = {'Genes': genes, 'fold_change': fold_change} # Create a dictionary with genes and fold change | ||
signature = pd.DataFrame(df) # Convert the dictionary to a DataFrame | ||
|
||
# Connect to the SQLite database | ||
conn = sqlite3.connect(db) | ||
conn = sqlite3.connect(db) # Connect to the SQLite database | ||
cursor = conn.cursor() # Create a cursor object | ||
|
||
# Create a cursor object using the cursor() method | ||
cursor = conn.cursor() | ||
# Query to get gene set IDs | ||
query = f"SELECT id FROM terms WHERE parent_id='{table_name}'" # SQL query to get gene set IDs | ||
cursor.execute(query) # Execute the query | ||
|
||
# SQL query to select all data from the table | ||
query = f"select id from terms where parent_id='" + table_name + "'" | ||
# Execute the SQL query | ||
cursor.execute(query) | ||
if filter_non_coding_genes == True: | ||
# SQL query to code all the protein coding genes | ||
coding_genes_query = f"select * from codingGenes" | ||
genedb = json_object['genedb'] | ||
gene_conn = sqlite3.connect(genedb) | ||
gene_cursor = gene_conn.cursor() | ||
gene_cursor.execute(coding_genes_query) | ||
coding_genes_list=gene_cursor.fetchall() | ||
coding_genes_list=list(map(lambda x: x[0],coding_genes_list)) | ||
signature=signature[signature['Genes'].isin(coding_genes_list)] | ||
|
||
# Fetch all rows from the executed SQL query | ||
rows = cursor.fetchall() | ||
# Filter out non-coding genes if specified | ||
if filter_non_coding_genes: | ||
coding_genes_query = "SELECT * FROM codingGenes" # SQL query to get coding genes | ||
genedb = json_object['genedb'] # Get the gene database path from the JSON object | ||
gene_conn = sqlite3.connect(genedb) # Connect to the gene database | ||
gene_cursor = gene_conn.cursor() # Create a cursor object for the gene database | ||
gene_cursor.execute(coding_genes_query) # Execute the query to get coding genes | ||
coding_genes_list = gene_cursor.fetchall() # Fetch all coding genes | ||
coding_genes_list = list(map(lambda x: x[0], coding_genes_list)) # Extract the gene symbols | ||
signature = signature[signature['Genes'].isin(coding_genes_list)] # Filter the signature to include only coding genes | ||
|
||
start_loop_time = time.time() | ||
msigdb_library={} | ||
# Iterate over the rows and print them | ||
# Fetch all gene set IDs | ||
rows = cursor.fetchall() # Fetch all rows from the executed query | ||
|
||
start_loop_time = time.time() # Record the start time of the loop | ||
msigdb_library = {} # Initialize an empty dictionary for the gene set library | ||
|
||
# Iterate over gene set IDs and fetch corresponding genes | ||
for row in rows: | ||
#print(row[0]) | ||
query2=f"select genes from term2genes where id='" + row[0] + "'" | ||
cursor.execute(query2) | ||
rows2 = cursor.fetchall() | ||
row3=json.loads(rows2[0][0]) | ||
msigdb_library[row[0]] = list(map(extract_symbols,row3)) | ||
query2 = f"SELECT genes FROM term2genes WHERE id='{row[0]}'" # SQL query to get genes for a gene set ID | ||
cursor.execute(query2) # Execute the query | ||
rows2 = cursor.fetchall() # Fetch all rows from the executed query | ||
row3 = json.loads(rows2[0][0]) # Parse the JSON data | ||
msigdb_library[row[0]] = list(map(extract_symbols, row3)) # Extract gene symbols and add to the library | ||
|
||
#print ("msigdb_library:",msigdb_library) | ||
# Close the cursor and connection to the database | ||
cursor.close() | ||
conn.close() | ||
stop_loop_time = time.time() | ||
execution_time = stop_loop_time - start_loop_time | ||
print(f"Execution time: {execution_time} seconds") | ||
try: # Extract ES data to be plotted on client side | ||
geneset_name=json_object['geneset_name'] # Checks if geneset_name is present, if yes it indicates the server request is for generating the image. It retrieves the result.pkl file and generates the image without having to recompute gsea again. | ||
pickle_file=json_object['pickle_file'] | ||
result = pd.read_pickle(os.path.join(cachedir,pickle_file)) | ||
fig = blitz.plot.running_sum(signature, geneset_name, msigdb_library, result=result.T, compact=True) | ||
random_num = np.random.rand() | ||
png_filename = "gsea_plot_" + str(random_num) + ".png" | ||
fig.savefig(os.path.join(cachedir,png_filename), bbox_inches='tight') | ||
#extract_plot_data(signature, geneset_name, msigdb_library, result) # This returns raw data to client side, not currently used | ||
print ('image: {"image_file":"' + png_filename + '"}') | ||
except KeyError: #Initial GSEA calculation, result saved to a result.pkl pickle file | ||
# run enrichment analysis | ||
start_gsea_time = time.time() | ||
if __name__ == "__main__": | ||
result = blitz.gsea(signature, msigdb_library).T | ||
random_num = np.random.rand() | ||
pickle_filename="gsea_result_"+ str(random_num) +".pkl" | ||
result.to_pickle(os.path.join(cachedir,pickle_filename)) | ||
gsea_str='{"data":' + result.to_json() + '}' | ||
pickle_str='{"pickle_file":"' + pickle_filename + '"}' | ||
#print ("pickle_file:",pickle_str) | ||
gsea_dict = json.loads(gsea_str) | ||
pickle_dict = json.loads(pickle_str) | ||
result_dict = {**gsea_dict, **pickle_dict} | ||
print ("result:",json.dumps(result_dict)) | ||
stop_gsea_time = time.time() | ||
gsea_time = stop_gsea_time - start_gsea_time | ||
print (f"GSEA time: {gsea_time} seconds") | ||
cursor.close() # Close the cursor | ||
conn.close() # Close the connection | ||
|
||
stop_loop_time = time.time() # Record the stop time of the loop | ||
execution_time = stop_loop_time - start_loop_time # Calculate the execution time | ||
print(f"Execution time: {execution_time} seconds") # Print the execution time | ||
|
||
try: | ||
# Check if geneset_name and pickle_file are present for generating the plot | ||
geneset_name = json_object['geneset_name'] # Get the gene set name from the JSON object | ||
pickle_file = json_object['pickle_file'] # Get the pickle file name from the JSON object | ||
result = pd.read_pickle(os.path.join(cachedir, pickle_file)) # Load the result from the pickle file | ||
fig = blitz.plot.running_sum(signature, geneset_name, msigdb_library, result=result.T, compact=True) # Generate the running sum plot | ||
random_num = np.random.rand() # Generate a random number | ||
png_filename = f"gsea_plot_{random_num}.png" # Create a filename for the plot | ||
fig.savefig(os.path.join(cachedir, png_filename), bbox_inches='tight') # Save the plot as a PNG file | ||
print(f'image: {{"image_file": "{png_filename}"}}') # Print the image file path in JSON format | ||
except KeyError: | ||
# Initial GSEA calculation and save the result to a pickle file | ||
start_gsea_time = time.time() # Record the start time of GSEA | ||
if __name__ == "__main__": | ||
result = blitz.gsea(signature, msigdb_library, permutations=1000).T # Perform GSEA and transpose the result | ||
random_num = np.random.rand() # Generate a random number | ||
pickle_filename = f"gsea_result_{random_num}.pkl" # Create a filename for the pickle file | ||
result.to_pickle(os.path.join(cachedir, pickle_filename)) # Save the result to the pickle file | ||
gsea_str = f'{{"data": {result.to_json()}}}' # Convert the result to JSON format | ||
pickle_str = f'{{"pickle_file": "{pickle_filename}"}}' # Create a JSON string for the pickle file | ||
gsea_dict = json.loads(gsea_str) # Parse the JSON string | ||
pickle_dict = json.loads(pickle_str) # Parse the JSON string | ||
result_dict = {**gsea_dict, **pickle_dict} # Merge the dictionaries | ||
print(f"result: {json.dumps(result_dict)}") # Print the result in JSON format | ||
stop_gsea_time = time.time() # Record the stop time of GSEA | ||
gsea_time = stop_gsea_time - start_gsea_time # Calculate the GSEA execution time | ||
print(f"GSEA time: {gsea_time} seconds") # Print the GSEA execution time | ||
else: | ||
pass | ||
pass # Do nothing if there is no input from stdin | ||
except (EOFError, IOError): | ||
pass | ||
pass # Handle EOFError and IOError exceptions gracefully | ||
|
||
# Function to extract plot data for GSEA visualization | ||
# def extract_plot_data(signature, geneset, library, result, center=True): | ||
|
||
# print("signature", signature) | ||
# print("result", result) | ||
# print("geneset", geneset) | ||
# print("library", library) | ||
# signature = signature.copy() # Create a copy of the signature DataFrame | ||
# signature.columns = ["i", "v"] # Rename columns to 'i' and 'v' | ||
# signature = signature.sort_values("v", ascending=False).set_index("i") # Sort by 'v' in descending order and set 'i' as index | ||
# signature = signature[~signature.index.duplicated(keep='first')] # Remove duplicate indices, keeping the first occurrence | ||
|
||
# if center: | ||
# signature.loc[:, "v"] -= np.mean(signature.loc[:, "v"]) # Center the signature values by subtracting the mean | ||
|
||
# signature_map = {h: i for i, h in enumerate(signature.index)} # Create a mapping of signature indices | ||
|
||
# gs = set(library[geneset]) # Get the gene set from the library | ||
# hits = [i for i, x in enumerate(signature.index) if x in gs] # Find the indices of hits in the signature | ||
|
||
# running_sum, es = blitz.enrichment_score(np.array(np.abs(signature.iloc[:, 0])), signature_map, gs) # Compute running sum and enrichment score | ||
# running_sum = list(running_sum) # Convert running sum to a list | ||
# nn = np.where(np.abs(running_sum) == np.max(np.abs(running_sum)))[0][0] # Find the index of the maximum absolute running sum | ||
|
||
# running_sum_str = [str(elem) for elem in running_sum] # Convert running sum elements to strings | ||
# print(f'result: {{"nn": {nn}, "running_sum": "{",".join(running_sum_str)}", "es": {es}}}') # Print the result in JSON format |