From ddb0f857864e217bf83802a7a5f0ec74758b54d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edyta=20Fr=C4=85szczak?= Date: Thu, 4 Jul 2024 12:17:04 +0200 Subject: [PATCH] code improvements --- .github/CODE_OF_CONDUCT.md | 2 +- .github/CONTRIBUTING.md | 14 +-- .github/ISSUE_TEMPLATE/bug_report.md | 2 +- .github/ISSUE_TEMPLATE/feature_request.md | 2 +- .idea/misc.xml | 4 +- CHANGELOG.md | 14 +-- README.md | 54 ++++++----- src/aa.py | 15 +-- src/nsdlib/algorithms/__init__.py | 91 +++++++++++++++++++ src/nsdlib/algorithms/algorithms_utils.py | 56 +++++++----- .../algorithms/node_evaluation/__init__.py | 1 + .../algorithms/node_evaluation/net_sleuth.py | 8 +- .../algorithms/reconstruction/__init__.py | 1 - src/nsdlib/algorithms/reconstruction/sbrp.py | 21 +++-- src/nsdlib/algorithms/reconstruction/utils.py | 3 +- src/nsdlib/setup.py | 2 +- src/nsdlib/source_detection.py | 70 ++++++++------ 17 files changed, 229 insertions(+), 131 deletions(-) diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md index d2f3186..55b979c 100644 --- a/.github/CODE_OF_CONDUCT.md +++ b/.github/CODE_OF_CONDUCT.md @@ -1,4 +1,4 @@ -# Code of Conduct for the Network Centrality Library Project +# Code of Conduct for the Network Source Detection Library) Project ## Our Pledge diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 391bf8c..6992825 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -22,19 +22,19 @@ We warmly welcome contributions to NSDLib! This document provides guidelines for ### Implementation Requirements -- **Centrality Measures Implementation**: - - Each centrality measure must be implemented in a separate file within the `nsdlib/algorithms` directory. - - The file name should match the centrality measure's name. - - Each file must contain a single function, named after the centrality measure, that calculates this measure. This function should accept a NetworkX graph as input and return a dictionary mapping nodes to their centrality values. - - Each centrality measure function must be exposed in the `nsdlib/algorithms` package to be accessible for external use. - - Add an entry for the new centrality measure in the `Centrality` enum to ensure it's recognized and accessible through a standardized interface. +- **Source Detection Method Implementation**: + - Each new method must be implemented in a separate file within the `nsdlib/algorithms` directory in appropriate package according to its intended purpose e.g. reconstruction algorithm should be placed in `reconstruction` package. + - The file name should match the method's name. + - Each file must contain a single function, named after the new method name. + - Each alg function must be exposed in the `nsdlib/algorithms` package to be accessible for external use. + - Add an entry for the new alg in the appropiate taxonomy class, e.g. for reconstruction algorithm new entry should be placed into `PropagationReconstructionAlgorithm` enum to ensure it's recognized and accessible through a standardized interface. - **Testing**: - Contributions must include tests covering the new functionality. We require at least 80% test coverage for changes. - Use the `pytest` framework for writing tests. - **Documentation**: - - Update the project documentation to reflect the addition of new centrality measures or any other significant changes. + - Update the project documentation to reflect the addition of new method or any other significant changes. - Ensure that examples, usage guides, and API documentation are clear and updated. ### Making Changes diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 0bd6daa..7a813e3 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,6 +1,6 @@ --- name: Bug Report -about: Create a report to help us improve the network centrality library +about: Create a report to help us improve the NSDLib title: "[BUG]" labels: bug assignees: '' diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 7b8d4d4..33de5a6 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,6 +1,6 @@ --- name: Feature Request -about: Suggest an idea for the network centrality library +about: Suggest an idea for the NSDLib title: "[FEATURE]" labels: enhancement assignees: '' diff --git a/.idea/misc.xml b/.idea/misc.xml index 22ef361..4f19ba1 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - - \ No newline at end of file + + diff --git a/CHANGELOG.md b/CHANGELOG.md index 09da0e7..e50f90b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,16 +3,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.2.1] - 2024-02-23 +## [0.1.0] - 2024-07-08 ### Added -- All common modules exported in `__init__.py` file - -## [0.2.0] - 2024-02-22 -### Added -- new centrality measure - hubbell centrality has been added -- updated maintenance related files -- extended documentation - -## [0.1.1] - 2024-01-09 -### Added -- nsdlib version 0.1.1 release +- NSDlib version 0.1.0 release diff --git a/README.md b/README.md index 0f6e46d..cd2bd53 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,21 @@ # NSDlib -NSDlib (Network source detection library) is a tool to compute a wide range of centrality measures for a given network. The -library is designed to work with Python Networkx library. +NSDlib (Network source detection library) is a comprehensive library designed for detecting sources of propagation in networks. This library offers a variety of algorithms that help researchers and developers analyze and identify the origins of information (epidemic etc.) spread within networks. ## Overview -The goal of NSDlib is to offer a comprehensive repository for implementing a broad spectrum of centrality measures. Each -year, new measures are introduced through scientific papers, often with only pseudo-code descriptions, making it -difficult for researchers to evaluate and compare them with existing methods. While implementations of well-known -centrality measures exist, recent innovations are frequently absent. NSDlib strives to bridge this gap. It references the -renowned CentiServer portal for well-known centrality measures and their originating papers, aiming to encompass all -these measures in the future. +NSDLib is a complex library designed for easy integration into existing projects. It aims to be a comprehensive repository +of source detection methods, outbreak detection techniques, and propagation graph reconstruction tools. Researchers worldwide are encouraged to contribute and utilize this library, +facilitating the development of new techniques to combat misinformation and improve propagation analysis. +Each year, new techniques are introduced through scientific papers, often with only pseudo-code descriptions, making it +difficult for researchers to evaluate and compare them with existing methods. NSDlib tries to bridge this gap and enhance researchers to put their implementations here. ## Code structure -All custom implementations are provided under `nsdlib/algorithms` package. Each centrality measure is implemented in a separate file, named after the measure itself. Correspondingly, each file contains a function, named identically to the file, which calculates the centrality measure. This function accepts a NetworkX graph as input (and other params if applicable) and returns a dictionary, mapping nodes to their centrality values. Ultimately, every custom implementation is made available through the `nsdlib/algorithms` package. -## Implemented centrality measures: +All custom implementations are provided under `nsdlib/algorithms` package. Each method is implemented in a separate file, named after the method itself and in appropriate package according to its intended purpose e.g. reconstruction algorithm should be placed in `reconstruction` package. . Correspondingly, each file contains a function, named identically to the file, which does appropriate logic. Ultimately, every custom implementation is made available through the `nsdlib/algorithms` package. +## Implemented features: +### Node evaluation algorithms - [Algebraic](https://www.centiserver.org/centrality/Algebraic_Centrality/) - [Average Distance](https://www.centiserver.org/centrality/Average_Distance/) - [Barycenter](https://www.centiserver.org/centrality/Barycenter_Centrality/) @@ -58,6 +57,12 @@ All custom implementations are provided under `nsdlib/algorithms` package. Each - [Topological](https://www.centiserver.org/centrality/Topological_Coefficient/) - [Trophic Levels](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.trophic_levels.html) +### Outbreak detection algorithms +- test + +### Graph reconstruction algorithms +- SbRP + ## How to use Library can be installed using pip: @@ -69,31 +74,32 @@ pip install nsdlib Provided algorithms can be executed in the following ways: -- by invoking a specific function from `nsdlib.algorithms` package, which computes a given centrality measure for a - given graph. +- by utilizing 'SourceDetector' class and configuring it with 'SourceDetectionConfig' object. This approach allows for seamless source detection and result evaluation. ```python import networkx as nx -import nsdlib as ncl -# Create a graph +from nsdlib.common.models import SourceDetectionConfig +from nsdlib.source_detection import SourceDetector +from nsdlib.taxonomies import NodeEvaluationAlgorithm + + G = nx.karate_club_graph() -# Compute degree centrality -degree_centrality = ncl.degree_centrality(G) +config = SourceDetectionConfig( + node_evaluation_algorithm=NodeEvaluationAlgorithm.NETSLEUTH, +) + +source_detector = SourceDetector(config) -# Compute betweenness centrality -betweenness_centrality = ncl.betweenness_centrality(G) +result, evaluation = source_detector.detect_sources_and_evaluate(G=G, + IG=G, real_sources=[0,33]) +print(evaluation) -# Compute closeness centrality -closeness_centrality = ncl.closeness_centrality(G) -# Compute eigenvector centrality -eigenvector_centrality = ncl.eigenvector_centrality(G) ``` -- invoking `compute_centrality` method of `CentralityService` class, which allows to compute centrality for a given - centrality measure. +- by importing and using specific method: ```python from typing import Any diff --git a/src/aa.py b/src/aa.py index c6d8f0c..b860a01 100644 --- a/src/aa.py +++ b/src/aa.py @@ -1,23 +1,18 @@ import networkx as nx -import netcenlib as ncl + from nsdlib.common.models import SourceDetectionConfig from nsdlib.source_detection import SourceDetector -from nsdlib.taxonomies import OutbreaksDetectionAlgorithm, \ - NodeEvaluationAlgorithm +from nsdlib.taxonomies import NodeEvaluationAlgorithm -# Create a graph G = nx.karate_club_graph() config = SourceDetectionConfig( - selection_threshold=None, node_evaluation_algorithm=NodeEvaluationAlgorithm.NETSLEUTH, ) source_detector = SourceDetector(config) -result, evaluation = source_detector.detect_sources_and_evaluate(G=G, - IG=G, real_sources=[0,33]) +result, evaluation = source_detector.detect_sources_and_evaluate( + G=G, IG=G, real_sources=[0, 33] +) print(result.global_scores) -print(ncl.degree_centrality(G)) - -print(evaluation) diff --git a/src/nsdlib/algorithms/__init__.py b/src/nsdlib/algorithms/__init__.py index e69de29..30d4115 100644 --- a/src/nsdlib/algorithms/__init__.py +++ b/src/nsdlib/algorithms/__init__.py @@ -0,0 +1,91 @@ +# flake8: noq + +from nsdlib.algorithms.outbreaks_detection import ( + CPM_Bipartite as outbreaks_detection_CPM_Bipartite, + agdl as outbreaks_detection_agdl, + angel as outbreaks_detection_angel, + aslpaw as outbreaks_detection_aslpaw, + async_fluid as outbreaks_detection_async_fluid, + attribute_clustering as outbreaks_detection_attribute_clustering, + bayan as outbreaks_detection_bayan, + belief as outbreaks_detection_belief, + bimlpa as outbreaks_detection_bimlpa, + bipartite_clustering as outbreaks_detection_bipartite_clustering, + coach as outbreaks_detection_coach, + condor as outbreaks_detection_condor, + conga as outbreaks_detection_conga, + congo as outbreaks_detection_congo, + core_expansion as outbreaks_detection_core_expansion, + cpm as outbreaks_detection_cpm, + crisp_partition as outbreaks_detection_crisp_partition, + dcs as outbreaks_detection_dcs, + demon as outbreaks_detection_demon, + der as outbreaks_detection_der, + dpclus as outbreaks_detection_dpclus, + ebgc as outbreaks_detection_ebgc, + edge_clustering as outbreaks_detection_edge_clustering, + ego_networks as outbreaks_detection_ego_networks, + eigenvector as outbreaks_detection_eigenvector, + em as outbreaks_detection_em, + endntm as outbreaks_detection_endntm, + eva as outbreaks_detection_eva, + frc_fgsn as outbreaks_detection_frc_fgsn, + ga as outbreaks_detection_ga, + gdmp2 as outbreaks_detection_gdmp2, + girvan_newman as outbreaks_detection_girvan_newman, + graph_entropy as outbreaks_detection_graph_entropy, + greedy_modularity as outbreaks_detection_greedy_modularity, + head_tail as outbreaks_detection_head_tail, + hierarchical_link_community as outbreaks_detection_hierarchical_link_community, + ilouvain as outbreaks_detection_ilouvain, + infomap as outbreaks_detection_infomap, + infomap_bipartite as outbreaks_detection_infomap_bipartite, + internal as outbreaks_detection_internal, + internal_dcd as outbreaks_detection_internal_dcd, + ipca as outbreaks_detection_ipca, + kclique as outbreaks_detection_kclique, + kcut as outbreaks_detection_kcut, + label_propagation as outbreaks_detection_label_propagation, + lais2 as outbreaks_detection_lais2, + leiden as outbreaks_detection_leiden, + lemon as outbreaks_detection_lemon, + lfm as outbreaks_detection_lfm, + louvain as outbreaks_detection_louvain, + lpam as outbreaks_detection_lpam, + lpanni as outbreaks_detection_lpanni, + lswl as outbreaks_detection_lswl, + lswl_plus as outbreaks_detection_lswl_plus, + markov_clustering as outbreaks_detection_markov_clustering, + mcode as outbreaks_detection_mcode, + mod_m as outbreaks_detection_mod_m, + mod_r as outbreaks_detection_mod_r, + multicom as outbreaks_detection_multicom, + node_perception as outbreaks_detection_node_perception, + overlapping_partition as outbreaks_detection_overlapping_partition, + overlapping_seed_set_expansion as outbreaks_detection_overlapping_seed_set_expansion, + paris as outbreaks_detection_paris, + percomvc as outbreaks_detection_percomvc, + principled_clustering as outbreaks_detection_principled_clustering, + pycombo as outbreaks_detection_pycombo, + r_spectral_clustering as outbreaks_detection_r_spectral_clustering, + rb_pots as outbreaks_detection_rb_pots, + rber_pots as outbreaks_detection_rber_pots, + ricci_community as outbreaks_detection_ricci_community, + sbm_dl as outbreaks_detection_sbm_dl, + sbm_dl_nested as outbreaks_detection_sbm_dl_nested, + scan as outbreaks_detection_scan, + siblinarity_antichain as outbreaks_detection_siblinarity_antichain, + significance_communities as outbreaks_detection_significance_communities, + slpa as outbreaks_detection_slpa, + spectral as outbreaks_detection_spectral, + spinglass as outbreaks_detection_spinglass, + surprise_communities as outbreaks_detection_surprise_communities, + temporal_partition as outbreaks_detection_temporal_partition, + threshold_clustering as outbreaks_detection_threshold_clustering, + tiles as outbreaks_detection_tiles, + umstmo as outbreaks_detection_umstmo, + walkscan as outbreaks_detection_walkscan, + walktrap as outbreaks_detection_walktrap, + wCommunity as outbreaks_detection_wCommunity, +) +from nsdlib.algorithms.reconstruction import sbrp as reconstruction_sbrp diff --git a/src/nsdlib/algorithms/algorithms_utils.py b/src/nsdlib/algorithms/algorithms_utils.py index d956c90..3c20aca 100644 --- a/src/nsdlib/algorithms/algorithms_utils.py +++ b/src/nsdlib/algorithms/algorithms_utils.py @@ -1,15 +1,20 @@ from functools import lru_cache +from typing import Dict, List, Set, Union + from netcenlib.common import nx_cached from netcenlib.common.nx_cached import MAX_SIZE from networkx import Graph -from typing import Dict, Set, List, Union -from nsdlib.algorithms import node_evaluation, outbreaks_detection, \ - reconstruction -from nsdlib.common.models import SourceDetectionEvaluation, NODE_TYPE +from nsdlib.algorithms import ( + node_evaluation, + outbreaks_detection, + reconstruction, +) +from nsdlib.common.models import NODE_TYPE, SourceDetectionEvaluation from nsdlib.taxonomies import ( NodeEvaluationAlgorithm, OutbreaksDetectionAlgorithm, + PropagationReconstructionAlgorithm, ) @@ -18,10 +23,8 @@ def identify_outbreaks( ) -> Dict[int, list]: """Identify outbreaks in a given network.""" function_name = f"{outbreaks_alg.value.lower()}" - result = getattr(outbreaks_detection, function_name)(network, *args, - **kwargs) - return {index: community for index, community in - enumerate(result.communities)} + result = getattr(outbreaks_detection, function_name)(network, *args, **kwargs) + return {index: community for index, community in enumerate(result.communities)} def evaluate_nodes( @@ -33,8 +36,11 @@ def evaluate_nodes( def reconstruct_propagation( - G: Graph, IG: Graph, - reconstruction_alg: PropagationReconstructionAlgorithm, *args, **kwargs + G: Graph, + IG: Graph, + reconstruction_alg: PropagationReconstructionAlgorithm, + *args, + **kwargs, ): """Reconstruct the propagation of a given network.""" function_name = f"{reconstruction_alg.value.lower()}" @@ -59,24 +65,27 @@ def evaluate_nodes_cached( @lru_cache(maxsize=MAX_SIZE) def reconstruct_propagation_cached( - G: Graph, IG: Graph, - reconstruction_alg: PropagationReconstructionAlgorithm, *args, **kwargs + G: Graph, + IG: Graph, + reconstruction_alg: PropagationReconstructionAlgorithm, + *args, + **kwargs, ): """Reconstruct the propagation of a given network.""" return reconstruct_propagation(G, IG, reconstruction_alg, *args, **kwargs) def compute_error_distances( - G: Graph, not_detected_sources: Set[int], - invalid_detected_sources: Set[int] + G: Graph, not_detected_sources: Set[int], invalid_detected_sources: Set[int] ) -> Dict[NODE_TYPE, float]: """Compute the error distances for the source detection evaluation.""" if not_detected_sources and invalid_detected_sources: - return {source: - min( + return { + source: min( [ - nx_cached.shortest_path_length(G, source=source, - target=invalid_source) + nx_cached.shortest_path_length( + G, source=source, target=invalid_source + ) for invalid_source in invalid_detected_sources ] ) @@ -90,21 +99,18 @@ def compute_error_distances( def compute_source_detection_evaluation( G: Graph, real_sources: List[NODE_TYPE], - detected_sources: Union[NODE_TYPE, List[NODE_TYPE]] + detected_sources: Union[NODE_TYPE, List[NODE_TYPE]], ) -> SourceDetectionEvaluation: """Compute the evaluation of the source detection.""" detected_sources = ( - detected_sources if isinstance(detected_sources, list) else [ - detected_sources] + detected_sources if isinstance(detected_sources, list) else [detected_sources] ) - correctly_detected_sources = set(real_sources).intersection( - detected_sources) + correctly_detected_sources = set(real_sources).intersection(detected_sources) invalid_detected_sources = set(detected_sources).difference( correctly_detected_sources ) - not_detected_sources = set(real_sources).difference( - correctly_detected_sources) + not_detected_sources = set(real_sources).difference(correctly_detected_sources) P = len(real_sources) N = len(G.nodes) - P diff --git a/src/nsdlib/algorithms/node_evaluation/__init__.py b/src/nsdlib/algorithms/node_evaluation/__init__.py index da32907..caa935b 100644 --- a/src/nsdlib/algorithms/node_evaluation/__init__.py +++ b/src/nsdlib/algorithms/node_evaluation/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from netcenlib.algorithms import * + from nsdlib.algorithms.node_evaluation.dynamic_age import dynamic_age from nsdlib.algorithms.node_evaluation.jordan_center import jordan_center from nsdlib.algorithms.node_evaluation.net_sleuth import net_sleuth diff --git a/src/nsdlib/algorithms/node_evaluation/net_sleuth.py b/src/nsdlib/algorithms/node_evaluation/net_sleuth.py index 177d22f..ab10d67 100644 --- a/src/nsdlib/algorithms/node_evaluation/net_sleuth.py +++ b/src/nsdlib/algorithms/node_evaluation/net_sleuth.py @@ -23,11 +23,7 @@ def net_sleuth(network: Graph) -> Dict[int, float]: L = nx.laplacian_matrix(network).toarray() eigenvalues, eigenvectors = np.linalg.eig(L) largest_eigenvalue = max(eigenvalues) - largest_eigenvector = eigenvectors[:, - list(eigenvalues).index(largest_eigenvalue)] + largest_eigenvector = eigenvectors[:, list(eigenvalues).index(largest_eigenvalue)] - scores = { - v: largest_eigenvector[v] - for v in network.nodes - } + scores = {v: largest_eigenvector[v] for v in network.nodes} return scores diff --git a/src/nsdlib/algorithms/reconstruction/__init__.py b/src/nsdlib/algorithms/reconstruction/__init__.py index 7456be5..43b68d7 100644 --- a/src/nsdlib/algorithms/reconstruction/__init__.py +++ b/src/nsdlib/algorithms/reconstruction/__init__.py @@ -1,2 +1 @@ - from nsdlib.algorithms.reconstruction.sbrp import sbrp diff --git a/src/nsdlib/algorithms/reconstruction/sbrp.py b/src/nsdlib/algorithms/reconstruction/sbrp.py index ec2c6f0..5303c97 100644 --- a/src/nsdlib/algorithms/reconstruction/sbrp.py +++ b/src/nsdlib/algorithms/reconstruction/sbrp.py @@ -1,13 +1,16 @@ from networkx import Graph -from nsdlib.algorithms.reconstruction.utils import init_extended_network, \ - compute_neighbors_probability, NODE_INFECTION_PROBABILITY_ATTR, \ - remove_invalid_nodes +from nsdlib.algorithms.reconstruction.utils import ( + NODE_INFECTION_PROBABILITY_ATTR, + compute_neighbors_probability, + init_extended_network, + remove_invalid_nodes, +) -def sbrp(G: Graph, IG: Graph, - reconstruction_threshold=0.5, - max_iterations: int = 1) -> Graph: +def sbrp( + G: Graph, IG: Graph, reconstruction_threshold=0.5, max_iterations: int = 1 +) -> Graph: """SbRP graph reconstruction algorithm. @param G: Network @@ -29,9 +32,9 @@ def sbrp(G: Graph, IG: Graph, for neighbour in G.neighbors(node): if neighbour in IG: continue - EG.nodes[neighbour][ - NODE_INFECTION_PROBABILITY_ATTR] = compute_neighbors_probability( - G=EG, node=neighbour) + EG.nodes[neighbour][NODE_INFECTION_PROBABILITY_ATTR] = ( + compute_neighbors_probability(G=EG, node=neighbour) + ) remove_invalid_nodes(EG, reconstruction_threshold) return EG diff --git a/src/nsdlib/algorithms/reconstruction/utils.py b/src/nsdlib/algorithms/reconstruction/utils.py index c5e76fb..b02b823 100644 --- a/src/nsdlib/algorithms/reconstruction/utils.py +++ b/src/nsdlib/algorithms/reconstruction/utils.py @@ -39,8 +39,7 @@ def compute_neighbors_probability(node: NODE_TYPE, G: Graph) -> float: @return: Probability of infection for a given node """ neighbors_probability = [ - G.nodes[node][NODE_INFECTION_PROBABILITY_ATTR] for node in - nx.neighbors(G, node) + G.nodes[node][NODE_INFECTION_PROBABILITY_ATTR] for node in nx.neighbors(G, node) ] return reduce( operator.mul, diff --git a/src/nsdlib/setup.py b/src/nsdlib/setup.py index 1126c2f..dd0b677 100644 --- a/src/nsdlib/setup.py +++ b/src/nsdlib/setup.py @@ -52,7 +52,7 @@ def find_version(*path_parts): "Programming Language :: Python", "Programming Language :: Python :: 3", ], - keywords="node_importance centrality_measures centrality complex-networks", + keywords="source_detection propagation_outbreaks node_importance complex-networks", install_requires=requirements, long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/nsdlib/source_detection.py b/src/nsdlib/source_detection.py index 0144a8a..2ed641b 100644 --- a/src/nsdlib/source_detection.py +++ b/src/nsdlib/source_detection.py @@ -1,19 +1,27 @@ -"""Centrality measures for networks.""" -from typing import Dict, Union, List, Tuple +"""Source detection algorithm.""" + +from typing import Dict, List, Tuple from networkx import Graph -from nsdlib.algorithms.algorithms_utils import identify_outbreaks_cached, \ - evaluate_nodes_cached, compute_source_detection_evaluation -from nsdlib.common.models import SourceDetectionConfig, NODE_TYPE, \ - SourceDetectionResult, \ - SourceDetectionEvaluation +from nsdlib.algorithms.algorithms_utils import ( + compute_source_detection_evaluation, + evaluate_nodes_cached, + identify_outbreaks_cached, +) +from nsdlib.common.models import ( + NODE_TYPE, + SourceDetectionConfig, + SourceDetectionEvaluation, + SourceDetectionResult, +) from nsdlib.common.nx_utils import create_subgraphs_based_on_outbreaks from nsdlib.commons import normalize_dict_values class SourceDetector: """Source detection generic algorithm.""" + def __init__(self, config: SourceDetectionConfig): self.config = config @@ -30,46 +38,50 @@ def _detect_outbreaks(self, IG): network=IG, outbreaks_alg=self.config.outbreaks_detection_algorithm, ) - outbreaks = [subgraph - for subgraph in create_subgraphs_based_on_outbreaks( + outbreaks = [ + subgraph + for subgraph in create_subgraphs_based_on_outbreaks( G=IG, outbreaks=outbreaks ) - ] + ] return outbreaks - def _get_global_scores(self, - outbreaks_evaluation: List[Dict[NODE_TYPE, float]]): + def _get_global_scores(self, outbreaks_evaluation: List[Dict[NODE_TYPE, float]]): global_scores = {} for outbreak_evaluation in outbreaks_evaluation: for node, evaluation in outbreak_evaluation.items(): global_scores[node] = evaluation return global_scores - def _evaluate_outbreaks(self, outbreaks: List[Graph]) -> List[ - Dict[NODE_TYPE, float]]: + def _evaluate_outbreaks( + self, outbreaks: List[Graph] + ) -> List[Dict[NODE_TYPE, float]]: scores = [] for outbreak in outbreaks: - scores.append(evaluate_nodes_cached( - network=outbreak, - evaluation_alg=self.config.node_evaluation_algorithm, - )) + scores.append( + evaluate_nodes_cached( + network=outbreak, + evaluation_alg=self.config.node_evaluation_algorithm, + ) + ) return scores - def _select_sources(self, - outbreaks_evaluation: List[Dict[NODE_TYPE, float]]): + def _select_sources(self, outbreaks_evaluation: List[Dict[NODE_TYPE, float]]): sources = [] for outbreak_evaluation in outbreaks_evaluation: if self.config.selection_threshold is None: - sources.append( - max(outbreak_evaluation, key=outbreak_evaluation.get)) + sources.append(max(outbreak_evaluation, key=outbreak_evaluation.get)) else: outbreaks_evaluation_normalized = normalize_dict_values( - outbreak_evaluation) + outbreak_evaluation + ) sources.extend( - [node for node, evaluation in - outbreaks_evaluation_normalized.items() - if evaluation >= self.config.selection_threshold] + [ + node + for node, evaluation in outbreaks_evaluation_normalized.items() + if evaluation >= self.config.selection_threshold + ] ) return sources @@ -88,9 +100,9 @@ def detect_sources(self, IG: Graph, G: Graph) -> SourceDetectionResult: detected_sources=detected_sources, ) - def detect_sources_and_evaluate(self, IG: Graph, G: Graph, - real_sources: List[NODE_TYPE]) -> Tuple[ - SourceDetectionResult, SourceDetectionEvaluation]: + def detect_sources_and_evaluate( + self, IG: Graph, G: Graph, real_sources: List[NODE_TYPE] + ) -> Tuple[SourceDetectionResult, SourceDetectionEvaluation]: sd_result = self.detect_sources(IG, G) evaluation = compute_source_detection_evaluation(