diff --git a/HISTORY.md b/HISTORY.md index 9d6eec0f..8b89f001 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,6 +8,7 @@ ### Models * Add `User.annotated_observations_count` field * Add `root_id` filter to `taxon.make_tree()` to explicitly set the root taxon instead of determining it automatically +* Fix `taxon.make_tree()` rank filtering to allow skipping any number of rank levels ### Rate limits, timeouts, and error handling * Increase default request timeout from 10 to 20 seconds diff --git a/pyinaturalist/models/taxon.py b/pyinaturalist/models/taxon.py index 6952c8a0..48ea33f4 100644 --- a/pyinaturalist/models/taxon.py +++ b/pyinaturalist/models/taxon.py @@ -438,16 +438,25 @@ def add_descendants(taxon, ancestors=None) -> Taxon: """Recursively add children and ancestors to a taxon""" taxon.children = [] taxon.ancestors = ancestors or [] - for child in taxa_by_parent.get(taxon.id, []): + for child in get_included_children(taxon): child = add_descendants(child, taxon.ancestors + [taxon]) - if include_ranks and child.rank not in include_ranks: - taxon.children.extend(child.children) - else: - taxon.children.append(child) + taxon.children.append(child) taxon.children = sorted(taxon.children, key=sort_key) return taxon + def included(taxon: Taxon) -> bool: + return not include_ranks or taxon.rank in include_ranks + + def get_included_children(taxon: Taxon) -> List[Taxon]: + """Get taxon children. If any child ranks are excluded, get the next level of descendants + that are included.""" + immediate_children = taxa_by_parent.get(taxon.id, []) + children = [c for c in immediate_children if included(c)] + for c in [c for c in immediate_children if not included(c)]: + children.extend(get_included_children(c)) + return children + return add_descendants(root) diff --git a/test/test_models.py b/test/test_models.py index eddd315b..4890b289 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1189,11 +1189,10 @@ def test_make_tree__flattened_without_root(): def test_make_tree__flattened_filtered(): flat_list = make_tree( Taxon.from_json_list(j_life_list_2), - include_ranks=['kingdom', 'phylum', 'family', 'genus', 'subgenus'], + include_ranks=['kingdom', 'family', 'genus', 'subgenus'], ).flatten() assert [t.id for t in flat_list] == [ 1, - 47120, 47221, 52775, 538903, @@ -1202,7 +1201,7 @@ def test_make_tree__flattened_filtered(): 415027, 538902, ] - assert [t.indent_level for t in flat_list] == [0, 1, 2, 3, 4, 4, 4, 4, 4] + assert [t.indent_level for t in flat_list] == [0, 1, 2, 3, 3, 3, 3, 3] assert flat_list[0].ancestors == [] assert [t.id for t in flat_list[1].ancestors] == [1]