Skip to content

Commit

Permalink
[dask] use more specific method names on _DaskLGBMModel (#4004)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Feb 20, 2021
1 parent 7f91dc6 commit 646267d
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def client_(self) -> Client:

return _get_dask_client(client=self.client)

def _lgb_getstate(self) -> Dict[Any, Any]:
def _lgb_dask_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self._other_params.pop("client", None)
Expand All @@ -474,7 +474,7 @@ def _lgb_getstate(self) -> Dict[Any, Any]:
self.client = client
return out

def _fit(
def _lgb_dask_fit(
self,
model_factory: Type[LGBMModel],
X: _DaskMatrixLike,
Expand All @@ -501,20 +501,20 @@ def _fit(
)

self.set_params(**model.get_params())
self._copy_extra_params(model, self)
self._lgb_dask_copy_extra_params(model, self)

return self

def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
params = self.get_params()
params.pop("client", None)
model = model_factory(**params)
self._copy_extra_params(self, model)
self._lgb_dask_copy_extra_params(self, model)
model._other_params.pop("client", None)
return model

@staticmethod
def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
params = source.get_params()
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
Expand Down Expand Up @@ -590,7 +590,7 @@ def __init__(
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()

def fit(
self,
Expand All @@ -600,7 +600,7 @@ def fit(
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMClassifier,
X=X,
y=y,
Expand Down Expand Up @@ -670,7 +670,7 @@ def to_local(self) -> LGBMClassifier:
model : lightgbm.LGBMClassifier
Local underlying model.
"""
return self._to_local(LGBMClassifier)
return self._lgb_dask_to_local(LGBMClassifier)


class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
Expand Down Expand Up @@ -741,7 +741,7 @@ def __init__(
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()

def fit(
self,
Expand All @@ -751,7 +751,7 @@ def fit(
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMRegressor,
X=X,
y=y,
Expand Down Expand Up @@ -802,7 +802,7 @@ def to_local(self) -> LGBMRegressor:
model : lightgbm.LGBMRegressor
Local underlying model.
"""
return self._to_local(LGBMRegressor)
return self._lgb_dask_to_local(LGBMRegressor)


class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
Expand Down Expand Up @@ -873,7 +873,7 @@ def __init__(
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()

def fit(
self,
Expand All @@ -888,7 +888,7 @@ def fit(
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')

return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMRanker,
X=X,
y=y,
Expand Down Expand Up @@ -939,4 +939,4 @@ def to_local(self) -> LGBMRanker:
model : lightgbm.LGBMRanker
Local underlying model.
"""
return self._to_local(LGBMRanker)
return self._lgb_dask_to_local(LGBMRanker)

0 comments on commit 646267d

Please sign in to comment.