Skip to content

Commit

Permalink
fix benchmark wf
Browse files Browse the repository at this point in the history
  • Loading branch information
rcannood committed Sep 22, 2024
1 parent 505b458 commit 961e3e6
Show file tree
Hide file tree
Showing 14 changed files with 30 additions and 25 deletions.
10 changes: 1 addition & 9 deletions scripts/create_resources/resources.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,11 @@ REPO_ROOT=$(git rev-parse --show-toplevel)
# ensure that the command below is run from the root of the repository
cd "$REPO_ROOT"

# remove this when you have implemented the script
echo "TODO: once the 'process_datasets' workflow is implemented, update this script to use it."
echo " Step 1: replace 'task_batch_integration' with the name of the task in the following command."
echo " Step 2: replace the rename keys parameters to fit your process_dataset inputs"
echo " Step 3: replace the settings parameter to fit your process_dataset outputs"
echo " Step 4: remove this message"
exit 1

cat > /tmp/params.yaml << 'HERE'
input_states: s3://openproblems-data/resources/datasets/**/state.yaml
rename_keys: 'input:output_dataset'
output_state: '$id/state.yaml'
settings: '{"output_train": "$id/train.h5ad", "output_test": "$id/test.h5ad"}'
settings: '{"output_dataset": "$id/dataset.h5ad", "output_solution": "$id/solution.h5ad"}'
publish_dir: s3://openproblems-data/resources/task_batch_integration/datasets/
HERE

Expand Down
2 changes: 1 addition & 1 deletion scripts/run_benchmark/run_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ publish_dir="resources/results/${RUN_ID}"

# write the parameters to file
cat > /tmp/params.yaml << HERE
input_states: s3://openproblems-data/resources_test/task_batch_integration/**/state.yaml
input_states: resources_test/task_batch_integration/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
Expand Down
2 changes: 1 addition & 1 deletion src/control_methods/embed_cell_types/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: embed_cell_types
label: Embed cell types
summary: Cells are embedded as a one-hot encoding of celltype labels
description: Cells are embedded as a one-hot encoding of celltype labels

info:
method_types: [embedding]
preferred_normalization: log_cp10k
resources:
- type: python_script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ summary: Cells are embedded as a one-hot encoding of celltype labels, with a sma
amount of random noise added to the embedding
description: Cells are embedded as a one-hot encoding of celltype labels, with a small
amount of random noise added to the embedding

info:
method_types: [embedding]
preferred_normalization: log_cp10k
arguments:
- name: --jitter
Expand Down
2 changes: 1 addition & 1 deletion src/control_methods/no_integration/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: no_integration
label: No integration
summary: Original feature space is not modified
description: Original feature space is not modified

info:
method_types: [embedding]
preferred_normalization: log_cp10k
resources:
- type: python_script
Expand Down
2 changes: 1 addition & 1 deletion src/control_methods/no_integration_batch/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: no_integration_batch
label: No integration by Batch
summary: Cells are embedded by computing PCA independently on each batch
description: Cells are embedded by computing PCA independently on each batch

info:
method_types: [embedding]
preferred_normalization: log_cp10k
resources:
- type: python_script
Expand Down
1 change: 1 addition & 0 deletions src/control_methods/shuffle_integration/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ label: Shuffle integration
summary: Integrations are randomly permuted
description: Integrations are randomly permuted
info:
method_types: [feature]
preferred_normalization: log_cp10k
resources:
- type: python_script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ label: Shuffle integration by batch
summary: Integrations are randomly permuted within each batch
description: Integrations are randomly permuted within each batch
info:
method_types: [feature]
preferred_normalization: log_cp10k
resources:
- type: python_script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ label: Shuffle integration by cell type
summary: Integrations are randomly permuted within each cell type
description: Integrations are randomly permuted within each cell type
info:
method_types: [feature]
preferred_normalization: log_cp10k
resources:
- type: python_script
Expand Down
13 changes: 7 additions & 6 deletions src/data_processors/transform/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@

print("Checking shapes", flush=True)
assert integrated.shape[0] == dataset.shape[0], "Number of cells do not match"
assert integrated.shape[1] == dataset.shape[1], "Number of genes do not match"

print("Checking index", flush=True)
if not integrated.obs.index.equals(dataset.obs.index):
assert integrated.obs.index.sort_values().equals(dataset.obs.index.sort_values()), "Cell names do not match"
print("Reordering cells", flush=True)
integrated = integrated[dataset.obs.index]

if "corrected_counts" in integrated.layers.keys() and \
not integrated.var.index.equals(dataset.var.index):
assert integrated.var.index.sort_values().equals(dataset.var.index.sort_values()), "Gene names do not match"
print("Reordering genes", flush=True)
integrated = integrated[:, dataset.var.index]
if "corrected_counts" in integrated.layers.keys():
assert integrated.shape[1] == dataset.shape[1], "Number of genes do not match"

if not integrated.var.index.equals(dataset.var.index):
assert integrated.var.index.sort_values().equals(dataset.var.index.sort_values()), "Gene names do not match"
print("Reordering genes", flush=True)
integrated = integrated[:, dataset.var.index]

print("Checking method output based on type", flush=True)
if "feature" in par["expected_method_types"]:
Expand Down
2 changes: 2 additions & 0 deletions src/methods/fastmnn/script.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ output <- anndata::AnnData(
layers = list(
corrected_counts = t(layer)
),
obs = adata$obs[, c()],
var = adata$var[, c()],
obsm = list(
X_emb = obsm
),
Expand Down
1 change: 1 addition & 0 deletions src/methods/liger/script.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ output <- anndata::AnnData(
normalization_id = adata$uns[["normalization_id"]],
method_id = meta$name
),
obs = adata$obs[, c()],
obsm = list(
X_emb = lobj@H.norm[rownames(adata), , drop = FALSE]
),
Expand Down
2 changes: 2 additions & 0 deletions src/methods/mnn_correct/script.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ output <- anndata::AnnData(
normalization_id = adata$uns[["normalization_id"]],
method_id = meta$name
),
obs = adata$obs[, c()],
var = adata$var[, c()],
layers = list(
corrected_counts = as(t(layer), "sparseMatrix")
),
Expand Down
14 changes: 9 additions & 5 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ workflow run_wf {

// use 'fromState' to fetch the arguments the component requires from the overall state
fromState: { id, state, comp ->
def new_args = []
def new_args = [:]
if (comp.config.info.type == "method") {
new_args.input = state.input_dataset
} else if (comp.config.info.type == "control_method") {
Expand All @@ -116,8 +116,12 @@ workflow run_wf {
)

| transform.run(
fromState: [input: "method_output"],
toState: { id, state, output ->
fromState: [
input_integrated: "method_output",
input_dataset: "input_dataset",
expected_method_types: "method_types"
],
toState: { id, output, state ->
def method_types_cleaned = []
if ("feature" in state.method_types) {
method_types_cleaned += ["feature", "embedding", "graph"]
Expand All @@ -132,7 +136,7 @@ workflow run_wf {
method_types_cleaned: method_types_cleaned
]

[id, new_state]
new_state
}
)

Expand All @@ -143,7 +147,7 @@ workflow run_wf {
id + "." + comp.config.name
},
filter: { id, state, comp ->
comp.info.metric_type in state.method_types_cleaned
comp.config.info.metric_type in state.method_types_cleaned
},
// use 'fromState' to fetch the arguments the component requires from the overall state
fromState: [
Expand Down

0 comments on commit 961e3e6

Please sign in to comment.