Skip to content

Commit

Permalink
Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 22, 2024
1 parent 46b5d23 commit 6e9d788
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
6 changes: 1 addition & 5 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3201,11 +3201,7 @@ def trees_to_dataframe(self, fmap: PathLike = "") -> DataFrame:
}
)

if callable(getattr(df, "sort_values", None)):
# pylint: disable=no-member
return df.sort_values(["Tree", "Node"]).reset_index(drop=True)
# pylint: disable=no-member
return df.sort(["Tree", "Node"]).reset_index(drop=True)
return df.sort_values(["Tree", "Node"]).reset_index(drop=True)

def _assign_dmatrix_features(self, data: DMatrix) -> None:
if data.num_row() == 0:
Expand Down
19 changes: 9 additions & 10 deletions python-package/xgboost/dask/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,26 +192,25 @@ def sort_data_by_qid(**kwargs: List[Any]) -> Dict[str, List[Any]]:
else:
from pandas import DataFrame

def get_dict(i: int) -> dict:
def _get(attr: Optional[List[Any]]) -> Optional[Any]:
def get_dict(i: int) -> Dict[str, list]:
"""Return a dictionary containing all the meta info and all partitions."""

def _get(attr: Optional[List[Any]]) -> Optional[list]:
if attr is not None:
return attr[i]
return None

data = {k: _get(kwargs.get(k, None)) for k in meta}
data = {k: v for k, v in data.items() if v is not None}
data_opt = {name: _get(kwargs.get(name, None)) for name in meta}
# Filter out None values.
data = {k: v for k, v in data_opt.items() if v is not None}
return data

# This function was created for the `dd.from_mapq constructor for sorting with a
# Dask DF. We did not proceed with that route but kept some of the utilities. It
# might be necessary to try again in the future since concatenating and sorting is
# extremely expensive in terms of memory usage.
def map_fn(i: int) -> pd.DataFrame:
data = get_dict(i)
return DataFrame(data)

qid_parts = [map_fn(i) for i in range(n_parts)]
dfq = concat(qid_parts)
meta_parts = [map_fn(i) for i in range(n_parts)]
dfq = concat(meta_parts)
if dfq.qid.is_monotonic_increasing:
return kwargs

Expand Down

0 comments on commit 6e9d788

Please sign in to comment.