From b5b2dd5dc4c3a7a850a6652531e901a6186f6464 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 22 Jan 2024 11:23:43 -0800 Subject: [PATCH 01/31] init scripts --- tools/models/metrics/ontology_mapper.py | 533 ++++++++++++++++++++++++ tools/models/metrics/run-scib.py | 229 ++++++++++ 2 files changed, 762 insertions(+) create mode 100644 tools/models/metrics/ontology_mapper.py create mode 100644 tools/models/metrics/run-scib.py diff --git a/tools/models/metrics/ontology_mapper.py b/tools/models/metrics/ontology_mapper.py new file mode 100644 index 000000000..3f266a7e5 --- /dev/null +++ b/tools/models/metrics/ontology_mapper.py @@ -0,0 +1,533 @@ +""" +Provides classes to recreate cell type and tissue mappings as used in CELLxGENE Discover + +- OntologyMapper abstract class to create other mappers +- SystemMapper to map any tissue to a System +- OrganMapper to map any tissue to an Organ +- TissueGeneralMapper to map any tissue to another tissue as shown in Gene Expression and Census +- CellClassMapper to map any cell type to a Cell Class +- CellSubclassMapper to map any cell type to a Cell Subclass + +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Union + +import owlready2 + + +class OntologyMapper(ABC): + # Terms to ignore when mapping + BLOCK_LIST = [ + "BFO_0000004", + "CARO_0000000", + "CARO_0030000", + "CARO_0000003", + "NCBITaxon_6072", + "Thing", + ] + + def __init__( + self, + high_level_ontology_term_ids: List[str], + ontology_owl_path: Union[str, os.PathLike], + root_ontology_term_id: str, + ): + self._cached_high_level_terms = {} + self._cached_labels = {} + self.high_level_terms = high_level_ontology_term_ids + self.root_ontology_term_id = root_ontology_term_id + + # TODO improve this. First time it loads it raises a TypeError for CL. But redoing it loads it correctly + # The type error is + # 'http://purl.obolibrary.org/obo/IAO_0000028' belongs to more than one entity + # types (cannot be both a property and a class/an individual)! + # So we retry only once + try: + self._ontology = owlready2.get_ontology(ontology_owl_path).load() + except TypeError: + self._ontology = owlready2.get_ontology(ontology_owl_path).load() + + def get_high_level_terms(self, ontology_term_id: str) -> List[str]: + """ + Returns the associated high-level ontology term IDs from any other ID + """ + + ontology_term_id = self.reformat_ontology_term_id(ontology_term_id, to_writable=False) + + if ontology_term_id in self._cached_high_level_terms: + return self._cached_high_level_terms[ontology_term_id] + + owl_entity = self._get_entity_from_id(ontology_term_id) + + # If not found as an ontology ID raise + if not owl_entity: + raise ValueError("ID not found in the ontology.") + + # List ancestors for this entity, including itself if it is in the list of high level terms + ancestors = [owl_entity.name] if ontology_term_id in self.high_level_terms else [] + + branch_ancestors = self._get_branch_ancestors(owl_entity) + # Ignore branch ancestors if they are not under the root node + if branch_ancestors: + if self.root_ontology_term_id in branch_ancestors: + ancestors.extend(branch_ancestors) + + # Check if there's at least one top-level entity in the list of ancestors, and add them to + # the return list of high level term. Always include itself + resulting_high_level_terms = [] + for high_level_term in self.high_level_terms: + if high_level_term in ancestors: + resulting_high_level_terms.append(high_level_term) + + # If no valid high level terms return None + if len(resulting_high_level_terms) == 0: + resulting_high_level_terms.append(None) + + resulting_high_level_terms = [ + self.reformat_ontology_term_id(i, to_writable=True) for i in resulting_high_level_terms + ] + self._cached_high_level_terms[ontology_term_id] = resulting_high_level_terms + + return resulting_high_level_terms + + def get_top_high_level_term(self, ontology_term_id: str) -> str: + """ + Return the top high level term + """ + + return self.get_high_level_terms(ontology_term_id)[0] + + @abstractmethod + def _get_branch_ancestors(self, owl_entity): + """ + Gets ALL ancestors from an owl entity. What's defined as an ancestor depends on the mapper type, for + example CL ancestors are likely to just include is_a relationship + """ + + def get_label_from_id(self, ontology_term_id: str): + """ + Returns the label from and ontology term id that is in writable form + Example: "UBERON:0002048" returns "lung" + Example: "UBERON_0002048" raises ValueError because the ID is not in writable form + """ + + if ontology_term_id in self._cached_labels: + return self._cached_labels[ontology_term_id] + + if ontology_term_id is None: + return None + + entity = self._get_entity_from_id(self.reformat_ontology_term_id(ontology_term_id, to_writable=False)) + if entity: + result = entity.label[0] + else: + result = ontology_term_id + + self._cached_labels[ontology_term_id] = result + return result + + @staticmethod + def reformat_ontology_term_id(ontology_term_id: str, to_writable: bool = True): + """ + Converts ontology term id string between two formats: + - `to_writable == True`: from "UBERON_0002048" to "UBERON:0002048" + - `to_writable == False`: from "UBERON:0002048" to "UBERON_0002048" + """ + + if ontology_term_id is None: + return None + + if to_writable: + if ontology_term_id.count("_") != 1: + raise ValueError(f"{ontology_term_id} is an invalid ontology term id, it must contain exactly one '_'") + return ontology_term_id.replace("_", ":") + else: + if ontology_term_id.count(":") != 1: + raise ValueError(f"{ontology_term_id} is an invalid ontology term id, it must contain exactly one ':'") + return ontology_term_id.replace(":", "_") + + def _list_ancestors(self, entity: owlready2.entity.ThingClass, ancestors: List[str] = []) -> List[str]: + """ + Recursive function that given an entity of an ontology, it traverses the ontology and returns + a list of all ancestors associated with the entity. + """ + + if self._is_restriction(entity): + # Entity is a restriction, check for part_of relationship + + prop = entity.property.name + if prop != "BFO_0000050": + # BFO_0000050 is "part of" + return ancestors + ancestors.append(entity.value.name.replace("obo.", "")) + + # Check for ancestors of restriction + self._list_ancestors(entity.value, ancestors) + return ancestors + + elif self._is_entity(entity) and not self._is_and_object(entity): + # Entity is a superclass, check for is_a relationships + + if entity.name in self.BLOCK_LIST: + return ancestors + ancestors.append(entity.name) + + # Check for ancestors of superclass + for super_entity in entity.is_a: + self._list_ancestors(super_entity, ancestors) + return ancestors + + def _get_entity_from_id(self, ontology_term_id: str) -> owlready2.entity.ThingClass: + """ + Given a readable ontology term id (e.g. "UBERON_0002048"), it returns the associated ontology entity + """ + return self._ontology.search_one(iri=f"http://purl.obolibrary.org/obo/{ontology_term_id}") + + @staticmethod + def _is_restriction(entity: owlready2.entity.ThingClass) -> bool: + return hasattr(entity, "value") + + @staticmethod + def _is_entity(entity: owlready2.entity.ThingClass) -> bool: + return hasattr(entity, "name") + + @staticmethod + def _is_and_object(entity: owlready2.entity.ThingClass) -> bool: + return hasattr(entity, "Classes") + + +class CellMapper(OntologyMapper): + # From schema 3.1.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/3.1.0/schema.md + CXG_CL_ONTOLOGY_URL = "https://github.com/obophenotype/cell-ontology/releases/download/v2023-07-20/cl.owl" + # Only look up ancestors under Cell + ROOT_NODE = "CL_0000000" + + def __init__(self, cell_type_high_level_ontology_term_ids: List[str]): + super(CellMapper, self).__init__( + high_level_ontology_term_ids=cell_type_high_level_ontology_term_ids, + ontology_owl_path=self.CXG_CL_ONTOLOGY_URL, + root_ontology_term_id=self.ROOT_NODE, + ) + + def _get_branch_ancestors(self, owl_entity): + branch_ancestors = [] + for is_a in self._get_is_a_for_cl(owl_entity): + branch_ancestors = self._list_ancestors(is_a, branch_ancestors) + + return set(branch_ancestors) + + @staticmethod + def _get_is_a_for_cl(owl_entity): + # TODO make this a recurrent function instead of 2-level for nested loop + result = [] + for is_a in owl_entity.is_a: + if CellMapper._is_entity(is_a): + result.append(is_a) + elif CellMapper._is_and_object(is_a): + for is_a_2 in is_a.get_Classes(): + if CellMapper._is_entity(is_a_2): + result.append(is_a_2) + + return result + + +class TissueMapper(OntologyMapper): + # From schema 3.1.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/3.1.0/schema.md + CXG_UBERON_ONTOLOGY_URL = "https://github.com/obophenotype/uberon/releases/download/v2023-06-28/uberon.owl" + + # Only look up ancestors under anatomical entity + ROOT_NODE = "UBERON_0001062" + + def __init__(self, tissue_high_level_ontology_term_ids: List[str]): + self.cell_type_high_level_ontology_term_ids = tissue_high_level_ontology_term_ids + super(TissueMapper, self).__init__( + high_level_ontology_term_ids=tissue_high_level_ontology_term_ids, + ontology_owl_path=self.CXG_UBERON_ONTOLOGY_URL, + root_ontology_term_id=self.ROOT_NODE, + ) + + def _get_branch_ancestors(self, owl_entity): + branch_ancestors = [] + for is_a in owl_entity.is_a: + branch_ancestors = self._list_ancestors(is_a, branch_ancestors) + + return set(branch_ancestors) + + +class OrganMapper(TissueMapper): + # List of tissue classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + ORGANS = [ + "UBERON_0000992", # ovary + "UBERON_0000029", # lymph node + "UBERON_0002048", # lung + "UBERON_0002110", # gallbladder + "UBERON_0001043", # esophagus + "UBERON_0003889", # fallopian tube + "UBERON_0018707", # bladder organ + "UBERON_0000178", # blood + "UBERON_0002371", # bone marrow + "UBERON_0000955", # brain + "UBERON_0000310", # breast + "UBERON_0000970", # eye + "UBERON_0000948", # heart + "UBERON_0000160", # intestine + "UBERON_0002113", # kidney + "UBERON_0002107", # liver + "UBERON_0000004", # nose + "UBERON_0001264", # pancreas + "UBERON_0001987", # placenta + "UBERON_0002097", # skin of body + "UBERON_0002240", # spinal cord + "UBERON_0002106", # spleen + "UBERON_0000945", # stomach + "UBERON_0002370", # thymus + "UBERON_0002046", # thyroid gland + "UBERON_0001723", # tongue + "UBERON_0000995", # uterus + "UBERON_0001013", # adipose tissue + ] + + def __init__(self): + super().__init__(tissue_high_level_ontology_term_ids=self.ORGANS) + + +class SystemMapper(TissueMapper): + # List of tissue classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + SYSTEMS = [ + "UBERON_0001017", # central nervous system + "UBERON_0000010", # peripheral nervous system + "UBERON_0001016", # nervous system + "UBERON_0001009", # circulatory system + "UBERON_0002390", # hematopoietic system + "UBERON_0004535", # cardiovascular system + "UBERON_0001004", # respiratory system + "UBERON_0001007", # digestive system + "UBERON_0000922", # embryo + "UBERON_0000949", # endocrine system + "UBERON_0002330", # exocrine system + "UBERON_0002405", # immune system + "UBERON_0001434", # skeletal system + "UBERON_0000383", # musculature of body + "UBERON_0001008", # renal system + "UBERON_0000990", # reproductive system + "UBERON_0001032", # sensory system + ] + + def __init__(self): + super().__init__(tissue_high_level_ontology_term_ids=self.SYSTEMS) + + +class TissueGeneralMapper(TissueMapper): + # List of tissue classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + TISSUE_GENERAL = [ + "UBERON_0000178", # blood + "UBERON_0002048", # lung + "UBERON_0002106", # spleen + "UBERON_0002371", # bone marrow + "UBERON_0002107", # liver + "UBERON_0002113", # kidney + "UBERON_0000955", # brain + "UBERON_0002240", # spinal cord + "UBERON_0000310", # breast + "UBERON_0000948", # heart + "UBERON_0002097", # skin of body + "UBERON_0000970", # eye + "UBERON_0001264", # pancreas + "UBERON_0001043", # esophagus + "UBERON_0001155", # colon + "UBERON_0000059", # large intestine + "UBERON_0002108", # small intestine + "UBERON_0000160", # intestine + "UBERON_0000945", # stomach + "UBERON_0001836", # saliva + "UBERON_0001723", # tongue + "UBERON_0001013", # adipose tissue + "UBERON_0000473", # testis + "UBERON_0002367", # prostate gland + "UBERON_0000057", # urethra + "UBERON_0000056", # ureter + "UBERON_0003889", # fallopian tube + "UBERON_0000995", # uterus + "UBERON_0000992", # ovary + "UBERON_0002110", # gall bladder + "UBERON_0001255", # urinary bladder + "UBERON_0018707", # bladder organ + "UBERON_0000922", # embryo + "UBERON_0004023", # ganglionic eminence --> this a part of the embryo, remove in case generality is desired + "UBERON_0001987", # placenta + "UBERON_0007106", # chorionic villus + "UBERON_0002369", # adrenal gland + "UBERON_0002368", # endocrine gland + "UBERON_0002365", # exocrine gland + "UBERON_0000030", # lamina propria + "UBERON_0000029", # lymph node + "UBERON_0004536", # lymph vasculature + "UBERON_0001015", # musculature + "UBERON_0000004", # nose + "UBERON_0003688", # omentum + "UBERON_0000977", # pleura + "UBERON_0002370", # thymus + "UBERON_0002049", # vasculature + "UBERON_0009472", # axilla + "UBERON_0001087", # pleural fluid + "UBERON_0000344", # mucosa + "UBERON_0001434", # skeletal system + "UBERON_0002228", # rib + "UBERON_0003129", # skull + "UBERON_0004537", # blood vasculature + "UBERON_0002405", # immune system + "UBERON_0001009", # circulatory system + "UBERON_0001007", # digestive system + "UBERON_0001017", # central nervous system + "UBERON_0001008", # renal system + "UBERON_0000990", # reproductive system + "UBERON_0001004", # respiratory system + "UBERON_0000010", # peripheral nervous system + "UBERON_0001032", # sensory system + "UBERON_0002046", # thyroid gland + "UBERON_0004535", # cardiovascular system + "UBERON_0000949", # endocrine system + "UBERON_0002330", # exocrine system + "UBERON_0002390", # hematopoietic system + "UBERON_0000383", # musculature of body + "UBERON_0001465", # knee + "UBERON_0001016", # nervous system + "UBERON_0001348", # brown adipose tissue + "UBERON_0015143", # mesenteric fat pad + "UBERON_0000175", # pleural effusion + "UBERON_0001416", # skin of abdomen + "UBERON_0001868", # skin of chest + "UBERON_0001511", # skin of leg + "UBERON_0002190", # subcutaneous adipose tissue + "UBERON_0000014", # zone of skin + "UBERON_0000916", # abdomen + ] + + def __init__(self): + super().__init__(tissue_high_level_ontology_term_ids=self.TISSUE_GENERAL) + + +class CellClassMapper(CellMapper): + # List of cell classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + CELL_CLASS = [ + "CL_0002494", # cardiocyte + "CL_0002320", # connective tissue cell + "CL_0000473", # defensive cell + "CL_0000066", # epithelial cell + "CL_0000988", # hematopoietic cell + "CL_0002319", # neural cell + "CL_0011115", # precursor cell + "CL_0000151", # secretory cell + "CL_0000039", # NEW germ cell line + "CL_0000064", # NEW ciliated cell + "CL_0000183", # NEW contractile cell + "CL_0000188", # NEW cell of skeletal muscle + "CL_0000219", # NEW motile cell + "CL_0000325", # NEW stuff accumulating cell + "CL_0000349", # NEW extraembryonic cell + "CL_0000586", # NEW germ cell + "CL_0000630", # NEW supporting cell + "CL_0001035", # NEW bone cell + "CL_0001061", # NEW abnormal cell + "CL_0002321", # NEW embryonic cell (metazoa) + "CL_0009010", # NEW transit amplifying cell + "CL_1000600", # NEW lower urinary tract cell + "CL_4033054", # NEW perivascular cell + ] + + def __init__(self): + super().__init__(cell_type_high_level_ontology_term_ids=self.CELL_CLASS) + + +class CellSubclassMapper(CellMapper): + # List of cell classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + CELL_SUB_CLASS = [ + "CL_0002494", # cardiocyte + "CL_0000624", # CD4-positive, alpha-beta T cell + "CL_0000625", # CD8-positive, alpha-beta T cell + "CL_0000084", # T cell + "CL_0000236", # B cell + "CL_0000451", # dendritic cell + "CL_0000576", # monocyte + "CL_0000235", # macrophage + "CL_0000542", # lymphocyte + "CL_0000738", # leukocyte + "CL_0000763", # myeloid cell + "CL_0008001", # hematopoietic precursor cell + "CL_0000234", # phagocyte + "CL_0000679", # glutamatergic neuron + "CL_0000617", # GABAergic neuron + "CL_0000099", # interneuron + "CL_0000125", # glial cell + "CL_0000101", # sensory neuron + "CL_0000100", # motor neuron + "CL_0000117", # CNS neuron (sensu Vertebrata) + "CL_0000540", # neuron + "CL_0000669", # pericyte + "CL_0000499", # stromal cell + "CL_0000057", # fibroblast + "CL_0000152", # exocrine cell + "CL_0000163", # endocrine cell + "CL_0000115", # endothelial cell + "CL_0002076", # endo-epithelial cell + "CL_0002078", # meso-epithelial cell + "CL_0011026", # progenitor cell + "CL_0000015", # NEW male germ cell + "CL_0000021", # NEW female germ cell + "CL_0000034", # NEW stem cell + "CL_0000055", # NEW non-terminally differentiated cell + "CL_0000068", # NEW duct epithelial cell + "CL_0000075", # NEW columnar/cuboidal epithelial cell + "CL_0000076", # NEW squamous epithelial cell + "CL_0000079", # NEW stratified epithelial cell + "CL_0000082", # NEW epithelial cell of lung + "CL_0000083", # NEW epithelial cell of pancreas + "CL_0000095", # NEW neuron associated cell + "CL_0000098", # NEW sensory epithelial cell + "CL_0000136", # NEW fat cell + "CL_0000147", # NEW pigment cell + "CL_0000150", # NEW glandular epithelial cell + "CL_0000159", # NEW seromucus secreting cell + "CL_0000182", # NEW hepatocyte + "CL_0000186", # NEW myofibroblast cell + "CL_0000187", # NEW muscle cell + "CL_0000221", # NEW ectodermal cell + "CL_0000222", # NEW mesodermal cell + "CL_0000244", # NEW urothelial cell + "CL_0000351", # NEW trophoblast cell + "CL_0000584", # NEW enterocyte + "CL_0000586", # NEW germ cell + "CL_0000670", # NEW primordial germ cell + "CL_0000680", # NEW muscle precursor cell + "CL_0001063", # NEW neoplastic cell + "CL_0002077", # NEW ecto-epithelial cell + "CL_0002222", # NEW vertebrate lens cell + "CL_0002327", # NEW mammary gland epithelial cell + "CL_0002503", # NEW adventitial cell + "CL_0002518", # NEW kidney epithelial cell + "CL_0002535", # NEW epithelial cell of cervix + "CL_0002536", # NEW epithelial cell of amnion + "CL_0005006", # NEW ionocyte + "CL_0008019", # NEW mesenchymal cell + "CL_0008034", # NEW mural cell + "CL_0009010", # NEW transit amplifying cell + "CL_1000296", # NEW epithelial cell of urethra + "CL_1000497", # NEW kidney cell + "CL_2000004", # NEW pituitary gland cell + "CL_2000064", # NEW ovarian surface epithelial cell + "CL_4030031", # NEW interstitial cell + ] + + def __init__(self, map_orphans_to_class: bool = False): + if map_orphans_to_class: + cell_type_high_level = self.CELL_SUB_CLASS + CellClassMapper.CELL_CLASS + else: + cell_type_high_level = self.CELL_SUB_CLASS + super().__init__(cell_type_high_level_ontology_term_ids=cell_type_high_level) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py new file mode 100644 index 000000000..da42c50af --- /dev/null +++ b/tools/models/metrics/run-scib.py @@ -0,0 +1,229 @@ +import datetime +import itertools +import warnings +from typing import List + +import cellxgene_census +import numpy as np +import ontology_mapper +import scanpy as sc +import scib_metrics +from cellxgene_census.experimental import get_embedding + +warnings.filterwarnings("ignore") + + +# human embeddings +CENSUS_VERSION = "2023-12-15" +# EXPERIMENT_NAME = "homo_sapiens" + +# These are embeddings contributed by the community hosted in S3 +embedding_uris_community = { + "scgpt": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-1/", + "uce": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-2/", +} + +# These are embeddings included in the Census data +embedding_names_census = ["geneformer", "scvi"] + +# All embedding names +embs = list(embedding_uris_community.keys()) + embedding_names_census + +census = cellxgene_census.open_soma(census_version=CENSUS_VERSION) + + +def subclass_mapper(): + mapper = ontology_mapper.CellSubclassMapper(map_orphans_to_class=True) + cell_types = ( + census["census_data"]["homo_sapiens"] + .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") + .concat() + .to_pandas() + ) + cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() + subclass_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} + return subclass_dict + + +def class_mapper(): + mapper = ontology_mapper.CellClassMapper() + cell_types = ( + census["census_data"]["homo_sapiens"] + .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") + .concat() + .to_pandas() + ) + cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() + class_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} + return class_dict + + +def build_anndata_with_embeddings( + embedding_uris: dict, + embedding_names: List[str], + coords: List[int] = None, + obs_value_filter: str = None, + column_names=dict, + census_version: str = None, + experiment_name: str = None, +): + """ + For a given set of Census cell coordinates (soma_joinids) + fetch embeddings with TileDBSoma and return the corresponding + AnnData with embeddings slotted in. + + `embedding_uris` is a dict with community embedding names as the keys and S3 URI as the values. + `embedding_names` is a list with embedding names included in Census. + + + Assume that all embeddings provided are coming from the same experiment. + """ + + with cellxgene_census.open_soma(census_version=census_version) as census: + print("Getting anndata with Census embeddings: ", embedding_names) + + ad = cellxgene_census.get_anndata( + census, + organism=experiment_name, + measurement_name="RNA", + obs_value_filter=obs_value_filter, + obs_coords=coords, + obsm_layers=embedding_names, + column_names=column_names, + ) + + for key in embedding_uris: + print("Getting community embedding:", key) + embedding_uri = embedding_uris[key] + ad.obsm[key] = get_embedding(census_version, embedding_uri, ad.obs["soma_joinid"].to_numpy()) + + # Embeddings with missing data contain all NaN, + # so we must find the intersection of non-NaN rows in the fetched embeddings + # and subset the AnnData accordingly. + filt = np.ones(ad.shape[0], dtype="bool") + for key in ad.obsm.keys(): + nan_row_sums = np.sum(np.isnan(ad.obsm[key]), axis=1) + total_columns = ad.obsm[key].shape[1] + filt = filt & (nan_row_sums != total_columns) + ad = ad[filt].copy() + + return ad + + +tissues = ["adipose tissue", "spinal cord", "skin of body", "spleen", "liver"] +tissues = ["adipose tissue", "spinal cord"] +column_names = { + "obs": ["cell_type_ontology_term_id", "cell_type", "assay", "suspension_type", "dataset_id", "soma_joinid"] +} +umap_plot_labels = ["cell_subclass", "cell_class", "cell_type", "dataset_id"] + +block_cell_types = ["native cell", "animal cell", "eukaryotic cell"] + +all_bio = {} +all_batch = {} + +for tissue in tissues: + print("Tissue", tissue, " getting Anndata") + + # Getting anddata + adata_metrics = build_anndata_with_embeddings( + embedding_uris=embedding_uris_community, + embedding_names=embedding_names_census, + obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", + census_version="2023-12-15", + experiment_name="homo_sapiens", + column_names=column_names, + ) + + # Create batch variable + adata_metrics.obs["batch"] = ( + adata_metrics.obs["assay"] + adata_metrics.obs["dataset_id"] + adata_metrics.obs["suspension_type"] + ) + + # Get cell subclass + adata_metrics.obs["cell_subclass"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(subclass_dict) + adata_metrics = adata_metrics[~adata_metrics.obs["cell_subclass"].isna(),] + + # Get cell class + adata_metrics.obs["cell_class"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(class_dict) + adata_metrics = adata_metrics[~adata_metrics.obs["cell_class"].isna(),] + + # Remove cells in block list of cell types + adata_metrics[~adata_metrics.obs["cell_type"].isin(block_cell_types),] + + print("Tissue", tissue, "cells", adata_metrics.n_obs) + + # Calculate neighbors + for emb_name in embs: + print(datetime.datetime.now(), "Getting neighbors", emb_name) + sc.pp.neighbors(adata_metrics, use_rep=emb_name, key_added=emb_name) + sc.tl.umap(adata_metrics, neighbors_key=emb_name) + adata_metrics.obsm["X_umap_" + emb_name] = adata_metrics.obsm["X_umap"].copy() + del adata_metrics.obsm["X_umap"] + + # Save a few UMAPS + print(datetime.datetime.now(), "Saving UMAP plots") + for emb_name in embs: + for label in umap_plot_labels: + title = "_".join(["UMAP", tissue, emb_name, label]) + sc.pl.embedding(adata_metrics, basis="X_umap_" + emb_name, color=label, title=title, save=title + ".png") + + bio_labels = ["cell_subclass", "cell_class"] + metric_bio_results = { + "embedding": [], + "bio_label": [], + "leiden_nmi": [], + "leiden_ari": [], + "silhouette_label": [], + } + + batch_labels = ["batch", "assay", "dataset_id", "suspension_type"] + metric_batch_results = { + "embedding": [], + "batch_label": [], + "silhouette_batch": [], + } + + for bio_label, emb in itertools.product(bio_labels, embs): + print("\n\nSTART", bio_label, emb) + + metric_bio_results["embedding"].append(emb) + metric_bio_results["bio_label"].append(bio_label) + + print(datetime.datetime.now(), "Calculating ARI Leiden") + this_metric = scib_metrics.nmi_ari_cluster_labels_leiden( + X=adata_metrics.obsp[emb + "_connectivities"], + labels=adata_metrics.obs[bio_label], + optimize_resolution=True, + resolution=1.0, + n_jobs=64, + ) + metric_bio_results["leiden_nmi"].append(this_metric["nmi"]) + metric_bio_results["leiden_ari"].append(this_metric["ari"]) + + print(datetime.datetime.now(), "Calculating silhouette labels") + + this_metric = scib_metrics.silhouette_label( + X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label], rescale=True, chunk_size=512 + ) + metric_bio_results["silhouette_label"].append(this_metric) + + for batch_label, emb in itertools.product(batch_labels, embs): + print("\n\nSTART", batch_label, emb) + + metric_batch_results["embedding"].append(emb) + metric_batch_results["batch_label"].append(batch_label) + + print(datetime.datetime.now(), "Calculating silhouette batch") + + this_metric = scib_metrics.silhouette_batch( + X=adata_metrics.obsm[emb], + labels=adata_metrics.obs[bio_label], + batch=adata_metrics.obs[batch_label], + rescale=True, + chunk_size=512, + ) + metric_batch_results["silhouette_batch"].append(this_metric) + + all_bio[tissue] = metric_bio_results + all_batch[tissue] = metric_batch_results From 66f0761d10b599481cefdd45b8dcf914be12b195 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 22 Jan 2024 11:44:59 -0800 Subject: [PATCH 02/31] add vars --- tools/models/metrics/run-scib.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index da42c50af..cd3942b32 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -57,6 +57,9 @@ def class_mapper(): class_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} return class_dict +class_dict = class_mapper() +subclass_dict = subclass_mapper() + def build_anndata_with_embeddings( embedding_uris: dict, From fa17de415da823207142ba9114d4ff2ec5d981f2 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 22 Jan 2024 13:14:48 -0800 Subject: [PATCH 03/31] add pickle --- tools/models/metrics/run-scib.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index cd3942b32..508e14334 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -9,6 +9,7 @@ import scanpy as sc import scib_metrics from cellxgene_census.experimental import get_embedding +import pickle warnings.filterwarnings("ignore") @@ -230,3 +231,9 @@ def build_anndata_with_embeddings( all_bio[tissue] = metric_bio_results all_batch[tissue] = metric_batch_results + +with open('metrics_bio.pickle', 'wb') as fp: + pickle.dump(all_bio, fp, protocol=pickle.HIGHEST_PROTOCOL) + +with open('metrics_batch.pickle', 'wb') as fp: + pickle.dump(all_batch, fp, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file From ef43141a5ab966142ac9722e0d07e5923ce8cc0c Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 22 Jan 2024 15:54:17 -0800 Subject: [PATCH 04/31] Add requirements.txt --- tools/models/metrics/requirements.txt | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 tools/models/metrics/requirements.txt diff --git a/tools/models/metrics/requirements.txt b/tools/models/metrics/requirements.txt new file mode 100644 index 000000000..c2ac6a403 --- /dev/null +++ b/tools/models/metrics/requirements.txt @@ -0,0 +1,2 @@ +owlready2 +scib-metrics==0.4.0 From 6a08ec1d268775c592fa9ff16ca2e90fc70912d7 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 23 Jan 2024 10:00:11 -0800 Subject: [PATCH 05/31] Add yaml file --- tools/models/metrics/run-scib.py | 433 +++++++++--------- tools/models/metrics/scib-metrics-config.yaml | 17 + 2 files changed, 241 insertions(+), 209 deletions(-) create mode 100644 tools/models/metrics/scib-metrics-config.yaml diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 508e14334..022f6b53a 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -1,5 +1,6 @@ import datetime import itertools +import pickle import warnings from typing import List @@ -8,232 +9,246 @@ import ontology_mapper import scanpy as sc import scib_metrics +import yaml from cellxgene_census.experimental import get_embedding -import pickle warnings.filterwarnings("ignore") +file = "scib-metrics-config.yaml" -# human embeddings -CENSUS_VERSION = "2023-12-15" -# EXPERIMENT_NAME = "homo_sapiens" - -# These are embeddings contributed by the community hosted in S3 -embedding_uris_community = { - "scgpt": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-1/", - "uce": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-2/", -} - -# These are embeddings included in the Census data -embedding_names_census = ["geneformer", "scvi"] - -# All embedding names -embs = list(embedding_uris_community.keys()) + embedding_names_census - -census = cellxgene_census.open_soma(census_version=CENSUS_VERSION) - - -def subclass_mapper(): - mapper = ontology_mapper.CellSubclassMapper(map_orphans_to_class=True) - cell_types = ( - census["census_data"]["homo_sapiens"] - .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") - .concat() - .to_pandas() - ) - cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() - subclass_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} - return subclass_dict - - -def class_mapper(): - mapper = ontology_mapper.CellClassMapper() - cell_types = ( - census["census_data"]["homo_sapiens"] - .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") - .concat() - .to_pandas() - ) - cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() - class_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} - return class_dict - -class_dict = class_mapper() -subclass_dict = subclass_mapper() - - -def build_anndata_with_embeddings( - embedding_uris: dict, - embedding_names: List[str], - coords: List[int] = None, - obs_value_filter: str = None, - column_names=dict, - census_version: str = None, - experiment_name: str = None, -): - """ - For a given set of Census cell coordinates (soma_joinids) - fetch embeddings with TileDBSoma and return the corresponding - AnnData with embeddings slotted in. - - `embedding_uris` is a dict with community embedding names as the keys and S3 URI as the values. - `embedding_names` is a list with embedding names included in Census. - - - Assume that all embeddings provided are coming from the same experiment. - """ - - with cellxgene_census.open_soma(census_version=census_version) as census: - print("Getting anndata with Census embeddings: ", embedding_names) - - ad = cellxgene_census.get_anndata( - census, - organism=experiment_name, - measurement_name="RNA", - obs_value_filter=obs_value_filter, - obs_coords=coords, - obsm_layers=embedding_names, - column_names=column_names, - ) +if __name__ == "__main__": + with open(file) as f: + config = yaml.safe_load(f) - for key in embedding_uris: - print("Getting community embedding:", key) - embedding_uri = embedding_uris[key] - ad.obsm[key] = get_embedding(census_version, embedding_uri, ad.obs["soma_joinid"].to_numpy()) - - # Embeddings with missing data contain all NaN, - # so we must find the intersection of non-NaN rows in the fetched embeddings - # and subset the AnnData accordingly. - filt = np.ones(ad.shape[0], dtype="bool") - for key in ad.obsm.keys(): - nan_row_sums = np.sum(np.isnan(ad.obsm[key]), axis=1) - total_columns = ad.obsm[key].shape[1] - filt = filt & (nan_row_sums != total_columns) - ad = ad[filt].copy() - - return ad - - -tissues = ["adipose tissue", "spinal cord", "skin of body", "spleen", "liver"] -tissues = ["adipose tissue", "spinal cord"] -column_names = { - "obs": ["cell_type_ontology_term_id", "cell_type", "assay", "suspension_type", "dataset_id", "soma_joinid"] -} -umap_plot_labels = ["cell_subclass", "cell_class", "cell_type", "dataset_id"] - -block_cell_types = ["native cell", "animal cell", "eukaryotic cell"] - -all_bio = {} -all_batch = {} - -for tissue in tissues: - print("Tissue", tissue, " getting Anndata") - - # Getting anddata - adata_metrics = build_anndata_with_embeddings( - embedding_uris=embedding_uris_community, - embedding_names=embedding_names_census, - obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", - census_version="2023-12-15", - experiment_name="homo_sapiens", - column_names=column_names, - ) - - # Create batch variable - adata_metrics.obs["batch"] = ( - adata_metrics.obs["assay"] + adata_metrics.obs["dataset_id"] + adata_metrics.obs["suspension_type"] - ) - - # Get cell subclass - adata_metrics.obs["cell_subclass"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(subclass_dict) - adata_metrics = adata_metrics[~adata_metrics.obs["cell_subclass"].isna(),] - - # Get cell class - adata_metrics.obs["cell_class"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(class_dict) - adata_metrics = adata_metrics[~adata_metrics.obs["cell_class"].isna(),] - - # Remove cells in block list of cell types - adata_metrics[~adata_metrics.obs["cell_type"].isin(block_cell_types),] - - print("Tissue", tissue, "cells", adata_metrics.n_obs) - - # Calculate neighbors - for emb_name in embs: - print(datetime.datetime.now(), "Getting neighbors", emb_name) - sc.pp.neighbors(adata_metrics, use_rep=emb_name, key_added=emb_name) - sc.tl.umap(adata_metrics, neighbors_key=emb_name) - adata_metrics.obsm["X_umap_" + emb_name] = adata_metrics.obsm["X_umap"].copy() - del adata_metrics.obsm["X_umap"] - - # Save a few UMAPS - print(datetime.datetime.now(), "Saving UMAP plots") - for emb_name in embs: - for label in umap_plot_labels: - title = "_".join(["UMAP", tissue, emb_name, label]) - sc.pl.embedding(adata_metrics, basis="X_umap_" + emb_name, color=label, title=title, save=title + ".png") - - bio_labels = ["cell_subclass", "cell_class"] - metric_bio_results = { - "embedding": [], - "bio_label": [], - "leiden_nmi": [], - "leiden_ari": [], - "silhouette_label": [], - } + census_config = config.get("census") + embedding_config = config.get("embedding") + metrics_config = config.get("metrics") - batch_labels = ["batch", "assay", "dataset_id", "suspension_type"] - metric_batch_results = { - "embedding": [], - "batch_label": [], - "silhouette_batch": [], - } + census_version = census_config.get("version") + experiment_name = census_config.get("organism") - for bio_label, emb in itertools.product(bio_labels, embs): - print("\n\nSTART", bio_label, emb) + embedding_uris_community = embedding_config.get("hosted") - metric_bio_results["embedding"].append(emb) - metric_bio_results["bio_label"].append(bio_label) + # These are embeddings contributed by the community hosted in S3 + # embedding_uris_community = { + # "scgpt": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-1/", + # "uce": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-2/", + # } - print(datetime.datetime.now(), "Calculating ARI Leiden") - this_metric = scib_metrics.nmi_ari_cluster_labels_leiden( - X=adata_metrics.obsp[emb + "_connectivities"], - labels=adata_metrics.obs[bio_label], - optimize_resolution=True, - resolution=1.0, - n_jobs=64, - ) - metric_bio_results["leiden_nmi"].append(this_metric["nmi"]) - metric_bio_results["leiden_ari"].append(this_metric["ari"]) + # These are embeddings included in the Census data + embedding_names_census = embedding_config.get("collaboration") - print(datetime.datetime.now(), "Calculating silhouette labels") + # All embedding names + embs = list(embedding_uris_community.keys()) + embedding_names_census - this_metric = scib_metrics.silhouette_label( - X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label], rescale=True, chunk_size=512 + census = cellxgene_census.open_soma(census_version=census_version) + + def subclass_mapper(): + mapper = ontology_mapper.CellSubclassMapper(map_orphans_to_class=True) + cell_types = ( + census["census_data"]["homo_sapiens"] + .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") + .concat() + .to_pandas() + ) + cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() + subclass_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} + return subclass_dict + + def class_mapper(): + mapper = ontology_mapper.CellClassMapper() + cell_types = ( + census["census_data"]["homo_sapiens"] + .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") + .concat() + .to_pandas() ) - metric_bio_results["silhouette_label"].append(this_metric) + cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() + class_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} + return class_dict + + class_dict = class_mapper() + subclass_dict = subclass_mapper() + + def build_anndata_with_embeddings( + embedding_uris: dict, + embedding_names: List[str], + embeddings_raw: dict[str, str], + coords: List[int] = None, + obs_value_filter: str = None, + column_names=dict, + census_version: str = None, + experiment_name: str = None, + ): + """ + For a given set of Census cell coordinates (soma_joinids) + fetch embeddings with TileDBSoma and return the corresponding + AnnData with embeddings slotted in. + + `embedding_uris` is a dict with community embedding names as the keys and S3 URI as the values. + `embedding_names` is a list with embedding names included in Census. + `embeddings_raw` are embeddings provided in raw format (npy) on a local drive + + + Assume that all embeddings provided are coming from the same experiment. + """ + + with cellxgene_census.open_soma(census_version=census_version) as census: + print("Getting anndata with Census embeddings: ", embedding_names) + + ad = cellxgene_census.get_anndata( + census, + organism=experiment_name, + measurement_name="RNA", + obs_value_filter=obs_value_filter, + obs_coords=coords, + obsm_layers=embedding_names, + column_names=column_names, + ) + + for key, val in embeddings_raw.items(): + print("Getting community embedding:", key) + embedding_uri = val["uri"] + ad.obsm[key] = get_embedding(census_version, embedding_uri, ad.obs["soma_joinid"].to_numpy()) + + for key, val in embeddings_raw.items(): + print("Getting raw embedding:", key) + ad.obsm[key] = np.load(val["uri"]) + + # Embeddings with missing data contain all NaN, + # so we must find the intersection of non-NaN rows in the fetched embeddings + # and subset the AnnData accordingly. + filt = np.ones(ad.shape[0], dtype="bool") + for key in ad.obsm.keys(): + nan_row_sums = np.sum(np.isnan(ad.obsm[key]), axis=1) + total_columns = ad.obsm[key].shape[1] + filt = filt & (nan_row_sums != total_columns) + ad = ad[filt].copy() + + return ad + + column_names = { + "obs": ["cell_type_ontology_term_id", "cell_type", "assay", "suspension_type", "dataset_id", "soma_joinid"] + } + umap_plot_labels = ["cell_subclass", "cell_class", "cell_type", "dataset_id"] - for batch_label, emb in itertools.product(batch_labels, embs): - print("\n\nSTART", batch_label, emb) + block_cell_types = ["native cell", "animal cell", "eukaryotic cell"] - metric_batch_results["embedding"].append(emb) - metric_batch_results["batch_label"].append(batch_label) + all_bio = {} + all_batch = {} - print(datetime.datetime.now(), "Calculating silhouette batch") + tissues = metrics_config.get("tissues") - this_metric = scib_metrics.silhouette_batch( - X=adata_metrics.obsm[emb], - labels=adata_metrics.obs[bio_label], - batch=adata_metrics.obs[batch_label], - rescale=True, - chunk_size=512, - ) - metric_batch_results["silhouette_batch"].append(this_metric) + for tissue in tissues: + print("Tissue", tissue, " getting Anndata") - all_bio[tissue] = metric_bio_results - all_batch[tissue] = metric_batch_results + # Getting anddata + adata_metrics = build_anndata_with_embeddings( + embedding_uris=embedding_uris_community, + embedding_names=embedding_names_census, + obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", + census_version=census_version, + experiment_name="homo_sapiens", + column_names=column_names, + ) -with open('metrics_bio.pickle', 'wb') as fp: - pickle.dump(all_bio, fp, protocol=pickle.HIGHEST_PROTOCOL) + # Create batch variable + adata_metrics.obs["batch"] = ( + adata_metrics.obs["assay"] + adata_metrics.obs["dataset_id"] + adata_metrics.obs["suspension_type"] + ) -with open('metrics_batch.pickle', 'wb') as fp: - pickle.dump(all_batch, fp, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file + # Get cell subclass + adata_metrics.obs["cell_subclass"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(subclass_dict) + adata_metrics = adata_metrics[~adata_metrics.obs["cell_subclass"].isna(),] + + # Get cell class + adata_metrics.obs["cell_class"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(class_dict) + adata_metrics = adata_metrics[~adata_metrics.obs["cell_class"].isna(),] + + # Remove cells in block list of cell types + adata_metrics[~adata_metrics.obs["cell_type"].isin(block_cell_types),] + + print("Tissue", tissue, "cells", adata_metrics.n_obs) + + # Calculate neighbors + for emb_name in embs: + print(datetime.datetime.now(), "Getting neighbors", emb_name) + sc.pp.neighbors(adata_metrics, use_rep=emb_name, key_added=emb_name) + sc.tl.umap(adata_metrics, neighbors_key=emb_name) + adata_metrics.obsm["X_umap_" + emb_name] = adata_metrics.obsm["X_umap"].copy() + del adata_metrics.obsm["X_umap"] + + # Save a few UMAPS + print(datetime.datetime.now(), "Saving UMAP plots") + for emb_name in embs: + for label in umap_plot_labels: + title = "_".join(["UMAP", tissue, emb_name, label]) + sc.pl.embedding( + adata_metrics, basis="X_umap_" + emb_name, color=label, title=title, save=title + ".png" + ) + + bio_labels = ["cell_subclass", "cell_class"] + metric_bio_results = { + "embedding": [], + "bio_label": [], + "leiden_nmi": [], + "leiden_ari": [], + "silhouette_label": [], + } + + batch_labels = ["batch", "assay", "dataset_id", "suspension_type"] + metric_batch_results = { + "embedding": [], + "batch_label": [], + "silhouette_batch": [], + } + + for bio_label, emb in itertools.product(bio_labels, embs): + print("\n\nSTART", bio_label, emb) + + metric_bio_results["embedding"].append(emb) + metric_bio_results["bio_label"].append(bio_label) + + print(datetime.datetime.now(), "Calculating ARI Leiden") + this_metric = scib_metrics.nmi_ari_cluster_labels_leiden( + X=adata_metrics.obsp[emb + "_connectivities"], + labels=adata_metrics.obs[bio_label], + optimize_resolution=True, + resolution=1.0, + n_jobs=64, + ) + metric_bio_results["leiden_nmi"].append(this_metric["nmi"]) + metric_bio_results["leiden_ari"].append(this_metric["ari"]) + + print(datetime.datetime.now(), "Calculating silhouette labels") + + this_metric = scib_metrics.silhouette_label( + X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label], rescale=True, chunk_size=512 + ) + metric_bio_results["silhouette_label"].append(this_metric) + + for batch_label, emb in itertools.product(batch_labels, embs): + print("\n\nSTART", batch_label, emb) + + metric_batch_results["embedding"].append(emb) + metric_batch_results["batch_label"].append(batch_label) + + print(datetime.datetime.now(), "Calculating silhouette batch") + + this_metric = scib_metrics.silhouette_batch( + X=adata_metrics.obsm[emb], + labels=adata_metrics.obs[bio_label], + batch=adata_metrics.obs[batch_label], + rescale=True, + chunk_size=512, + ) + metric_batch_results["silhouette_batch"].append(this_metric) + + all_bio[tissue] = metric_bio_results + all_batch[tissue] = metric_batch_results + + with open("metrics_bio.pickle", "wb") as fp: + pickle.dump(all_bio, fp, protocol=pickle.HIGHEST_PROTOCOL) + + with open("metrics_batch.pickle", "wb") as fp: + pickle.dump(all_batch, fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml new file mode 100644 index 000000000..ab2f068b0 --- /dev/null +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -0,0 +1,17 @@ +census: + version: + 2023-12-15 + organism: + "homo_sapiens" +embeddings: + hosted: + scgpt: + uri: "s3://cellxgene-contrib-public/contrib/cell-census/soma/2023-12-15/CxG-contrib-1/" + uce: + uri: "s3://cellxgene-contrib-public/contrib/cell-census/soma/2023-12-15/CxG-contrib-2/" + collaboration: + [scvi, geneformer] +metrics: + tissues: + ["adipose tissue", "spinal cord"] + From 78d40296de9a2aeaf36e6ffabc40af19bb8bd560 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 23 Jan 2024 10:02:39 -0800 Subject: [PATCH 06/31] typo --- tools/models/metrics/run-scib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 022f6b53a..350392a1f 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -21,7 +21,7 @@ config = yaml.safe_load(f) census_config = config.get("census") - embedding_config = config.get("embedding") + embedding_config = config.get("embeddings") metrics_config = config.get("metrics") census_version = census_config.get("version") From 5254ab85fe4e5b2f3ae4c031b444db26ad5ccb14 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 23 Jan 2024 10:12:15 -0800 Subject: [PATCH 07/31] changes --- tools/models/metrics/run-scib.py | 7 +++++-- tools/models/metrics/scib-metrics-config.yaml | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 350392a1f..a47b9b50b 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -38,6 +38,8 @@ # These are embeddings included in the Census data embedding_names_census = embedding_config.get("collaboration") + embeddings_raw = embedding_config.get("raw") + # All embedding names embs = list(embedding_uris_community.keys()) + embedding_names_census @@ -73,7 +75,7 @@ def class_mapper(): def build_anndata_with_embeddings( embedding_uris: dict, embedding_names: List[str], - embeddings_raw: dict[str, str], + embeddings_raw: dict, coords: List[int] = None, obs_value_filter: str = None, column_names=dict, @@ -106,7 +108,7 @@ def build_anndata_with_embeddings( column_names=column_names, ) - for key, val in embeddings_raw.items(): + for key, val in embedding_uris.items(): print("Getting community embedding:", key) embedding_uri = val["uri"] ad.obsm[key] = get_embedding(census_version, embedding_uri, ad.obs["soma_joinid"].to_numpy()) @@ -146,6 +148,7 @@ def build_anndata_with_embeddings( adata_metrics = build_anndata_with_embeddings( embedding_uris=embedding_uris_community, embedding_names=embedding_names_census, + embeddings_raw=embeddings_raw, obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", census_version=census_version, experiment_name="homo_sapiens", diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml index ab2f068b0..921e43e6b 100644 --- a/tools/models/metrics/scib-metrics-config.yaml +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -1,6 +1,6 @@ census: version: - 2023-12-15 + "2023-12-15" organism: "homo_sapiens" embeddings: From 9abf1ff69fb61dd6f6df98e5644fe0a4f42f8548 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 23 Jan 2024 10:19:17 -0800 Subject: [PATCH 08/31] changes --- tools/models/metrics/run-scib.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index a47b9b50b..9a33fb204 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -10,13 +10,15 @@ import scanpy as sc import scib_metrics import yaml +import sys from cellxgene_census.experimental import get_embedding warnings.filterwarnings("ignore") -file = "scib-metrics-config.yaml" - if __name__ == "__main__": + + file = sys.argv[1] or "scib-metrics-config.yaml" + with open(file) as f: config = yaml.safe_load(f) @@ -27,7 +29,7 @@ census_version = census_config.get("version") experiment_name = census_config.get("organism") - embedding_uris_community = embedding_config.get("hosted") + embedding_uris_community = embedding_config.get("hosted") or dict() # These are embeddings contributed by the community hosted in S3 # embedding_uris_community = { @@ -36,9 +38,9 @@ # } # These are embeddings included in the Census data - embedding_names_census = embedding_config.get("collaboration") + embedding_names_census = embedding_config.get("collaboration") or dict() - embeddings_raw = embedding_config.get("raw") + embeddings_raw = embedding_config.get("raw") or dict() # All embedding names embs = list(embedding_uris_community.keys()) + embedding_names_census From 8e708613bdb60168e0ec69e9580f21e078d229b1 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 23 Jan 2024 11:32:01 -0800 Subject: [PATCH 09/31] add subsetting --- tools/models/metrics/run-scib.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 9a33fb204..ea74e63d1 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -110,14 +110,18 @@ def build_anndata_with_embeddings( column_names=column_names, ) + obs_idx = ad.obs["soma_joinid"].to_numpy() + for key, val in embedding_uris.items(): print("Getting community embedding:", key) embedding_uri = val["uri"] - ad.obsm[key] = get_embedding(census_version, embedding_uri, ad.obs["soma_joinid"].to_numpy()) + ad.obsm[key] = get_embedding(census_version, embedding_uri, obs_idx) + # For these, we need to extract the right cells via soma_joinid for key, val in embeddings_raw.items(): print("Getting raw embedding:", key) - ad.obsm[key] = np.load(val["uri"]) + emb = np.load(val["uri"]) + ad.obsm[key] = emb[obs_idx] # Embeddings with missing data contain all NaN, # so we must find the intersection of non-NaN rows in the fetched embeddings From e50ce123bdd5e2d369120e48ecfaec8b01250bfd Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 23 Jan 2024 11:51:22 -0800 Subject: [PATCH 10/31] add missing embs --- tools/models/metrics/run-scib.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index ea74e63d1..fa136e5b3 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -17,7 +17,10 @@ if __name__ == "__main__": - file = sys.argv[1] or "scib-metrics-config.yaml" + try: + file = sys.argv[1] + except IndexError: + file = "scib-metrics-config.yaml" with open(file) as f: config = yaml.safe_load(f) @@ -43,7 +46,7 @@ embeddings_raw = embedding_config.get("raw") or dict() # All embedding names - embs = list(embedding_uris_community.keys()) + embedding_names_census + embs = list(embedding_uris_community.keys()) + embedding_names_census + embeddings_raw.keys() census = cellxgene_census.open_soma(census_version=census_version) From ae6e97dd278374e99c9be03f9db1f2e47de1f491 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 23 Jan 2024 11:52:53 -0800 Subject: [PATCH 11/31] fix --- tools/models/metrics/run-scib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index fa136e5b3..721339e6d 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -46,7 +46,9 @@ embeddings_raw = embedding_config.get("raw") or dict() # All embedding names - embs = list(embedding_uris_community.keys()) + embedding_names_census + embeddings_raw.keys() + embs = list(embedding_uris_community.keys()) + embedding_names_census + list(embeddings_raw.keys()) + + print("Embeddings to use: ", embs) census = cellxgene_census.open_soma(census_version=census_version) From 9d47c944b560640fa4b1c33ca1cbd54dd20a7f08 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 24 Jan 2024 09:15:04 -0800 Subject: [PATCH 12/31] Add tiledbsoma support --- tools/models/metrics/run-scib.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 721339e6d..1b8d62485 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -12,6 +12,7 @@ import yaml import sys from cellxgene_census.experimental import get_embedding +import tiledbsoma as soma warnings.filterwarnings("ignore") @@ -125,8 +126,15 @@ def build_anndata_with_embeddings( # For these, we need to extract the right cells via soma_joinid for key, val in embeddings_raw.items(): print("Getting raw embedding:", key) - emb = np.load(val["uri"]) - ad.obsm[key] = emb[obs_idx] + # Alternative approach: set type in the config file + try: + # Assume it's a numpy ndarray + emb = np.load(val["uri"]) + ad.obsm[key] = emb[obs_idx] + except Exception: + # Assume it's a TileDBSoma URI + with soma.open(val["uri"]) as emb: + ad.obsm[key] = emb.read(coords=(obs_idx,)).coos((len(obs_idx), emb.shape[1])).concat().to_scipy().todense() # Embeddings with missing data contain all NaN, # so we must find the intersection of non-NaN rows in the fetched embeddings From 3387c4f42fe440600160f0f5061ca267e0a1bbc4 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 24 Jan 2024 09:38:26 -0800 Subject: [PATCH 13/31] fix indexing --- tools/models/metrics/run-scib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 1b8d62485..9668bae18 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -134,7 +134,7 @@ def build_anndata_with_embeddings( except Exception: # Assume it's a TileDBSoma URI with soma.open(val["uri"]) as emb: - ad.obsm[key] = emb.read(coords=(obs_idx,)).coos((len(obs_idx), emb.shape[1])).concat().to_scipy().todense() + ad.obsm[key] = emb.read(coords=(obs_idx,)).coos().concat().to_scipy().tocsr()[obs_idx, :].todense() # Embeddings with missing data contain all NaN, # so we must find the intersection of non-NaN rows in the fetched embeddings From 29acd78bc8c874a9789c8317da59fe6e611e0b16 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 24 Jan 2024 14:57:21 -0800 Subject: [PATCH 14/31] Add asarray --- tools/models/metrics/run-scib.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 9668bae18..79bdc6d35 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -1,6 +1,7 @@ import datetime import itertools import pickle +import sys import warnings from typing import List @@ -9,15 +10,13 @@ import ontology_mapper import scanpy as sc import scib_metrics +import tiledbsoma as soma import yaml -import sys from cellxgene_census.experimental import get_embedding -import tiledbsoma as soma warnings.filterwarnings("ignore") if __name__ == "__main__": - try: file = sys.argv[1] except IndexError: @@ -134,7 +133,9 @@ def build_anndata_with_embeddings( except Exception: # Assume it's a TileDBSoma URI with soma.open(val["uri"]) as emb: - ad.obsm[key] = emb.read(coords=(obs_idx,)).coos().concat().to_scipy().tocsr()[obs_idx, :].todense() + ad.obsm[key] = np.asarray( + emb.read(coords=(obs_idx,)).coos().concat().to_scipy().tocsr()[obs_idx, :].todense() + ) # Embeddings with missing data contain all NaN, # so we must find the intersection of non-NaN rows in the fetched embeddings From 00175731aedbf7ba9cefbebafb8176fbe92671ad Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 24 Jan 2024 15:12:54 -0800 Subject: [PATCH 15/31] Proper usage of soma_joinid --- tools/models/metrics/run-scib.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 79bdc6d35..33c4eb45e 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -8,6 +8,7 @@ import cellxgene_census import numpy as np import ontology_mapper +import pandas as pd import scanpy as sc import scib_metrics import tiledbsoma as soma @@ -115,12 +116,12 @@ def build_anndata_with_embeddings( column_names=column_names, ) - obs_idx = ad.obs["soma_joinid"].to_numpy() + obs_soma_joinids = ad.obs["soma_joinid"].to_numpy() for key, val in embedding_uris.items(): print("Getting community embedding:", key) embedding_uri = val["uri"] - ad.obsm[key] = get_embedding(census_version, embedding_uri, obs_idx) + ad.obsm[key] = get_embedding(census_version, embedding_uri, obs_soma_joinids) # For these, we need to extract the right cells via soma_joinid for key, val in embeddings_raw.items(): @@ -129,13 +130,23 @@ def build_anndata_with_embeddings( try: # Assume it's a numpy ndarray emb = np.load(val["uri"]) - ad.obsm[key] = emb[obs_idx] + ad.obsm[key] = emb[obs_soma_joinids] except Exception: # Assume it's a TileDBSoma URI - with soma.open(val["uri"]) as emb: - ad.obsm[key] = np.asarray( - emb.read(coords=(obs_idx,)).coos().concat().to_scipy().tocsr()[obs_idx, :].todense() - ) + with soma.open(val["uri"]) as E: + embedding_shape = (len(obs_soma_joinids), E.shape[1]) + embedding = np.full(embedding_shape, np.NaN, dtype=np.float32, order="C") + + obs_indexer = pd.Index(obs_soma_joinids) + for tbl in E.read(coords=(obs_soma_joinids,)).tables(): + obs_idx = obs_indexer.get_indexer(tbl.column("soma_dim_0").to_numpy()) # type: ignore[no-untyped-call] + feat_idx = tbl.column("soma_dim_1").to_numpy() + emb = tbl.column("soma_data") + + indices = obs_idx * E.shape[1] + feat_idx + np.put(embedding.reshape(-1), indices, emb) + + ad.obsm[key] = embedding # Embeddings with missing data contain all NaN, # so we must find the intersection of non-NaN rows in the fetched embeddings From b282b6f783176c97e85979d6a85409cae5f6fe7b Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 24 Jan 2024 15:50:16 -0800 Subject: [PATCH 16/31] Proper usage of soma_joinid, pass 2 --- tools/models/metrics/run-scib.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 33c4eb45e..50ffe8aa4 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -130,7 +130,10 @@ def build_anndata_with_embeddings( try: # Assume it's a numpy ndarray emb = np.load(val["uri"]) - ad.obsm[key] = emb[obs_soma_joinids] + emb_idx = np.load(val["idx"]) + obs_indexer = pd.Index(emb_idx) + idx = obs_indexer.get_indexer(obs_soma_joinids) + ad.obsm[key] = emb[idx] except Exception: # Assume it's a TileDBSoma URI with soma.open(val["uri"]) as E: From 39b2e4f69d50981856f02cd0c567d348dc4eb74a Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 20 Feb 2024 13:44:15 -0800 Subject: [PATCH 17/31] try: ilisi_knn metric --- tools/models/metrics/run-scib.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 50ffe8aa4..444fdf656 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -238,6 +238,7 @@ def build_anndata_with_embeddings( "embedding": [], "batch_label": [], "silhouette_batch": [], + "ilisi_knn_batch": [], } for bio_label, emb in itertools.product(bio_labels, embs): @@ -281,6 +282,14 @@ def build_anndata_with_embeddings( ) metric_batch_results["silhouette_batch"].append(this_metric) + ilisi_metric = scib_metrics.ilisi_knn( + X=adata_metrics.obsp[emb + "_connectivities"], + batches=adata_metrics.obs[batch_label], + scale=True, + ) + + metric_batch_results["ilisi_knn_batch"].append(ilisi_metric) + all_bio[tissue] = metric_bio_results all_batch[tissue] = metric_batch_results From 7050c12e6e78690a03d70d6faee00e1c7b735ae1 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 20 Feb 2024 16:53:31 -0800 Subject: [PATCH 18/31] fix matrix --- tools/models/metrics/run-scib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 444fdf656..c96178f14 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -283,7 +283,7 @@ def build_anndata_with_embeddings( metric_batch_results["silhouette_batch"].append(this_metric) ilisi_metric = scib_metrics.ilisi_knn( - X=adata_metrics.obsp[emb + "_connectivities"], + X=adata_metrics.obsp["distances"], batches=adata_metrics.obs[batch_label], scale=True, ) From ebb91c6d1947538fd0751e991e72ef2afbd81b80 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Wed, 21 Feb 2024 15:51:24 -0800 Subject: [PATCH 19/31] metric fix --- tools/models/metrics/run-scib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index c96178f14..d51f4cbdd 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -211,6 +211,7 @@ def build_anndata_with_embeddings( for emb_name in embs: print(datetime.datetime.now(), "Getting neighbors", emb_name) sc.pp.neighbors(adata_metrics, use_rep=emb_name, key_added=emb_name) + sc.pp.neighbors(adata_metrics, n_neighbors=90, use_rep=emb_name, key_added=emb_name + "_90") sc.tl.umap(adata_metrics, neighbors_key=emb_name) adata_metrics.obsm["X_umap_" + emb_name] = adata_metrics.obsm["X_umap"].copy() del adata_metrics.obsm["X_umap"] @@ -283,7 +284,7 @@ def build_anndata_with_embeddings( metric_batch_results["silhouette_batch"].append(this_metric) ilisi_metric = scib_metrics.ilisi_knn( - X=adata_metrics.obsp["distances"], + X=adata_metrics.obsp[f"{emb}_90_distances"], batches=adata_metrics.obs[batch_label], scale=True, ) From 01235934e16986c44858d2ccbc93e4fa3812cd0f Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Sun, 30 Jun 2024 17:15:16 -0700 Subject: [PATCH 20/31] Update to scib 0.5 --- tools/models/metrics/ontology_mapper.py | 1 + tools/models/metrics/requirements.txt | 2 +- tools/models/metrics/run-scib.py | 121 +++++++++--------- tools/models/metrics/scib-metrics-config.yaml | 11 +- 4 files changed, 71 insertions(+), 64 deletions(-) diff --git a/tools/models/metrics/ontology_mapper.py b/tools/models/metrics/ontology_mapper.py index 3f266a7e5..532818fa5 100644 --- a/tools/models/metrics/ontology_mapper.py +++ b/tools/models/metrics/ontology_mapper.py @@ -26,6 +26,7 @@ class OntologyMapper(ABC): "CARO_0000003", "NCBITaxon_6072", "Thing", + "unknown", ] def __init__( diff --git a/tools/models/metrics/requirements.txt b/tools/models/metrics/requirements.txt index c2ac6a403..2757927fa 100644 --- a/tools/models/metrics/requirements.txt +++ b/tools/models/metrics/requirements.txt @@ -1,2 +1,2 @@ owlready2 -scib-metrics==0.4.0 +scib-metrics==0.5.1 diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index d51f4cbdd..80e232bec 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -13,7 +13,6 @@ import scib_metrics import tiledbsoma as soma import yaml -from cellxgene_census.experimental import get_embedding warnings.filterwarnings("ignore") @@ -33,21 +32,14 @@ census_version = census_config.get("version") experiment_name = census_config.get("organism") - embedding_uris_community = embedding_config.get("hosted") or dict() - - # These are embeddings contributed by the community hosted in S3 - # embedding_uris_community = { - # "scgpt": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-1/", - # "uce": f"s3://cellxgene-contrib-public/contrib/cell-census/soma/{CENSUS_VERSION}/CxG-contrib-2/", - # } - - # These are embeddings included in the Census data - embedding_names_census = embedding_config.get("collaboration") or dict() + # These are embeddings hosted in the Census + embeddings_census = embedding_config.get("census") or dict() + # Raw embeddings (external) embeddings_raw = embedding_config.get("raw") or dict() # All embedding names - embs = list(embedding_uris_community.keys()) + embedding_names_census + list(embeddings_raw.keys()) + embs = list(embeddings_census.keys()) + list(embeddings_raw.keys()) print("Embeddings to use: ", embs) @@ -112,17 +104,12 @@ def build_anndata_with_embeddings( measurement_name="RNA", obs_value_filter=obs_value_filter, obs_coords=coords, - obsm_layers=embedding_names, + obs_embeddings=embedding_names, column_names=column_names, ) obs_soma_joinids = ad.obs["soma_joinid"].to_numpy() - for key, val in embedding_uris.items(): - print("Getting community embedding:", key) - embedding_uri = val["uri"] - ad.obsm[key] = get_embedding(census_version, embedding_uri, obs_soma_joinids) - # For these, we need to extract the right cells via soma_joinid for key, val in embeddings_raw.items(): print("Getting raw embedding:", key) @@ -175,13 +162,15 @@ def build_anndata_with_embeddings( tissues = metrics_config.get("tissues") + bio_metrics = metrics_config["bio"] + batch_metrics = metrics_config["batch"] + for tissue in tissues: print("Tissue", tissue, " getting Anndata") # Getting anddata adata_metrics = build_anndata_with_embeddings( - embedding_uris=embedding_uris_community, - embedding_names=embedding_names_census, + embedding_names=embeddings_census, embeddings_raw=embeddings_raw, obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", census_version=census_version, @@ -211,7 +200,9 @@ def build_anndata_with_embeddings( for emb_name in embs: print(datetime.datetime.now(), "Getting neighbors", emb_name) sc.pp.neighbors(adata_metrics, use_rep=emb_name, key_added=emb_name) - sc.pp.neighbors(adata_metrics, n_neighbors=90, use_rep=emb_name, key_added=emb_name + "_90") + # Only necessary + if "ilisi_knn_batch" in metrics_config["batch"]: + sc.pp.neighbors(adata_metrics, n_neighbors=90, use_rep=emb_name, key_added=emb_name + "_90") sc.tl.umap(adata_metrics, neighbors_key=emb_name) adata_metrics.obsm["X_umap_" + emb_name] = adata_metrics.obsm["X_umap"].copy() del adata_metrics.obsm["X_umap"] @@ -226,22 +217,25 @@ def build_anndata_with_embeddings( ) bio_labels = ["cell_subclass", "cell_class"] + batch_labels = ["batch", "assay", "dataset_id", "suspension_type"] + + # Initialize results metric_bio_results = { "embedding": [], "bio_label": [], - "leiden_nmi": [], - "leiden_ari": [], - "silhouette_label": [], } - - batch_labels = ["batch", "assay", "dataset_id", "suspension_type"] metric_batch_results = { "embedding": [], "batch_label": [], - "silhouette_batch": [], - "ilisi_knn_batch": [], } + for metric in bio_metrics: + metric_bio_results[metric] = [] + + for metric in batch_metrics: + metric_batch_results[metric] = [] + + # Calculate metrics for bio_label, emb in itertools.product(bio_labels, embs): print("\n\nSTART", bio_label, emb) @@ -249,22 +243,31 @@ def build_anndata_with_embeddings( metric_bio_results["bio_label"].append(bio_label) print(datetime.datetime.now(), "Calculating ARI Leiden") - this_metric = scib_metrics.nmi_ari_cluster_labels_leiden( - X=adata_metrics.obsp[emb + "_connectivities"], - labels=adata_metrics.obs[bio_label], - optimize_resolution=True, - resolution=1.0, - n_jobs=64, - ) - metric_bio_results["leiden_nmi"].append(this_metric["nmi"]) - metric_bio_results["leiden_ari"].append(this_metric["ari"]) - print(datetime.datetime.now(), "Calculating silhouette labels") + class NN: + def __init__(self, conn): + self.knn_graph_connectivities = conn - this_metric = scib_metrics.silhouette_label( - X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label], rescale=True, chunk_size=512 - ) - metric_bio_results["silhouette_label"].append(this_metric) + X = NN(adata_metrics.obsp[emb + "_connectivities"]) + + if "leiden_nmi" in bio_metrics and "leiden_ari" in bio_metrics: + this_metric = scib_metrics.nmi_ari_cluster_labels_leiden( + X=X, + labels=adata_metrics.obs[bio_label], + optimize_resolution=True, + resolution=1.0, + n_jobs=64, + ) + metric_bio_results["leiden_nmi"].append(this_metric["nmi"]) + metric_bio_results["leiden_ari"].append(this_metric["ari"]) + + if "silhouette_label" in bio_metrics: + print(datetime.datetime.now(), "Calculating silhouette labels") + + this_metric = scib_metrics.silhouette_label( + X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label], rescale=True, chunk_size=512 + ) + metric_bio_results["silhouette_label"].append(this_metric) for batch_label, emb in itertools.product(batch_labels, embs): print("\n\nSTART", batch_label, emb) @@ -272,24 +275,28 @@ def build_anndata_with_embeddings( metric_batch_results["embedding"].append(emb) metric_batch_results["batch_label"].append(batch_label) - print(datetime.datetime.now(), "Calculating silhouette batch") + if "silhouette_batch" in batch_metrics: + print(datetime.datetime.now(), "Calculating silhouette batch") - this_metric = scib_metrics.silhouette_batch( - X=adata_metrics.obsm[emb], - labels=adata_metrics.obs[bio_label], - batch=adata_metrics.obs[batch_label], - rescale=True, - chunk_size=512, - ) - metric_batch_results["silhouette_batch"].append(this_metric) + this_metric = scib_metrics.silhouette_batch( + X=adata_metrics.obsm[emb], + labels=adata_metrics.obs[bio_label], + batch=adata_metrics.obs[batch_label], + rescale=True, + chunk_size=512, + ) + metric_batch_results["silhouette_batch"].append(this_metric) - ilisi_metric = scib_metrics.ilisi_knn( - X=adata_metrics.obsp[f"{emb}_90_distances"], - batches=adata_metrics.obs[batch_label], - scale=True, - ) + if "ilisi_knn_batch" in batch_metrics: + print(datetime.datetime.now(), "Calculating ilisi knn batch") + + ilisi_metric = scib_metrics.ilisi_knn( + X=adata_metrics.obsp[f"{emb}_90_distances"], + batches=adata_metrics.obs[batch_label], + scale=True, + ) - metric_batch_results["ilisi_knn_batch"].append(ilisi_metric) + metric_batch_results["ilisi_knn_batch"].append(ilisi_metric) all_bio[tissue] = metric_bio_results all_batch[tissue] = metric_batch_results diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml index 921e43e6b..690a82012 100644 --- a/tools/models/metrics/scib-metrics-config.yaml +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -4,14 +4,13 @@ census: organism: "homo_sapiens" embeddings: - hosted: - scgpt: - uri: "s3://cellxgene-contrib-public/contrib/cell-census/soma/2023-12-15/CxG-contrib-1/" - uce: - uri: "s3://cellxgene-contrib-public/contrib/cell-census/soma/2023-12-15/CxG-contrib-2/" - collaboration: + census: [scvi, geneformer] metrics: tissues: ["adipose tissue", "spinal cord"] + bio: + ["leiden_nmi", "leiden_ari", "silhouette_label"] + batch: + ["silhouette_batch", "ilisi_knn_batch"] From 5dbde3e6a04351a9398a7e146a98c815c0ed586e Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 1 Jul 2024 11:31:59 -0700 Subject: [PATCH 21/31] Add classifier metrics --- tools/models/metrics/requirements.txt | 1 + tools/models/metrics/run-scib.py | 93 ++++++++++++++++--- tools/models/metrics/scib-metrics-config.yaml | 9 +- 3 files changed, 89 insertions(+), 14 deletions(-) diff --git a/tools/models/metrics/requirements.txt b/tools/models/metrics/requirements.txt index 2757927fa..a3dd411e6 100644 --- a/tools/models/metrics/requirements.txt +++ b/tools/models/metrics/requirements.txt @@ -1,2 +1,3 @@ owlready2 scib-metrics==0.5.1 +pyyaml \ No newline at end of file diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 80e232bec..6a0977ed8 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -14,8 +14,61 @@ import tiledbsoma as soma import yaml +from sklearn.preprocessing import LabelEncoder +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn import svm +from sklearn.ensemble import RandomForestClassifier + +from sklearn.metrics import accuracy_score, roc_auc_score + warnings.filterwarnings("ignore") +class CensusClassifierMetrics: + + def __init__(self): + self._default_metric = "accuracy" + + def lr_labels(self, X, labels, metric = None): + return self._base_accuracy(X, labels, LogisticRegression, metric=metric) + + def svm_svc_labels(self, X, labels, metric = None): + return self._base_accuracy(X, labels, svm.SVC, metric=metric) + + def random_forest_labels(self, X, labels, metric = None, n_jobs=8): + return self._base_accuracy(X, labels, RandomForestClassifier, metric=metric, n_jobs=n_jobs) + + def lr_batch(self, X, batch, metric = None): + return 1-self._base_accuracy(X, batch, LogisticRegression, metric=metric) + + def svm_svc_batch(self, X, batch, metric = None): + return 1-self._base_accuracy(X, batch, svm.SVC, metric=metric) + + def random_forest_batch(self, X, batch, metric = None, n_jobs=8): + return 1-self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs) + + def _base_accuracy(self, X, y, model, metric, test_size=0.4, **kwargs): + """ + Train LogisticRegression on X with labels y and return classifier accuracy score + """ + y_encoded = LabelEncoder().fit_transform(y) + X_train, X_test, y_train, y_test = train_test_split( + X, y_encoded, test_size=test_size, random_state=42 + ) + model = model(**kwargs).fit(X_train, y_train) + + if metric == None: + metric = self._default_metric + + if metric == "roc_auc": + #return y_test + #return model.predict_proba(X_test) + return roc_auc_score(y_test, model.predict_proba(X_test), multi_class="ovo", average="macro") + elif metric == "accuracy": + return accuracy_score(y_test, model.predict(X_test)) + else: + raise ValueError("Only {'accuracy', 'roc_auc'} are supported as a metric") + if __name__ == "__main__": try: file = sys.argv[1] @@ -33,13 +86,13 @@ experiment_name = census_config.get("organism") # These are embeddings hosted in the Census - embeddings_census = embedding_config.get("census") or dict() + embeddings_census = embedding_config.get("census") or [] # Raw embeddings (external) embeddings_raw = embedding_config.get("raw") or dict() # All embedding names - embs = list(embeddings_census.keys()) + list(embeddings_raw.keys()) + embs = list(embeddings_census) + list(embeddings_raw.keys()) print("Embeddings to use: ", embs) @@ -73,7 +126,6 @@ def class_mapper(): subclass_dict = subclass_mapper() def build_anndata_with_embeddings( - embedding_uris: dict, embedding_names: List[str], embeddings_raw: dict, coords: List[int] = None, @@ -87,7 +139,6 @@ def build_anndata_with_embeddings( fetch embeddings with TileDBSoma and return the corresponding AnnData with embeddings slotted in. - `embedding_uris` is a dict with community embedding names as the keys and S3 URI as the values. `embedding_names` is a list with embedding names included in Census. `embeddings_raw` are embeddings provided in raw format (npy) on a local drive @@ -165,14 +216,18 @@ def build_anndata_with_embeddings( bio_metrics = metrics_config["bio"] batch_metrics = metrics_config["batch"] - for tissue in tissues: + for tissue_node in tissues: + + tissue = tissue_node["name"] + query = tissue_node.get("query") or f"tissue_general == '{tissue}' and is_primary_data == True" + print("Tissue", tissue, " getting Anndata") # Getting anddata adata_metrics = build_anndata_with_embeddings( embedding_names=embeddings_census, embeddings_raw=embeddings_raw, - obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", + obs_value_filter=query, census_version=census_version, experiment_name="homo_sapiens", column_names=column_names, @@ -269,6 +324,16 @@ def __init__(self, conn): ) metric_bio_results["silhouette_label"].append(this_metric) + if "classifier" in bio_metrics: + metrics = CensusClassifierMetrics() + + m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) + m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) + m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) + + metric_bio_results["classifier"].append({"lr": m1, "svm": m2, "random_forest": m3}) + + for batch_label, emb in itertools.product(batch_labels, embs): print("\n\nSTART", batch_label, emb) @@ -298,11 +363,17 @@ def __init__(self, conn): metric_batch_results["ilisi_knn_batch"].append(ilisi_metric) + if "classifier" in batch_metrics: + metrics = CensusClassifierMetrics() + + m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs["batch"]) + m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs["batch"]) + m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs["batch"]) + metric_batch_results["classifier"].append({"lr": m4, "random_forest": m5, "svm": m6}) + + all_bio[tissue] = metric_bio_results all_batch[tissue] = metric_batch_results - with open("metrics_bio.pickle", "wb") as fp: - pickle.dump(all_bio, fp, protocol=pickle.HIGHEST_PROTOCOL) - - with open("metrics_batch.pickle", "wb") as fp: - pickle.dump(all_batch, fp, protocol=pickle.HIGHEST_PROTOCOL) + with open("metrics.pickle", "wb") as fp: + pickle.dump({"all_bio": all_bio, "all_batch": all_batch}, fp, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml index 690a82012..6d1a677d4 100644 --- a/tools/models/metrics/scib-metrics-config.yaml +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -8,9 +8,12 @@ embeddings: [scvi, geneformer] metrics: tissues: - ["adipose tissue", "spinal cord"] + - name: "adipose tissue" + - name: "spinal cord" + - name: "heart" + query: 'tissue in ["cardiac ventricle", "heart left ventricle", "heart right ventricle"] and datasets in ["53d208b0-2cfd-4366-9866-c3c6114081bc", "d567b692-c374-4628-a508-8008f6778f22", "f15e263b-6544-46cb-a46e-e33ab7ce8347", "d4e69e01-3ba2-4d6b-a15d-e7048f78f22e"] and is_primary_data==True' bio: - ["leiden_nmi", "leiden_ari", "silhouette_label"] + ["leiden_nmi", "leiden_ari", "silhouette_label", "classifier"] batch: - ["silhouette_batch", "ilisi_knn_batch"] + ["silhouette_batch", "ilisi_knn_batch", "classifier"] From 5d9d305bbe4d40b0e699d3a8e2ebc431575b7023 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 1 Jul 2024 16:38:23 -0700 Subject: [PATCH 22/31] All metrics --- tools/models/metrics/run-scib.py | 48 ++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 6a0977ed8..7d6495eb9 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -14,6 +14,11 @@ import tiledbsoma as soma import yaml +import numpy as np +import scipy as sp +import cellxgene_census +import functools + from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression @@ -68,6 +73,37 @@ def _base_accuracy(self, X, y, model, metric, test_size=0.4, **kwargs): return accuracy_score(y_test, model.predict(X_test)) else: raise ValueError("Only {'accuracy', 'roc_auc'} are supported as a metric") + +def safelog(a): + return np.log(a, out=np.zeros_like(a), where=(a!=0)) + +def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors = 100): + import hnswlib + labels = np.arange(x.shape[0]) + p = hnswlib.Index(space = 'l2', dim = x.shape[1]) + p.init_index(max_elements = x.shape[0], ef_construction = ef, M = M) + p.add_items(x, labels) + p.set_ef(ef) + idx, dist = p.knn_query(x, k = n_neighbors) + return idx,dist + +def compute_entropy_per_cell(adata, obsm_key): + + batch_keys = ["donor_id", "dataset_id", "assay", "suspension_type"] + adata.obs["batch"] = functools.reduce(lambda a, b: a+b, [adata.obs[c].astype(str) for c in batch_keys]) + + indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors = 200) + + BATCH_KEY = 'batch' + + batch_labels = np.array(list(adata.obs[BATCH_KEY])) + unique_batch_labels = np.unique(batch_labels) + + indices_batch = batch_labels[indices] + + label_counts_per_cell = np.vstack([(indices_batch == label).sum(1) for label in unique_batch_labels]).T + label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:,None] + return (-label_counts_per_cell_normed*safelog(label_counts_per_cell_normed)).sum(1) if __name__ == "__main__": try: @@ -366,11 +402,17 @@ def __init__(self, conn): if "classifier" in batch_metrics: metrics = CensusClassifierMetrics() - m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs["batch"]) - m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs["batch"]) - m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs["batch"]) + m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) + m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) + m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) metric_batch_results["classifier"].append({"lr": m4, "random_forest": m5, "svm": m6}) + if "entropy" in batch_metrics: + print(datetime.datetime.now(), "Calculating entropy") + + entropy = compute_entropy_per_cell(adata_metrics, emb) + e_mean = entropy.mean() + metric_batch_results["entropy"].append(e_mean) all_bio[tissue] = metric_bio_results all_batch[tissue] = metric_batch_results From 386880a7432b68e7749719f89250b13a45070593 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Mon, 1 Jul 2024 20:35:35 -0700 Subject: [PATCH 23/31] fix missing key --- tools/models/metrics/run-scib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 7d6495eb9..5d35a2f70 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -89,7 +89,7 @@ def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors = 100): def compute_entropy_per_cell(adata, obsm_key): - batch_keys = ["donor_id", "dataset_id", "assay", "suspension_type"] + batch_keys = ["dataset_id", "assay", "suspension_type"] adata.obs["batch"] = functools.reduce(lambda a, b: a+b, [adata.obs[c].astype(str) for c in batch_keys]) indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors = 200) From 6798b0ce0fce31567216620e2842d5350ba1d155 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 2 Jul 2024 16:39:51 -0700 Subject: [PATCH 24/31] Change query --- tools/models/metrics/scib-metrics-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml index 6d1a677d4..94eef9085 100644 --- a/tools/models/metrics/scib-metrics-config.yaml +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -11,7 +11,7 @@ metrics: - name: "adipose tissue" - name: "spinal cord" - name: "heart" - query: 'tissue in ["cardiac ventricle", "heart left ventricle", "heart right ventricle"] and datasets in ["53d208b0-2cfd-4366-9866-c3c6114081bc", "d567b692-c374-4628-a508-8008f6778f22", "f15e263b-6544-46cb-a46e-e33ab7ce8347", "d4e69e01-3ba2-4d6b-a15d-e7048f78f22e"] and is_primary_data==True' + query: 'tissue in ["cardiac ventricle", "heart left ventricle", "heart right ventricle"] and dataset_id in ["53d208b0-2cfd-4366-9866-c3c6114081bc", "d567b692-c374-4628-a508-8008f6778f22", "f15e263b-6544-46cb-a46e-e33ab7ce8347", "d4e69e01-3ba2-4d6b-a15d-e7048f78f22e"] and is_primary_data==True' bio: ["leiden_nmi", "leiden_ari", "silhouette_label", "classifier"] batch: From 88936dcef50bdf7434d6d0f53ced6abe98e20b0c Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Fri, 5 Jul 2024 13:59:34 -0700 Subject: [PATCH 25/31] Fixes --- tools/models/metrics/run-scib.py | 112 ++++++++---------- tools/models/metrics/scib-metrics-config.yaml | 3 +- 2 files changed, 52 insertions(+), 63 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 5d35a2f70..5627599d5 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -1,9 +1,9 @@ import datetime +import functools import itertools import pickle import sys import warnings -from typing import List import cellxgene_census import numpy as np @@ -13,88 +13,80 @@ import scib_metrics import tiledbsoma as soma import yaml - -import numpy as np -import scipy as sp -import cellxgene_census -import functools - -from sklearn.preprocessing import LabelEncoder -from sklearn.model_selection import train_test_split -from sklearn.linear_model import LogisticRegression from sklearn import svm from sklearn.ensemble import RandomForestClassifier - +from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, roc_auc_score +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder warnings.filterwarnings("ignore") -class CensusClassifierMetrics: +class CensusClassifierMetrics: def __init__(self): self._default_metric = "accuracy" - def lr_labels(self, X, labels, metric = None): + def lr_labels(self, X, labels, metric=None): return self._base_accuracy(X, labels, LogisticRegression, metric=metric) - def svm_svc_labels(self, X, labels, metric = None): + def svm_svc_labels(self, X, labels, metric=None): return self._base_accuracy(X, labels, svm.SVC, metric=metric) - def random_forest_labels(self, X, labels, metric = None, n_jobs=8): + def random_forest_labels(self, X, labels, metric=None, n_jobs=8): return self._base_accuracy(X, labels, RandomForestClassifier, metric=metric, n_jobs=n_jobs) - def lr_batch(self, X, batch, metric = None): - return 1-self._base_accuracy(X, batch, LogisticRegression, metric=metric) + def lr_batch(self, X, batch, metric=None): + return 1 - self._base_accuracy(X, batch, LogisticRegression, metric=metric) - def svm_svc_batch(self, X, batch, metric = None): - return 1-self._base_accuracy(X, batch, svm.SVC, metric=metric) + def svm_svc_batch(self, X, batch, metric=None): + return 1 - self._base_accuracy(X, batch, svm.SVC, metric=metric) - def random_forest_batch(self, X, batch, metric = None, n_jobs=8): - return 1-self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs) + def random_forest_batch(self, X, batch, metric=None, n_jobs=8): + return 1 - self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs) def _base_accuracy(self, X, y, model, metric, test_size=0.4, **kwargs): - """ - Train LogisticRegression on X with labels y and return classifier accuracy score - """ + """Train LogisticRegression on X with labels y and return classifier accuracy score""" y_encoded = LabelEncoder().fit_transform(y) - X_train, X_test, y_train, y_test = train_test_split( - X, y_encoded, test_size=test_size, random_state=42 - ) + X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=test_size, random_state=42) model = model(**kwargs).fit(X_train, y_train) if metric == None: metric = self._default_metric - - if metric == "roc_auc": - #return y_test - #return model.predict_proba(X_test) + + if metric == "roc_auc": + # return y_test + # return model.predict_proba(X_test) return roc_auc_score(y_test, model.predict_proba(X_test), multi_class="ovo", average="macro") elif metric == "accuracy": return accuracy_score(y_test, model.predict(X_test)) else: raise ValueError("Only {'accuracy', 'roc_auc'} are supported as a metric") - + + def safelog(a): - return np.log(a, out=np.zeros_like(a), where=(a!=0)) + return np.log(a, out=np.zeros_like(a), where=(a != 0)) + -def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors = 100): +def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors=100): import hnswlib + labels = np.arange(x.shape[0]) - p = hnswlib.Index(space = 'l2', dim = x.shape[1]) - p.init_index(max_elements = x.shape[0], ef_construction = ef, M = M) + p = hnswlib.Index(space="l2", dim=x.shape[1]) + p.init_index(max_elements=x.shape[0], ef_construction=ef, M=M) p.add_items(x, labels) p.set_ef(ef) - idx, dist = p.knn_query(x, k = n_neighbors) - return idx,dist + idx, dist = p.knn_query(x, k=n_neighbors) + return idx, dist -def compute_entropy_per_cell(adata, obsm_key): +def compute_entropy_per_cell(adata, obsm_key): batch_keys = ["dataset_id", "assay", "suspension_type"] - adata.obs["batch"] = functools.reduce(lambda a, b: a+b, [adata.obs[c].astype(str) for c in batch_keys]) + adata.obs["batch"] = functools.reduce(lambda a, b: a + b, [adata.obs[c].astype(str) for c in batch_keys]) - indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors = 200) + indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors=200) - BATCH_KEY = 'batch' + BATCH_KEY = "batch" batch_labels = np.array(list(adata.obs[BATCH_KEY])) unique_batch_labels = np.unique(batch_labels) @@ -102,8 +94,9 @@ def compute_entropy_per_cell(adata, obsm_key): indices_batch = batch_labels[indices] label_counts_per_cell = np.vstack([(indices_batch == label).sum(1) for label in unique_batch_labels]).T - label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:,None] - return (-label_counts_per_cell_normed*safelog(label_counts_per_cell_normed)).sum(1) + label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:, None] + return (-label_counts_per_cell_normed * safelog(label_counts_per_cell_normed)).sum(1) + if __name__ == "__main__": try: @@ -162,16 +155,15 @@ def class_mapper(): subclass_dict = subclass_mapper() def build_anndata_with_embeddings( - embedding_names: List[str], + embedding_names: list[str], embeddings_raw: dict, - coords: List[int] = None, + coords: list[int] = None, obs_value_filter: str = None, column_names=dict, census_version: str = None, experiment_name: str = None, ): - """ - For a given set of Census cell coordinates (soma_joinids) + """For a given set of Census cell coordinates (soma_joinids) fetch embeddings with TileDBSoma and return the corresponding AnnData with embeddings slotted in. @@ -181,7 +173,6 @@ def build_anndata_with_embeddings( Assume that all embeddings provided are coming from the same experiment. """ - with cellxgene_census.open_soma(census_version=census_version) as census: print("Getting anndata with Census embeddings: ", embedding_names) @@ -253,7 +244,6 @@ def build_anndata_with_embeddings( batch_metrics = metrics_config["batch"] for tissue_node in tissues: - tissue = tissue_node["name"] query = tissue_node.get("query") or f"tissue_general == '{tissue}' and is_primary_data == True" @@ -363,13 +353,12 @@ def __init__(self, conn): if "classifier" in bio_metrics: metrics = CensusClassifierMetrics() - m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) - m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) - m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) + m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) + m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) + m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) metric_bio_results["classifier"].append({"lr": m1, "svm": m2, "random_forest": m3}) - for batch_label, emb in itertools.product(batch_labels, embs): print("\n\nSTART", batch_label, emb) @@ -402,9 +391,9 @@ def __init__(self, conn): if "classifier" in batch_metrics: metrics = CensusClassifierMetrics() - m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) - m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) - m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) + m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) + m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) + m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) metric_batch_results["classifier"].append({"lr": m4, "random_forest": m5, "svm": m6}) if "entropy" in batch_metrics: @@ -414,8 +403,9 @@ def __init__(self, conn): e_mean = entropy.mean() metric_batch_results["entropy"].append(e_mean) - all_bio[tissue] = metric_bio_results - all_batch[tissue] = metric_batch_results + filename = f"metrics.{tissue}.pickle".replace(" ", "-").lower() - with open("metrics.pickle", "wb") as fp: - pickle.dump({"all_bio": all_bio, "all_batch": all_batch}, fp, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file + with open(filename, "wb") as fp: + pickle.dump( + {"bio": metric_bio_results, "batch": metric_batch_results}, fp, protocol=pickle.HIGHEST_PROTOCOL + ) diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml index 94eef9085..2ba7d30eb 100644 --- a/tools/models/metrics/scib-metrics-config.yaml +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -15,5 +15,4 @@ metrics: bio: ["leiden_nmi", "leiden_ari", "silhouette_label", "classifier"] batch: - ["silhouette_batch", "ilisi_knn_batch", "classifier"] - + ["silhouette_batch", "ilisi_knn_batch", "classifier", "entropy"] \ No newline at end of file From 1e059cae04d826d5adeba6cbc8e7fba91bae01d0 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 9 Jul 2024 09:52:25 -0700 Subject: [PATCH 26/31] Add ignore directives --- tools/models/metrics/ontology_mapper.py | 3 +++ tools/models/metrics/run-scib.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tools/models/metrics/ontology_mapper.py b/tools/models/metrics/ontology_mapper.py index 532818fa5..8868e553a 100644 --- a/tools/models/metrics/ontology_mapper.py +++ b/tools/models/metrics/ontology_mapper.py @@ -1,3 +1,6 @@ +# ruff: noqa +# type: ignore + """ Provides classes to recreate cell type and tissue mappings as used in CELLxGENE Discover diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 5627599d5..74226b31d 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -1,3 +1,6 @@ +# ruff: noqa +# type: ignore + import datetime import functools import itertools From 95967944de1c279252eb41cb7b8c827217ba31a5 Mon Sep 17 00:00:00 2001 From: pablo-gar Date: Tue, 9 Jul 2024 10:02:10 -0700 Subject: [PATCH 27/31] [misc] update metrics script for 2024 LTS --- tools/models/metrics/ontology_mapper.py | 14 ++++++++++---- tools/models/metrics/run-scib.py | 25 +++++++++++-------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/tools/models/metrics/ontology_mapper.py b/tools/models/metrics/ontology_mapper.py index 8868e553a..76a323d32 100644 --- a/tools/models/metrics/ontology_mapper.py +++ b/tools/models/metrics/ontology_mapper.py @@ -58,6 +58,9 @@ def get_high_level_terms(self, ontology_term_id: str) -> List[str]: Returns the associated high-level ontology term IDs from any other ID """ + if ontology_term_id == "unknown": + return ["unknown"] + ontology_term_id = self.reformat_ontology_term_id(ontology_term_id, to_writable=False) if ontology_term_id in self._cached_high_level_terms: @@ -117,6 +120,9 @@ def get_label_from_id(self, ontology_term_id: str): Example: "UBERON_0002048" raises ValueError because the ID is not in writable form """ + if ontology_term_id == "unknown": + return "unknown" + if ontology_term_id in self._cached_labels: return self._cached_labels[ontology_term_id] @@ -203,8 +209,8 @@ def _is_and_object(entity: owlready2.entity.ThingClass) -> bool: class CellMapper(OntologyMapper): - # From schema 3.1.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/3.1.0/schema.md - CXG_CL_ONTOLOGY_URL = "https://github.com/obophenotype/cell-ontology/releases/download/v2023-07-20/cl.owl" + # From schema 5.0.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/5.0.0/schema.md + CXG_CL_ONTOLOGY_URL = "https://github.com/obophenotype/cell-ontology/releases/download/v2024-01-04/cl.owl" # Only look up ancestors under Cell ROOT_NODE = "CL_0000000" @@ -238,8 +244,8 @@ def _get_is_a_for_cl(owl_entity): class TissueMapper(OntologyMapper): - # From schema 3.1.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/3.1.0/schema.md - CXG_UBERON_ONTOLOGY_URL = "https://github.com/obophenotype/uberon/releases/download/v2023-06-28/uberon.owl" + # From schema 5.0.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/5.0.0/schema.md + CXG_UBERON_ONTOLOGY_URL = "https://github.com/obophenotype/uberon/releases/download/v2024-01-18/uberon.owl" # Only look up ancestors under anatomical entity ROOT_NODE = "UBERON_0001062" diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 74226b31d..9ee51a84b 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -203,21 +203,14 @@ def build_anndata_with_embeddings( idx = obs_indexer.get_indexer(obs_soma_joinids) ad.obsm[key] = emb[idx] except Exception: + from scipy.sparse import vstack # Assume it's a TileDBSoma URI + all_embs = [] with soma.open(val["uri"]) as E: - embedding_shape = (len(obs_soma_joinids), E.shape[1]) - embedding = np.full(embedding_shape, np.NaN, dtype=np.float32, order="C") - - obs_indexer = pd.Index(obs_soma_joinids) - for tbl in E.read(coords=(obs_soma_joinids,)).tables(): - obs_idx = obs_indexer.get_indexer(tbl.column("soma_dim_0").to_numpy()) # type: ignore[no-untyped-call] - feat_idx = tbl.column("soma_dim_1").to_numpy() - emb = tbl.column("soma_data") - - indices = obs_idx * E.shape[1] + feat_idx - np.put(embedding.reshape(-1), indices, emb) - - ad.obsm[key] = embedding + for mat in E.read(coords=(obs_soma_joinids,)).blockwise(axis=0).scipy(): + all_embs.append(mat[0]) + ad.obsm[key] = vstack(all_embs).toarray() + print("DIM:", ad.obsm[key].shape) # Embeddings with missing data contain all NaN, # so we must find the intersection of non-NaN rows in the fetched embeddings @@ -236,7 +229,7 @@ def build_anndata_with_embeddings( } umap_plot_labels = ["cell_subclass", "cell_class", "cell_type", "dataset_id"] - block_cell_types = ["native cell", "animal cell", "eukaryotic cell"] + block_cell_types = ["native cell", "animal cell", "eukaryotic cell", "unknown"] all_bio = {} all_batch = {} @@ -262,6 +255,10 @@ def build_anndata_with_embeddings( column_names=column_names, ) + for column in adata_metrics.obs.columns: + if adata_metrics.obs[column].dtype.name == "category": + adata_metrics.obs[column] = adata_metrics.obs[column].astype(str) + # Create batch variable adata_metrics.obs["batch"] = ( adata_metrics.obs["assay"] + adata_metrics.obs["dataset_id"] + adata_metrics.obs["suspension_type"] From ef0b96f5ee40b446502f79a3619907231ad4ce0a Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 9 Jul 2024 10:15:23 -0700 Subject: [PATCH 28/31] Entropy fix --- tools/models/metrics/run-scib.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 74226b31d..b6a791171 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -83,15 +83,10 @@ def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors=100): return idx, dist -def compute_entropy_per_cell(adata, obsm_key): - batch_keys = ["dataset_id", "assay", "suspension_type"] - adata.obs["batch"] = functools.reduce(lambda a, b: a + b, [adata.obs[c].astype(str) for c in batch_keys]) - +def compute_entropy_per_cell(adata, obsm_key, batch_key): indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors=200) - BATCH_KEY = "batch" - - batch_labels = np.array(list(adata.obs[BATCH_KEY])) + batch_labels = np.array(list(adata.obs[batch_key])) unique_batch_labels = np.unique(batch_labels) indices_batch = batch_labels[indices] @@ -402,7 +397,7 @@ def __init__(self, conn): if "entropy" in batch_metrics: print(datetime.datetime.now(), "Calculating entropy") - entropy = compute_entropy_per_cell(adata_metrics, emb) + entropy = compute_entropy_per_cell(adata_metrics, emb, batch_label) e_mean = entropy.mean() metric_batch_results["entropy"].append(e_mean) From f44637ba33567400820407f4f7b9984e52966156 Mon Sep 17 00:00:00 2001 From: Emanuele Bezzi Date: Tue, 9 Jul 2024 10:34:58 -0700 Subject: [PATCH 29/31] Lint --- tools/models/metrics/run-scib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 34f5c150d..be761b962 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -199,6 +199,7 @@ def build_anndata_with_embeddings( ad.obsm[key] = emb[idx] except Exception: from scipy.sparse import vstack + # Assume it's a TileDBSoma URI all_embs = [] with soma.open(val["uri"]) as E: From 141eee808faaab2d7938ef6daaf48d3ebb8c6c59 Mon Sep 17 00:00:00 2001 From: Pablo Garcia-Nieto Date: Wed, 10 Jul 2024 14:51:18 -0700 Subject: [PATCH 30/31] add plotting scripts --- tools/models/metrics/run-scib-plots.py | 176 +++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 tools/models/metrics/run-scib-plots.py diff --git a/tools/models/metrics/run-scib-plots.py b/tools/models/metrics/run-scib-plots.py new file mode 100644 index 000000000..6db1a5cd6 --- /dev/null +++ b/tools/models/metrics/run-scib-plots.py @@ -0,0 +1,176 @@ +import sys +import yaml +import pickle +import pandas as pd +import plotnine as pt + +from typing import List, Dict + + +def main(): + + try: + file = sys.argv[1] + except IndexError: + file = "scib-metrics-config-plot.yaml" + + print (file) + + with open(file) as f: + config = yaml.safe_load(f) + + metrics = load_metrics( + metrics_file_dict=config["metrics_files"], + ) + + df_metrics, df_summary = wrangle_metrics( + metrics=metrics, + bio_metrics_to_plot=config["metrics_plot"]["bio"], + batch_metrics_to_plot=config["metrics_plot"]["batch"], + decimal_values=config["decimal_values"], + normalize_entropy=config["normalize_entropy"], + ) + + suf = config["output_suffix"] + + p = make_headtmap_summary(df_summary) + p.save(filename=f"{suf}metrics_0_summary.png", dpi=300) + + p = make_headtmap(df_metrics.query("metric_mode == 'Embedding space'"), "Bio-conservation") + p.save(filename=f"{suf}metrics_1_bio_emb.png", dpi=300) + + p = make_headtmap(df_metrics.query("metric_mode == 'Label classifier'"), "Bio-conservation") + p.save(filename=f"{suf}metrics_2_bio_classifier.png", dpi=300) + + p = make_headtmap(df_metrics.query("metric_mode == 'Embedding space'"), "Batch-correction", height= 5) + p.save(filename=f"{suf}metrics_3_batch_emb.png", dpi=300) + + p = make_headtmap(df_metrics.query("metric_mode == 'Label classifier'"), "Batch-correction", height= 5) + p.save(filename=f"{suf}metrics_4_batch_classifier.png", dpi=300) + + +def load_metrics(metrics_file_dict: Dict[str, str]): + + all_metrics = {"bio": {}, "batch":{}} + for tissue in metrics_file_dict: + metrics_file = metrics_file_dict[tissue] + with open(metrics_file, "rb") as op: + metrics = pickle.load(op) + + # Expand classfier metrics into individual keys + for metric_type in metrics: + classifier_types = list(metrics[metric_type]['classifier'][0].keys()) + for classifier_type in classifier_types: + metrics[metric_type][f"classifier_{classifier_type}"] = [] + for i in metrics[metric_type]['classifier']: + metrics[metric_type][f"classifier_{classifier_type}"].append( + i[classifier_type] + ) + del metrics[metric_type]['classifier'] + + all_metrics["bio"][tissue] = metrics["bio"] + all_metrics["batch"][tissue] = metrics["batch"] + + return all_metrics + + +def wrangle_metrics( + metrics: Dict, + bio_metrics_to_plot: List[str], + batch_metrics_to_plot: List[str], + decimal_values: int, + normalize_entropy: bool, +): + + # Wrangle to pandas wide format + df_batch = {} + df_bio = {} + for tissue in metrics["bio"]: + df_batch[tissue] = pd.DataFrame(metrics["batch"][tissue]) + df_batch[tissue]["tissue"] = tissue + df_bio[tissue] = pd.DataFrame(metrics["bio"][tissue]) + df_bio[tissue]["tissue"] = tissue + + df_batch = pd.concat(df_batch, axis=0, ignore_index=True) + df_bio = pd.concat(df_bio, axis=0, ignore_index=True) + + df_batch = df_batch.rename(columns={"batch_label":"label"}) + df_bio = df_bio.rename(columns={"bio_label":"label"}) + + if "entropy" in df_batch and normalize_entropy: + df_batch["entropy"] = df_batch["entropy"] / df_batch["entropy"].max() + + # Wrangle Batch to pandas long format + df_batch_long = df_batch.rename(columns = {i: "Value_" + i for i in batch_metrics_to_plot}) + df_batch_long = pd.wide_to_long( + df_batch_long, + stubnames="Value", + i=['embedding', 'label', 'tissue'], + j='metric', + sep="_", + suffix='.*', + ).reset_index() + df_batch_long["Metric Type"] = "Batch-correction" + + # Wrangle Bio to pandas long format + df_bio_long = df_bio.rename(columns = {i: "Value_" + i for i in bio_metrics_to_plot}) + df_bio_long = pd.wide_to_long( + df_bio_long, + stubnames="Value", + i=['embedding', 'label', 'tissue'], + j='metric', + sep="_", + suffix='.*', + ).reset_index() + df_bio_long["Metric Type"] = "Bio-conservation" + + # Join Bio and Batch metrics, rename columns for readability + df_long = pd.concat([df_batch_long, df_bio_long]) + df_long["Value"] = round(df_long["Value"], decimal_values) + df_long = df_long.rename(columns={"embedding": "Embedding"}) + + # Add column to classify "metrics mode", classifier vs embedding + df_long["metric_mode"] = "Embedding space" + df_long.loc[df_long["metric"].str.contains("classi"), "metric_mode"] = "Label classifier" + + df_summary = df_long.groupby(["Embedding", "Metric Type", "metric_mode"], as_index=False).mean("Value") + df_summary["Value"] = round(df_summary["Value"], decimal_values) + + return df_long, df_summary + +metrics_pt_theme = pt.theme( + axis_text_x=pt.element_text(rotation=30, hjust=1), + panel_grid_major=pt.element_blank(), + panel_grid_minor=pt.element_blank(), + panel_border=pt.element_blank(), +) + + +def make_headtmap(df, metric_type, width=9, height=4, theme=metrics_pt_theme): + return (pt.ggplot(df.query(f"`Metric Type`=='{metric_type}'")) + + pt.aes(x="Embedding", y="metric", fill = "Value") + + pt.geom_tile() + + pt.facet_grid(f"{'label'}~{'tissue'}") + + pt.geom_text(pt.aes(label='Value')) + + pt.scale_fill_gradient(low="#d9e6f2", high="#2d5986") + + pt.theme_light() + + pt.theme(figure_size=(width, height)) + + theme + ) + + +def make_headtmap_summary(df, width=9, height=2.5, theme=metrics_pt_theme): + return (pt.ggplot(df) + + pt.aes(x="Embedding", y="Metric Type", fill = "Value") + + pt.geom_tile() + + pt.facet_wrap(f"~{'metric_mode'}", scales="free") + + pt.geom_text(pt.aes(label='Value')) + + pt.theme_light() + + pt.scale_fill_gradient(low="#d9e6f2", high="#2d5986") + + pt.theme(figure_size=(width, height)) + + theme + ) + + +if __name__ == "__main__": + main() \ No newline at end of file From 9547add394d9d34a765d9f48c54e5e773ef45779 Mon Sep 17 00:00:00 2001 From: Pablo Garcia-Nieto Date: Thu, 11 Jul 2024 17:30:14 -0700 Subject: [PATCH 31/31] update metrics plotting script --- tools/models/metrics/run-scib-plots.py | 102 ++++++++++++------------- 1 file changed, 50 insertions(+), 52 deletions(-) diff --git a/tools/models/metrics/run-scib-plots.py b/tools/models/metrics/run-scib-plots.py index 6db1a5cd6..85d6f039a 100644 --- a/tools/models/metrics/run-scib-plots.py +++ b/tools/models/metrics/run-scib-plots.py @@ -8,13 +8,12 @@ def main(): - try: file = sys.argv[1] except IndexError: file = "scib-metrics-config-plot.yaml" - print (file) + print(file) with open(file) as f: config = yaml.safe_load(f) @@ -35,23 +34,22 @@ def main(): p = make_headtmap_summary(df_summary) p.save(filename=f"{suf}metrics_0_summary.png", dpi=300) - + p = make_headtmap(df_metrics.query("metric_mode == 'Embedding space'"), "Bio-conservation") p.save(filename=f"{suf}metrics_1_bio_emb.png", dpi=300) - + p = make_headtmap(df_metrics.query("metric_mode == 'Label classifier'"), "Bio-conservation") p.save(filename=f"{suf}metrics_2_bio_classifier.png", dpi=300) - - p = make_headtmap(df_metrics.query("metric_mode == 'Embedding space'"), "Batch-correction", height= 5) + + p = make_headtmap(df_metrics.query("metric_mode == 'Embedding space'"), "Batch-correction", height=5) p.save(filename=f"{suf}metrics_3_batch_emb.png", dpi=300) - p = make_headtmap(df_metrics.query("metric_mode == 'Label classifier'"), "Batch-correction", height= 5) + p = make_headtmap(df_metrics.query("metric_mode == 'Label classifier'"), "Batch-correction", height=5) p.save(filename=f"{suf}metrics_4_batch_classifier.png", dpi=300) def load_metrics(metrics_file_dict: Dict[str, str]): - - all_metrics = {"bio": {}, "batch":{}} + all_metrics = {"bio": {}, "batch": {}} for tissue in metrics_file_dict: metrics_file = metrics_file_dict[tissue] with open(metrics_file, "rb") as op: @@ -59,14 +57,12 @@ def load_metrics(metrics_file_dict: Dict[str, str]): # Expand classfier metrics into individual keys for metric_type in metrics: - classifier_types = list(metrics[metric_type]['classifier'][0].keys()) + classifier_types = list(metrics[metric_type]["classifier"][0].keys()) for classifier_type in classifier_types: metrics[metric_type][f"classifier_{classifier_type}"] = [] - for i in metrics[metric_type]['classifier']: - metrics[metric_type][f"classifier_{classifier_type}"].append( - i[classifier_type] - ) - del metrics[metric_type]['classifier'] + for i in metrics[metric_type]["classifier"]: + metrics[metric_type][f"classifier_{classifier_type}"].append(i[classifier_type]) + del metrics[metric_type]["classifier"] all_metrics["bio"][tissue] = metrics["bio"] all_metrics["batch"][tissue] = metrics["batch"] @@ -75,14 +71,13 @@ def load_metrics(metrics_file_dict: Dict[str, str]): def wrangle_metrics( - metrics: Dict, - bio_metrics_to_plot: List[str], + metrics: Dict, + bio_metrics_to_plot: List[str], batch_metrics_to_plot: List[str], decimal_values: int, normalize_entropy: bool, ): - - # Wrangle to pandas wide format + # Wrangle to pandas wide format df_batch = {} df_bio = {} for tissue in metrics["bio"]: @@ -90,37 +85,37 @@ def wrangle_metrics( df_batch[tissue]["tissue"] = tissue df_bio[tissue] = pd.DataFrame(metrics["bio"][tissue]) df_bio[tissue]["tissue"] = tissue - + df_batch = pd.concat(df_batch, axis=0, ignore_index=True) df_bio = pd.concat(df_bio, axis=0, ignore_index=True) - - df_batch = df_batch.rename(columns={"batch_label":"label"}) - df_bio = df_bio.rename(columns={"bio_label":"label"}) + + df_batch = df_batch.rename(columns={"batch_label": "label"}) + df_bio = df_bio.rename(columns={"bio_label": "label"}) if "entropy" in df_batch and normalize_entropy: df_batch["entropy"] = df_batch["entropy"] / df_batch["entropy"].max() # Wrangle Batch to pandas long format - df_batch_long = df_batch.rename(columns = {i: "Value_" + i for i in batch_metrics_to_plot}) + df_batch_long = df_batch.rename(columns={i: "Value_" + i for i in batch_metrics_to_plot}) df_batch_long = pd.wide_to_long( df_batch_long, stubnames="Value", - i=['embedding', 'label', 'tissue'], - j='metric', + i=["embedding", "label", "tissue"], + j="metric", sep="_", - suffix='.*', + suffix=".*", ).reset_index() df_batch_long["Metric Type"] = "Batch-correction" # Wrangle Bio to pandas long format - df_bio_long = df_bio.rename(columns = {i: "Value_" + i for i in bio_metrics_to_plot}) + df_bio_long = df_bio.rename(columns={i: "Value_" + i for i in bio_metrics_to_plot}) df_bio_long = pd.wide_to_long( df_bio_long, stubnames="Value", - i=['embedding', 'label', 'tissue'], - j='metric', + i=["embedding", "label", "tissue"], + j="metric", sep="_", - suffix='.*', + suffix=".*", ).reset_index() df_bio_long["Metric Type"] = "Bio-conservation" @@ -129,15 +124,16 @@ def wrangle_metrics( df_long["Value"] = round(df_long["Value"], decimal_values) df_long = df_long.rename(columns={"embedding": "Embedding"}) - # Add column to classify "metrics mode", classifier vs embedding + # Add column to classify "metrics mode", classifier vs embedding df_long["metric_mode"] = "Embedding space" df_long.loc[df_long["metric"].str.contains("classi"), "metric_mode"] = "Label classifier" df_summary = df_long.groupby(["Embedding", "Metric Type", "metric_mode"], as_index=False).mean("Value") df_summary["Value"] = round(df_summary["Value"], decimal_values) - + return df_long, df_summary + metrics_pt_theme = pt.theme( axis_text_x=pt.element_text(rotation=30, hjust=1), panel_grid_major=pt.element_blank(), @@ -147,30 +143,32 @@ def wrangle_metrics( def make_headtmap(df, metric_type, width=9, height=4, theme=metrics_pt_theme): - return (pt.ggplot(df.query(f"`Metric Type`=='{metric_type}'")) - + pt.aes(x="Embedding", y="metric", fill = "Value") - + pt.geom_tile() - + pt.facet_grid(f"{'label'}~{'tissue'}") - + pt.geom_text(pt.aes(label='Value')) - + pt.scale_fill_gradient(low="#d9e6f2", high="#2d5986") - + pt.theme_light() - + pt.theme(figure_size=(width, height)) - + theme + return ( + pt.ggplot(df.query(f"`Metric Type`=='{metric_type}'")) + + pt.aes(x="Embedding", y="metric", fill="Value") + + pt.geom_tile() + + pt.facet_grid(f"{'label'}~{'tissue'}") + + pt.geom_text(pt.aes(label="Value")) + + pt.scale_fill_gradient(low="#d9e6f2", high="#2d5986") + + pt.theme_light() + + pt.theme(figure_size=(width, height)) + + theme ) def make_headtmap_summary(df, width=9, height=2.5, theme=metrics_pt_theme): - return (pt.ggplot(df) - + pt.aes(x="Embedding", y="Metric Type", fill = "Value") - + pt.geom_tile() - + pt.facet_wrap(f"~{'metric_mode'}", scales="free") - + pt.geom_text(pt.aes(label='Value')) - + pt.theme_light() - + pt.scale_fill_gradient(low="#d9e6f2", high="#2d5986") - + pt.theme(figure_size=(width, height)) - + theme + return ( + pt.ggplot(df) + + pt.aes(x="Embedding", y="Metric Type", fill="Value") + + pt.geom_tile() + + pt.facet_wrap(f"~{'metric_mode'}", scales="free") + + pt.geom_text(pt.aes(label="Value")) + + pt.theme_light() + + pt.scale_fill_gradient(low="#d9e6f2", high="#2d5986") + + pt.theme(figure_size=(width, height)) + + theme ) if __name__ == "__main__": - main() \ No newline at end of file + main()