From 7356527cb36b14854f254179d1c406abcd448b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Tue, 27 Apr 2021 22:09:50 -0500 Subject: [PATCH 01/17] include distributed tests --- .github/workflows/distributed.yml | 36 +++++++++ tests/distributed/predict.conf | 4 + tests/distributed/test_distributed.py | 108 ++++++++++++++++++++++++++ 3 files changed, 148 insertions(+) create mode 100644 .github/workflows/distributed.yml create mode 100644 tests/distributed/predict.conf create mode 100644 tests/distributed/test_distributed.py diff --git a/.github/workflows/distributed.yml b/.github/workflows/distributed.yml new file mode 100644 index 000000000000..5f92db0c6856 --- /dev/null +++ b/.github/workflows/distributed.yml @@ -0,0 +1,36 @@ +name: Distributed test + +on: + push: + branches: [tests/distributed] + +defaults: + run: + shell: bash -l {0} + +jobs: + job: + runs-on: ubuntu-latest + steps: + - name: clone repo + uses: actions/checkout@v2 + with: + submodules: 'recursive' + + - name: set up environment + uses: conda-incubator/setup-miniconda@v2 + with: + auto-activate-base: true + activate-environment: "" + + - name: Compile binary + run: | + conda install -c conda-forge cxx-compiler cmake make pytest numpy scikit-learn + mkdir build && cd build + cmake .. && make -j2 + + - name: run tests + run: | + cp lightgbm tests/distributed/ + cd python-package/ && python setup.py install --precompile + cd ../tests/distributed && pytest test_distributed.py diff --git a/tests/distributed/predict.conf b/tests/distributed/predict.conf new file mode 100644 index 000000000000..4e6a0a80d710 --- /dev/null +++ b/tests/distributed/predict.conf @@ -0,0 +1,4 @@ +task = predict +data = train.txt +input_model = model.txt +output_result = predictions.txt diff --git a/tests/distributed/test_distributed.py b/tests/distributed/test_distributed.py new file mode 100644 index 000000000000..b862644f477f --- /dev/null +++ b/tests/distributed/test_distributed.py @@ -0,0 +1,108 @@ +import copy +import subprocess +import sys +from concurrent.futures import ThreadPoolExecutor + +import lightgbm as lgb +import numpy as np +from sklearn.datasets import make_blobs, make_regression +from sklearn.metrics import accuracy_score + + +def create_data(task, n_samples=1_000): + if task == 'binary-classification': + centers = [[-4, -4], [4, 4]] + X, y = make_blobs(n_samples, centers=centers, random_state=42) + elif task == 'regression': + X, y = make_regression(n_samples, n_features=4, n_informative=2, random_state=42) + dataset = np.hstack((y[:, None], X)) + return dataset + + +def run_and_log(cmd): + process = subprocess.Popen(cmd, stdout=subprocess.PIPE) + for c in iter(lambda: process.stdout.read(1), b''): + sys.stdout.buffer.write(c) + + +class DistributedMockup: + default_config = { + 'output_model': 'model.txt', + 'machine_list_file': 'mlist.txt', + 'tree_learner': 'data', + 'force_row_wise': True, + 'verbose': 0, + 'num_boost_round': 20, + 'num_leaves': 15, + 'num_threads': 2, + } + def __init__(self, config={}, n_workers=2): + self.config = copy.deepcopy(self.default_config) + self.config.update(config) + self.config['num_machines'] = n_workers + self.n_workers = n_workers + + def worker_train(self, i): + cmd = f'./lightgbm config=train{i}.conf'.split() + if i == 0: + return run_and_log(cmd) + subprocess.run(cmd) + + def _set_ports(self): + self.listen_ports = [lgb.dask._find_random_open_port() for _ in range(self.n_workers)] + with open('mlist.txt', 'wt') as f: + for port in self.listen_ports: + f.write(f'127.0.0.1 {port}\n') + + def _write_data(self, data): + np.savetxt('train.txt', data, delimiter=',') + for i, partition in enumerate(np.array_split(data, self.n_workers)): + np.savetxt(f'train{i}.txt', partition, delimiter=',') + + def fit(self, data): + self._write_data(data) + self.label_ = data[:, 0] + self._set_ports() + futures = [] + with ThreadPoolExecutor(max_workers=self.n_workers) as executor: + for i in range(self.n_workers): + self.write_train_config(i) + futures.append(executor.submit(self.worker_train, i)) + results = [f.result() for f in futures] + + def predict(self): + cmd = './lightgbm config=predict.conf'.split() + run_and_log(cmd) + y_pred = np.loadtxt('predictions.txt') + return y_pred + + def write_train_config(self, i): + with open(f'train{i}.conf', 'wt') as f: + f.write('task = train\n') + f.write(f'local_listen_port = {self.listen_ports[i]}\n') + f.write(f'data = train{i}.txt\n') + for param, value in self.config.items(): + f.write(f'{param} = {value}\n') + + +def test_classifier(): + data = create_data(task='binary-classification') + params = { + 'objective': 'binary', + } + clf = DistributedMockup(params) + clf.fit(data) + y_probas = clf.predict() + y_pred = y_probas > 0.5 + assert accuracy_score(clf.label_, y_pred) == 1. + + +def test_regressor(): + data = create_data(task='regression') + params = { + 'objective': 'regression', + } + reg = DistributedMockup(params) + reg.fit(data) + y_pred = reg.predict() + np.testing.assert_allclose(y_pred, reg.label_, rtol=0.5, atol=50.) \ No newline at end of file From 711009c01b9de6115ad795b35a5c8f5bfed232ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Tue, 4 May 2021 18:29:36 -0500 Subject: [PATCH 02/17] remove github action file --- .github/workflows/distributed.yml | 36 --------------------------- tests/distributed/test_distributed.py | 15 +++++------ 2 files changed, 8 insertions(+), 43 deletions(-) delete mode 100644 .github/workflows/distributed.yml diff --git a/.github/workflows/distributed.yml b/.github/workflows/distributed.yml deleted file mode 100644 index 5f92db0c6856..000000000000 --- a/.github/workflows/distributed.yml +++ /dev/null @@ -1,36 +0,0 @@ -name: Distributed test - -on: - push: - branches: [tests/distributed] - -defaults: - run: - shell: bash -l {0} - -jobs: - job: - runs-on: ubuntu-latest - steps: - - name: clone repo - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - name: set up environment - uses: conda-incubator/setup-miniconda@v2 - with: - auto-activate-base: true - activate-environment: "" - - - name: Compile binary - run: | - conda install -c conda-forge cxx-compiler cmake make pytest numpy scikit-learn - mkdir build && cd build - cmake .. && make -j2 - - - name: run tests - run: | - cp lightgbm tests/distributed/ - cd python-package/ && python setup.py install --precompile - cd ../tests/distributed && pytest test_distributed.py diff --git a/tests/distributed/test_distributed.py b/tests/distributed/test_distributed.py index b862644f477f..c523c8ce5026 100644 --- a/tests/distributed/test_distributed.py +++ b/tests/distributed/test_distributed.py @@ -16,17 +16,18 @@ def create_data(task, n_samples=1_000): elif task == 'regression': X, y = make_regression(n_samples, n_features=4, n_informative=2, random_state=42) dataset = np.hstack((y[:, None], X)) - return dataset + return dataset def run_and_log(cmd): process = subprocess.Popen(cmd, stdout=subprocess.PIPE) - for c in iter(lambda: process.stdout.read(1), b''): + for c in iter(lambda: process.stdout.read(1), b''): sys.stdout.buffer.write(c) class DistributedMockup: default_config = { + 'task': 'train', 'output_model': 'model.txt', 'machine_list_file': 'mlist.txt', 'tree_learner': 'data', @@ -36,6 +37,7 @@ class DistributedMockup: 'num_leaves': 15, 'num_threads': 2, } + def __init__(self, config={}, n_workers=2): self.config = copy.deepcopy(self.default_config) self.config.update(config) @@ -47,7 +49,7 @@ def worker_train(self, i): if i == 0: return run_and_log(cmd) subprocess.run(cmd) - + def _set_ports(self): self.listen_ports = [lgb.dask._find_random_open_port() for _ in range(self.n_workers)] with open('mlist.txt', 'wt') as f: @@ -68,8 +70,8 @@ def fit(self, data): for i in range(self.n_workers): self.write_train_config(i) futures.append(executor.submit(self.worker_train, i)) - results = [f.result() for f in futures] - + _ = [f.result() for f in futures] + def predict(self): cmd = './lightgbm config=predict.conf'.split() run_and_log(cmd) @@ -78,7 +80,6 @@ def predict(self): def write_train_config(self, i): with open(f'train{i}.conf', 'wt') as f: - f.write('task = train\n') f.write(f'local_listen_port = {self.listen_ports[i]}\n') f.write(f'data = train{i}.txt\n') for param, value in self.config.items(): @@ -105,4 +106,4 @@ def test_regressor(): reg = DistributedMockup(params) reg.fit(data) y_pred = reg.predict() - np.testing.assert_allclose(y_pred, reg.label_, rtol=0.5, atol=50.) \ No newline at end of file + np.testing.assert_allclose(y_pred, reg.label_, rtol=0.5, atol=50.) From e3c96a8f25c568e37c527ca7e82c8895fffc2466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Wed, 5 May 2021 19:08:49 -0500 Subject: [PATCH 03/17] try CI --- .ci/test.sh | 8 ++++++++ .vsts-ci.yml | 2 ++ .../{test_distributed.py => _test_distributed.py} | 0 3 files changed, 10 insertions(+) rename tests/distributed/{test_distributed.py => _test_distributed.py} (100%) diff --git a/.ci/test.sh b/.ci/test.sh index 0ee4695f39c2..583c0990f6f1 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -107,6 +107,14 @@ fi conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy pip install graphviz # python-graphviz from Anaconda is not allowed to be installed with Python 3.9 +if [[ $TASK == "cli-distributed" ]]; then + mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build && cmake .. && make lightgbm -j4 || exit -1 + cp $BUILD_DIRECTORY/lightgbm $BUILD_DIRECTORY/tests/distributed/ || exit -1 + cd $BUILD_DIRECTORY/python-package/ && python setup.py install --precompile || exit -1 + cd $BUILD_DIRECTORY/tests/distributed && pytest _test_distributed.py || exit -1 + exit 0 +fi + if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then # fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL) for LIBOMP_ALIAS in libgomp.dylib libiomp5.dylib libomp.dylib; do sudo ln -sf "$(brew --cellar libomp)"/*/lib/libomp.dylib $CONDA_PREFIX/lib/$LIBOMP_ALIAS || exit -1; done diff --git a/.vsts-ci.yml b/.vsts-ci.yml index 30702b31467b..ee93cae99726 100644 --- a/.vsts-ci.yml +++ b/.vsts-ci.yml @@ -119,6 +119,8 @@ jobs: TASK: gpu METHOD: wheel PYTHON_VERSION: 3.7 + distributed: + TASK: cli-distributed steps: - script: | echo "##vso[task.setvariable variable=BUILD_DIRECTORY]$BUILD_SOURCESDIRECTORY" diff --git a/tests/distributed/test_distributed.py b/tests/distributed/_test_distributed.py similarity index 100% rename from tests/distributed/test_distributed.py rename to tests/distributed/_test_distributed.py From eabb3b7db67b4c8fd230800c64386cdc1f64d254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Wed, 5 May 2021 19:18:27 -0500 Subject: [PATCH 04/17] build shared library and fix linting error --- .ci/test.sh | 2 +- tests/distributed/_test_distributed.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 583c0990f6f1..1d3ae45f9d6f 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -108,7 +108,7 @@ conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib pip install graphviz # python-graphviz from Anaconda is not allowed to be installed with Python 3.9 if [[ $TASK == "cli-distributed" ]]; then - mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build && cmake .. && make lightgbm -j4 || exit -1 + mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build && cmake .. && make -j4 || exit -1 cp $BUILD_DIRECTORY/lightgbm $BUILD_DIRECTORY/tests/distributed/ || exit -1 cd $BUILD_DIRECTORY/python-package/ && python setup.py install --precompile || exit -1 cd $BUILD_DIRECTORY/tests/distributed && pytest _test_distributed.py || exit -1 diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index c523c8ce5026..96a271c2002c 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -3,11 +3,12 @@ import sys from concurrent.futures import ThreadPoolExecutor -import lightgbm as lgb import numpy as np from sklearn.datasets import make_blobs, make_regression from sklearn.metrics import accuracy_score +import lightgbm as lgb + def create_data(task, n_samples=1_000): if task == 'binary-classification': From 9a4ad3bce3f2c0ac3c088a7f804fc3217dc9da6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Mon, 10 May 2021 19:38:19 -0500 Subject: [PATCH 05/17] ignore files created for testing. add type hints and check with mypy. include docstrings --- .ci/test.sh | 2 + .gitignore | 4 ++ tests/distributed/_test_distributed.py | 64 +++++++++++++++++++------- 3 files changed, 53 insertions(+), 17 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 1d3ae45f9d6f..1bbed0e29c21 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -108,6 +108,8 @@ conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib pip install graphviz # python-graphviz from Anaconda is not allowed to be installed with Python 3.9 if [[ $TASK == "cli-distributed" ]]; then + pip install --user mypy + mypy --ignore-missing-imports tests/distributed/ mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build && cmake .. && make -j4 || exit -1 cp $BUILD_DIRECTORY/lightgbm $BUILD_DIRECTORY/tests/distributed/ || exit -1 cd $BUILD_DIRECTORY/python-package/ && python setup.py install --precompile || exit -1 diff --git a/.gitignore b/.gitignore index 66b8a9b4acff..2ee306e155a7 100644 --- a/.gitignore +++ b/.gitignore @@ -430,6 +430,10 @@ miktex*.zip **/lgb.Dataset.data **/model.txt **/lgb-model.txt +tests/distributed/mlist.txt +tests/distributed/train* +tests/distributed/predictions*.txt + # Files from interactive R sessions .Rproj.user diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 96a271c2002c..6284001f51e4 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -2,6 +2,7 @@ import subprocess import sys from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List import numpy as np from sklearn.datasets import make_blobs, make_regression @@ -10,19 +11,24 @@ import lightgbm as lgb -def create_data(task, n_samples=1_000): +def create_data(task: str, n_samples: int = 1_000) -> np.ndarray: + """Creates the appropiate data for the task. + The data is returned as a numpy array with the label as the first column.""" if task == 'binary-classification': centers = [[-4, -4], [4, 4]] X, y = make_blobs(n_samples, centers=centers, random_state=42) elif task == 'regression': X, y = make_regression(n_samples, n_features=4, n_informative=2, random_state=42) - dataset = np.hstack((y[:, None], X)) + dataset = np.hstack([y.reshape(-1, 1), X]) return dataset -def run_and_log(cmd): +def run_and_log(cmd: List[str]) -> None: + """Run `cmd` in another process and pipe its logs to this process' stdout.""" process = subprocess.Popen(cmd, stdout=subprocess.PIPE) - for c in iter(lambda: process.stdout.read(1), b''): + assert process.stdout is not None + stdout_stream = lambda: process.stdout.read(1) + for c in iter(stdout_stream, b''): sys.stdout.buffer.write(c) @@ -39,32 +45,45 @@ class DistributedMockup: 'num_threads': 2, } - def __init__(self, config={}, n_workers=2): + def __init__(self, config: Dict = {}, n_workers: int = 2): self.config = copy.deepcopy(self.default_config) self.config.update(config) self.config['num_machines'] = n_workers self.n_workers = n_workers - def worker_train(self, i): + def worker_train(self, i: int) -> None: + """Start the training process on the `i`-th worker. + If this is the first worker, its logs are piped to stdout.""" cmd = f'./lightgbm config=train{i}.conf'.split() if i == 0: return run_and_log(cmd) subprocess.run(cmd) - def _set_ports(self): + def _set_ports(self) -> None: + """Randomly assign a port for training to each worker and save all ports to mlist.txt.""" self.listen_ports = [lgb.dask._find_random_open_port() for _ in range(self.n_workers)] with open('mlist.txt', 'wt') as f: for port in self.listen_ports: f.write(f'127.0.0.1 {port}\n') - def _write_data(self, data): - np.savetxt('train.txt', data, delimiter=',') - for i, partition in enumerate(np.array_split(data, self.n_workers)): + def _write_data(self, partitions: List[np.ndarray]) -> None: + """Write all training data as train.txt and each training partition as train{i}.txt.""" + all_data = np.vstack(partitions) + np.savetxt('train.txt', all_data, delimiter=',') + for i, partition in enumerate(partitions): np.savetxt(f'train{i}.txt', partition, delimiter=',') - def fit(self, data): - self._write_data(data) - self.label_ = data[:, 0] + def fit(self, partitions: List[np.ndarray]) -> None: + """Run the distributed training process on a single machine. + For each worker i: + 1. The i-th partition is saved as train{i}.txt + 2. A random port is assigned for training. + 3. A configuration file train{i}.conf is created. + 4. The lightgbm binary is called with config=train{i}.conf in another thread. + The whole training set is saved as train.txt and the logs from the first worker are piped to stdout. + """ + self._write_data(partitions) + self.label_ = np.hstack([partition[:, 0] for partition in partitions]) self._set_ports() futures = [] with ThreadPoolExecutor(max_workers=self.n_workers) as executor: @@ -73,13 +92,20 @@ def fit(self, data): futures.append(executor.submit(self.worker_train, i)) _ = [f.result() for f in futures] - def predict(self): + def predict(self) -> np.ndarray: + """Compute the predictions using the model created in the fit step. + model.txt is used to predict the training set train.txt using predict.conf. + The predictions are saved as predictions.txt and are then loaded to return them as a numpy array. + The logs are piped to stdout.""" cmd = './lightgbm config=predict.conf'.split() run_and_log(cmd) y_pred = np.loadtxt('predictions.txt') return y_pred - def write_train_config(self, i): + def write_train_config(self, i: int) -> None: + """Creates a file train{i}.txt with the required configuration to train. + Each worker gets a different port and piece of the data, the rest are the + model parameters contained in `self.config`.""" with open(f'train{i}.conf', 'wt') as f: f.write(f'local_listen_port = {self.listen_ports[i]}\n') f.write(f'data = train{i}.txt\n') @@ -88,23 +114,27 @@ def write_train_config(self, i): def test_classifier(): + num_machines = 2 data = create_data(task='binary-classification') + partitions = np.array_split(data, num_machines) params = { 'objective': 'binary', } clf = DistributedMockup(params) - clf.fit(data) + clf.fit(partitions) y_probas = clf.predict() y_pred = y_probas > 0.5 assert accuracy_score(clf.label_, y_pred) == 1. def test_regressor(): + num_machines = 2 data = create_data(task='regression') + partitions = np.array_split(data, num_machines) params = { 'objective': 'regression', } reg = DistributedMockup(params) - reg.fit(data) + reg.fit(partitions) y_pred = reg.predict() np.testing.assert_allclose(y_pred, reg.label_, rtol=0.5, atol=50.) From 1d108e18b2974770f09a5c9b66ef3ce48b6a99fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Mon, 10 May 2021 19:50:36 -0500 Subject: [PATCH 06/17] lint --- tests/distributed/_test_distributed.py | 30 ++++++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 6284001f51e4..0ea2b6c3e01a 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -12,8 +12,10 @@ def create_data(task: str, n_samples: int = 1_000) -> np.ndarray: - """Creates the appropiate data for the task. - The data is returned as a numpy array with the label as the first column.""" + """Create the appropiate data for the task. + + The data is returned as a numpy array with the label as the first column. + """ if task == 'binary-classification': centers = [[-4, -4], [4, 4]] X, y = make_blobs(n_samples, centers=centers, random_state=42) @@ -27,12 +29,17 @@ def run_and_log(cmd: List[str]) -> None: """Run `cmd` in another process and pipe its logs to this process' stdout.""" process = subprocess.Popen(cmd, stdout=subprocess.PIPE) assert process.stdout is not None - stdout_stream = lambda: process.stdout.read(1) + + def stdout_stream(): + return process.stdout.read(1) + for c in iter(stdout_stream, b''): sys.stdout.buffer.write(c) class DistributedMockup: + """Simulate distributed training.""" + default_config = { 'task': 'train', 'output_model': 'model.txt', @@ -53,7 +60,9 @@ def __init__(self, config: Dict = {}, n_workers: int = 2): def worker_train(self, i: int) -> None: """Start the training process on the `i`-th worker. - If this is the first worker, its logs are piped to stdout.""" + + If this is the first worker, its logs are piped to stdout. + """ cmd = f'./lightgbm config=train{i}.conf'.split() if i == 0: return run_and_log(cmd) @@ -75,6 +84,7 @@ def _write_data(self, partitions: List[np.ndarray]) -> None: def fit(self, partitions: List[np.ndarray]) -> None: """Run the distributed training process on a single machine. + For each worker i: 1. The i-th partition is saved as train{i}.txt 2. A random port is assigned for training. @@ -94,18 +104,22 @@ def fit(self, partitions: List[np.ndarray]) -> None: def predict(self) -> np.ndarray: """Compute the predictions using the model created in the fit step. + model.txt is used to predict the training set train.txt using predict.conf. The predictions are saved as predictions.txt and are then loaded to return them as a numpy array. - The logs are piped to stdout.""" + The logs are piped to stdout. + """ cmd = './lightgbm config=predict.conf'.split() run_and_log(cmd) y_pred = np.loadtxt('predictions.txt') return y_pred def write_train_config(self, i: int) -> None: - """Creates a file train{i}.txt with the required configuration to train. + """Create a file train{i}.txt with the required configuration to train. + Each worker gets a different port and piece of the data, the rest are the - model parameters contained in `self.config`.""" + model parameters contained in `self.config`. + """ with open(f'train{i}.conf', 'wt') as f: f.write(f'local_listen_port = {self.listen_ports[i]}\n') f.write(f'data = train{i}.txt\n') @@ -114,6 +128,7 @@ def write_train_config(self, i: int) -> None: def test_classifier(): + """Test the classification task.""" num_machines = 2 data = create_data(task='binary-classification') partitions = np.array_split(data, num_machines) @@ -128,6 +143,7 @@ def test_classifier(): def test_regressor(): + """Test the regression task.""" num_machines = 2 data = create_data(task='regression') partitions = np.array_split(data, num_machines) From 9f51ad8fb5d19c02313558d4ef5574d0d0c32310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Tue, 11 May 2021 22:14:01 -0500 Subject: [PATCH 07/17] use pre_partition and write separate model files. remove mypy --- .ci/test.sh | 19 +++++++++---------- .gitignore | 1 + tests/distributed/_test_distributed.py | 3 ++- tests/distributed/predict.conf | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 1bbed0e29c21..191f70b2d811 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -85,6 +85,15 @@ if [[ $TASK == "if-else" ]]; then exit 0 fi +if [[ $TASK == "cli-distributed" ]]; then + conda install -q -y -n $CONDA_ENV numpy pytest scikit-learn + mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build && cmake .. && make -j4 || exit -1 + cp $BUILD_DIRECTORY/lightgbm $BUILD_DIRECTORY/tests/distributed/ || exit -1 + cd $BUILD_DIRECTORY/python-package/ && python setup.py install --precompile || exit -1 + cd $BUILD_DIRECTORY/tests/distributed && pytest _test_distributed.py || exit -1 + exit 0 +fi + if [[ $TASK == "swig" ]]; then mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build if [[ $OS_NAME == "macos" ]]; then @@ -107,16 +116,6 @@ fi conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy pip install graphviz # python-graphviz from Anaconda is not allowed to be installed with Python 3.9 -if [[ $TASK == "cli-distributed" ]]; then - pip install --user mypy - mypy --ignore-missing-imports tests/distributed/ - mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build && cmake .. && make -j4 || exit -1 - cp $BUILD_DIRECTORY/lightgbm $BUILD_DIRECTORY/tests/distributed/ || exit -1 - cd $BUILD_DIRECTORY/python-package/ && python setup.py install --precompile || exit -1 - cd $BUILD_DIRECTORY/tests/distributed && pytest _test_distributed.py || exit -1 - exit 0 -fi - if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then # fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL) for LIBOMP_ALIAS in libgomp.dylib libiomp5.dylib libomp.dylib; do sudo ln -sf "$(brew --cellar libomp)"/*/lib/libomp.dylib $CONDA_PREFIX/lib/$LIBOMP_ALIAS || exit -1; done diff --git a/.gitignore b/.gitignore index 2ee306e155a7..f990289b38d6 100644 --- a/.gitignore +++ b/.gitignore @@ -432,6 +432,7 @@ miktex*.zip **/lgb-model.txt tests/distributed/mlist.txt tests/distributed/train* +tests/distributed/model* tests/distributed/predictions*.txt diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 0ea2b6c3e01a..31cad667958a 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -42,7 +42,7 @@ class DistributedMockup: default_config = { 'task': 'train', - 'output_model': 'model.txt', + 'pre_partition': True, 'machine_list_file': 'mlist.txt', 'tree_learner': 'data', 'force_row_wise': True, @@ -121,6 +121,7 @@ def write_train_config(self, i: int) -> None: model parameters contained in `self.config`. """ with open(f'train{i}.conf', 'wt') as f: + f.write(f'output_model = model{i}.txt\n') f.write(f'local_listen_port = {self.listen_ports[i]}\n') f.write(f'data = train{i}.txt\n') for param, value in self.config.items(): diff --git a/tests/distributed/predict.conf b/tests/distributed/predict.conf index 4e6a0a80d710..f926b2f81edd 100644 --- a/tests/distributed/predict.conf +++ b/tests/distributed/predict.conf @@ -1,4 +1,4 @@ task = predict data = train.txt -input_model = model.txt +input_model = model0.txt output_result = predictions.txt From 17ee6ab64d64cea6db928b6244a7d8a8a4065176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Tue, 11 May 2021 22:51:22 -0500 Subject: [PATCH 08/17] update docs --- tests/distributed/_test_distributed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 31cad667958a..4ab53e252376 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -90,6 +90,7 @@ def fit(self, partitions: List[np.ndarray]) -> None: 2. A random port is assigned for training. 3. A configuration file train{i}.conf is created. 4. The lightgbm binary is called with config=train{i}.conf in another thread. + 5. The trained model is saved as model{i}.txt. Each model file only differs in data and local_listen_port. The whole training set is saved as train.txt and the logs from the first worker are piped to stdout. """ self._write_data(partitions) @@ -105,7 +106,7 @@ def fit(self, partitions: List[np.ndarray]) -> None: def predict(self) -> np.ndarray: """Compute the predictions using the model created in the fit step. - model.txt is used to predict the training set train.txt using predict.conf. + model0.txt is used to predict the training set train.txt using predict.conf. The predictions are saved as predictions.txt and are then loaded to return them as a numpy array. The logs are piped to stdout. """ From a424374bcf2ef09c33d69804defba0487ebe300e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Thu, 20 May 2021 19:25:03 -0500 Subject: [PATCH 09/17] remove ci. lower rtol. pass num_machines in config --- .ci/test.sh | 9 --------- .vsts-ci.yml | 2 -- tests/distributed/_test_distributed.py | 11 ++++++----- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 191f70b2d811..0ee4695f39c2 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -85,15 +85,6 @@ if [[ $TASK == "if-else" ]]; then exit 0 fi -if [[ $TASK == "cli-distributed" ]]; then - conda install -q -y -n $CONDA_ENV numpy pytest scikit-learn - mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build && cmake .. && make -j4 || exit -1 - cp $BUILD_DIRECTORY/lightgbm $BUILD_DIRECTORY/tests/distributed/ || exit -1 - cd $BUILD_DIRECTORY/python-package/ && python setup.py install --precompile || exit -1 - cd $BUILD_DIRECTORY/tests/distributed && pytest _test_distributed.py || exit -1 - exit 0 -fi - if [[ $TASK == "swig" ]]; then mkdir $BUILD_DIRECTORY/build && cd $BUILD_DIRECTORY/build if [[ $OS_NAME == "macos" ]]; then diff --git a/.vsts-ci.yml b/.vsts-ci.yml index ee93cae99726..30702b31467b 100644 --- a/.vsts-ci.yml +++ b/.vsts-ci.yml @@ -119,8 +119,6 @@ jobs: TASK: gpu METHOD: wheel PYTHON_VERSION: 3.7 - distributed: - TASK: cli-distributed steps: - script: | echo "##vso[task.setvariable variable=BUILD_DIRECTORY]$BUILD_SOURCESDIRECTORY" diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 4ab53e252376..fe77f48a8b5f 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -52,11 +52,10 @@ class DistributedMockup: 'num_threads': 2, } - def __init__(self, config: Dict = {}, n_workers: int = 2): + def __init__(self, config: Dict = {}): self.config = copy.deepcopy(self.default_config) self.config.update(config) - self.config['num_machines'] = n_workers - self.n_workers = n_workers + self.n_workers = self.config['num_machines'] def worker_train(self, i: int) -> None: """Start the training process on the `i`-th worker. @@ -116,7 +115,7 @@ def predict(self) -> np.ndarray: return y_pred def write_train_config(self, i: int) -> None: - """Create a file train{i}.txt with the required configuration to train. + """Create a file train{i}.conf with the required configuration to train. Each worker gets a different port and piece of the data, the rest are the model parameters contained in `self.config`. @@ -136,6 +135,7 @@ def test_classifier(): partitions = np.array_split(data, num_machines) params = { 'objective': 'binary', + 'num_machines': num_machines, } clf = DistributedMockup(params) clf.fit(partitions) @@ -151,8 +151,9 @@ def test_regressor(): partitions = np.array_split(data, num_machines) params = { 'objective': 'regression', + 'num_machines': num_machines, } reg = DistributedMockup(params) reg.fit(partitions) y_pred = reg.predict() - np.testing.assert_allclose(y_pred, reg.label_, rtol=0.5, atol=50.) + np.testing.assert_allclose(y_pred, reg.label_, rtol=0.2, atol=50.) From f5e9d496715cf409f5bdf38e38092e72d0764a3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Wed, 26 May 2021 21:47:17 -0500 Subject: [PATCH 10/17] write predict.conf in the predict method. more robust port setup. use subprocess.run and check returncode --- tests/distributed/_test_distributed.py | 107 +++++++++++++++---------- tests/distributed/predict.conf | 4 - 2 files changed, 63 insertions(+), 48 deletions(-) delete mode 100644 tests/distributed/predict.conf diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index fe77f48a8b5f..e3855454b86b 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -1,8 +1,8 @@ import copy +import io import subprocess -import sys from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List +from typing import Dict, Generator, List import numpy as np from sklearn.datasets import make_blobs, make_regression @@ -11,6 +11,15 @@ import lightgbm as lgb +def _generate_n_ports(n: int) -> Generator[int, None, None]: + return (lgb.dask._find_random_open_port() for _ in range(n)) + + +def _write_dict(d: Dict, file: io.TextIOWrapper) -> None: + for k, v in d.items(): + file.write(f'{k} = {v}\n') + + def create_data(task: str, n_samples: int = 1_000) -> np.ndarray: """Create the appropiate data for the task. @@ -25,22 +34,10 @@ def create_data(task: str, n_samples: int = 1_000) -> np.ndarray: return dataset -def run_and_log(cmd: List[str]) -> None: - """Run `cmd` in another process and pipe its logs to this process' stdout.""" - process = subprocess.Popen(cmd, stdout=subprocess.PIPE) - assert process.stdout is not None - - def stdout_stream(): - return process.stdout.read(1) - - for c in iter(stdout_stream, b''): - sys.stdout.buffer.write(c) - - class DistributedMockup: """Simulate distributed training.""" - default_config = { + default_train_config = { 'task': 'train', 'pre_partition': True, 'machine_list_file': 'mlist.txt', @@ -52,24 +49,34 @@ class DistributedMockup: 'num_threads': 2, } - def __init__(self, config: Dict = {}): - self.config = copy.deepcopy(self.default_config) - self.config.update(config) - self.n_workers = self.config['num_machines'] + default_predict_config = { + 'task': 'predict', + 'data': 'train.txt', + 'input_model': 'model0.txt', + 'output_result': 'predictions.txt', + } - def worker_train(self, i: int) -> None: + def worker_train(self, i: int) -> subprocess.CompletedProcess: """Start the training process on the `i`-th worker. If this is the first worker, its logs are piped to stdout. """ cmd = f'./lightgbm config=train{i}.conf'.split() - if i == 0: - return run_and_log(cmd) - subprocess.run(cmd) + return subprocess.run(cmd) def _set_ports(self) -> None: """Randomly assign a port for training to each worker and save all ports to mlist.txt.""" - self.listen_ports = [lgb.dask._find_random_open_port() for _ in range(self.n_workers)] + ports = set(_generate_n_ports(self.n_workers)) + i = 0 + max_tries = 100 + while i < max_tries and len(ports) < self.n_workers: + n_ports_left = self.n_workers - len(ports) + candidates = _generate_n_ports(n_ports_left) + ports.update(candidates) + i += 1 + if i == max_tries: + raise RuntimeError('Unable to find non-colliding ports.') + self.listen_ports = list(ports) with open('mlist.txt', 'wt') as f: for port in self.listen_ports: f.write(f'127.0.0.1 {port}\n') @@ -81,36 +88,49 @@ def _write_data(self, partitions: List[np.ndarray]) -> None: for i, partition in enumerate(partitions): np.savetxt(f'train{i}.txt', partition, delimiter=',') - def fit(self, partitions: List[np.ndarray]) -> None: + def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: """Run the distributed training process on a single machine. For each worker i: - 1. The i-th partition is saved as train{i}.txt + 1. The i-th partition is saved as train{i}.txt. 2. A random port is assigned for training. 3. A configuration file train{i}.conf is created. 4. The lightgbm binary is called with config=train{i}.conf in another thread. 5. The trained model is saved as model{i}.txt. Each model file only differs in data and local_listen_port. The whole training set is saved as train.txt and the logs from the first worker are piped to stdout. """ + self.train_config = copy.deepcopy(self.default_train_config) + self.train_config.update(train_config) + self.n_workers = self.train_config['num_machines'] + self._set_ports() self._write_data(partitions) self.label_ = np.hstack([partition[:, 0] for partition in partitions]) - self._set_ports() futures = [] with ThreadPoolExecutor(max_workers=self.n_workers) as executor: for i in range(self.n_workers): self.write_train_config(i) - futures.append(executor.submit(self.worker_train, i)) - _ = [f.result() for f in futures] - - def predict(self) -> np.ndarray: + train_future = executor.submit(self.worker_train, i) + futures.append(train_future) + results = [f.result() for f in futures] + for result in results: + if result.returncode != 0: + raise RuntimeError + + def predict(self, predict_config: Dict = {}) -> np.ndarray: """Compute the predictions using the model created in the fit step. model0.txt is used to predict the training set train.txt using predict.conf. The predictions are saved as predictions.txt and are then loaded to return them as a numpy array. The logs are piped to stdout. """ + self.predict_config = copy.deepcopy(self.default_predict_config) + self.predict_config.update(predict_config) + with open('predict.conf', 'wt') as file: + _write_dict(self.predict_config, file) cmd = './lightgbm config=predict.conf'.split() - run_and_log(cmd) + result = subprocess.run(cmd) + if result.returncode != 0: + raise RuntimeError y_pred = np.loadtxt('predictions.txt') return y_pred @@ -120,12 +140,11 @@ def write_train_config(self, i: int) -> None: Each worker gets a different port and piece of the data, the rest are the model parameters contained in `self.config`. """ - with open(f'train{i}.conf', 'wt') as f: - f.write(f'output_model = model{i}.txt\n') - f.write(f'local_listen_port = {self.listen_ports[i]}\n') - f.write(f'data = train{i}.txt\n') - for param, value in self.config.items(): - f.write(f'{param} = {value}\n') + with open(f'train{i}.conf', 'wt') as file: + file.write(f'output_model = model{i}.txt\n') + file.write(f'local_listen_port = {self.listen_ports[i]}\n') + file.write(f'data = train{i}.txt\n') + _write_dict(self.train_config, file) def test_classifier(): @@ -133,12 +152,12 @@ def test_classifier(): num_machines = 2 data = create_data(task='binary-classification') partitions = np.array_split(data, num_machines) - params = { + train_params = { 'objective': 'binary', 'num_machines': num_machines, } - clf = DistributedMockup(params) - clf.fit(partitions) + clf = DistributedMockup() + clf.fit(partitions, train_params) y_probas = clf.predict() y_pred = y_probas > 0.5 assert accuracy_score(clf.label_, y_pred) == 1. @@ -149,11 +168,11 @@ def test_regressor(): num_machines = 2 data = create_data(task='regression') partitions = np.array_split(data, num_machines) - params = { + train_params = { 'objective': 'regression', 'num_machines': num_machines, } - reg = DistributedMockup(params) - reg.fit(partitions) + reg = DistributedMockup() + reg.fit(partitions, train_params) y_pred = reg.predict() np.testing.assert_allclose(y_pred, reg.label_, rtol=0.2, atol=50.) diff --git a/tests/distributed/predict.conf b/tests/distributed/predict.conf deleted file mode 100644 index f926b2f81edd..000000000000 --- a/tests/distributed/predict.conf +++ /dev/null @@ -1,4 +0,0 @@ -task = predict -data = train.txt -input_model = model0.txt -output_result = predictions.txt From f7c6fddc2e82ea55773e72f7639ff5d6514a0cb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Mon, 31 May 2021 20:21:00 -0500 Subject: [PATCH 11/17] add paths to tests and binary. remove lgb dependency. update .igtignore. --- .gitignore | 2 +- tests/distributed/_test_distributed.py | 58 ++++++++++++++++---------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index f990289b38d6..c054aee9111a 100644 --- a/.gitignore +++ b/.gitignore @@ -433,7 +433,7 @@ miktex*.zip tests/distributed/mlist.txt tests/distributed/train* tests/distributed/model* -tests/distributed/predictions*.txt +tests/distributed/predict* # Files from interactive R sessions diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index e3855454b86b..326fccce3206 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -1,18 +1,30 @@ import copy import io +import socket import subprocess from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from typing import Dict, Generator, List import numpy as np from sklearn.datasets import make_blobs, make_regression from sklearn.metrics import accuracy_score -import lightgbm as lgb + +TESTS_DIR = Path(__file__).absolute().parent +BINARY_DIR = TESTS_DIR.parents[1] + + +def _find_random_open_port() -> int: + """Find a random open port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + port = s.getsockname()[1] + return port def _generate_n_ports(n: int) -> Generator[int, None, None]: - return (lgb.dask._find_random_open_port() for _ in range(n)) + return (_find_random_open_port() for _ in range(n)) def _write_dict(d: Dict, file: io.TextIOWrapper) -> None: @@ -40,7 +52,7 @@ class DistributedMockup: default_train_config = { 'task': 'train', 'pre_partition': True, - 'machine_list_file': 'mlist.txt', + 'machine_list_file': TESTS_DIR / 'mlist.txt', 'tree_learner': 'data', 'force_row_wise': True, 'verbose': 0, @@ -51,17 +63,15 @@ class DistributedMockup: default_predict_config = { 'task': 'predict', - 'data': 'train.txt', - 'input_model': 'model0.txt', - 'output_result': 'predictions.txt', + 'data': TESTS_DIR / 'train.txt', + 'input_model': TESTS_DIR / 'model0.txt', + 'output_result': TESTS_DIR / 'predictions.txt', } def worker_train(self, i: int) -> subprocess.CompletedProcess: - """Start the training process on the `i`-th worker. - - If this is the first worker, its logs are piped to stdout. - """ - cmd = f'./lightgbm config=train{i}.conf'.split() + """Start the training process on the `i`-th worker.""" + config_path = TESTS_DIR / f'train{i}.conf' + cmd = [BINARY_DIR / 'lightgbm', f'config={config_path}'] return subprocess.run(cmd) def _set_ports(self) -> None: @@ -77,16 +87,16 @@ def _set_ports(self) -> None: if i == max_tries: raise RuntimeError('Unable to find non-colliding ports.') self.listen_ports = list(ports) - with open('mlist.txt', 'wt') as f: + with open(TESTS_DIR / 'mlist.txt', 'wt') as f: for port in self.listen_ports: f.write(f'127.0.0.1 {port}\n') def _write_data(self, partitions: List[np.ndarray]) -> None: """Write all training data as train.txt and each training partition as train{i}.txt.""" all_data = np.vstack(partitions) - np.savetxt('train.txt', all_data, delimiter=',') + np.savetxt(TESTS_DIR / 'train.txt', all_data, delimiter=',') for i, partition in enumerate(partitions): - np.savetxt(f'train{i}.txt', partition, delimiter=',') + np.savetxt(TESTS_DIR / f'train{i}.txt', partition, delimiter=',') def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: """Run the distributed training process on a single machine. @@ -97,7 +107,7 @@ def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: 3. A configuration file train{i}.conf is created. 4. The lightgbm binary is called with config=train{i}.conf in another thread. 5. The trained model is saved as model{i}.txt. Each model file only differs in data and local_listen_port. - The whole training set is saved as train.txt and the logs from the first worker are piped to stdout. + The whole training set is saved as train.txt. """ self.train_config = copy.deepcopy(self.default_train_config) self.train_config.update(train_config) @@ -119,19 +129,19 @@ def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: def predict(self, predict_config: Dict = {}) -> np.ndarray: """Compute the predictions using the model created in the fit step. - model0.txt is used to predict the training set train.txt using predict.conf. + predict_config is used to predict the training set train.txt The predictions are saved as predictions.txt and are then loaded to return them as a numpy array. - The logs are piped to stdout. """ self.predict_config = copy.deepcopy(self.default_predict_config) self.predict_config.update(predict_config) - with open('predict.conf', 'wt') as file: + with open(TESTS_DIR / 'predict.conf', 'wt') as file: _write_dict(self.predict_config, file) - cmd = './lightgbm config=predict.conf'.split() + config_path = TESTS_DIR / 'predict.conf' + cmd = [BINARY_DIR / 'lightgbm', f'config={config_path}'] result = subprocess.run(cmd) if result.returncode != 0: raise RuntimeError - y_pred = np.loadtxt('predictions.txt') + y_pred = np.loadtxt(TESTS_DIR / 'predictions.txt') return y_pred def write_train_config(self, i: int) -> None: @@ -140,10 +150,12 @@ def write_train_config(self, i: int) -> None: Each worker gets a different port and piece of the data, the rest are the model parameters contained in `self.config`. """ - with open(f'train{i}.conf', 'wt') as file: - file.write(f'output_model = model{i}.txt\n') + with open(TESTS_DIR / f'train{i}.conf', 'wt') as file: + output_model = TESTS_DIR / f'model{i}.txt' + data = TESTS_DIR / f'train{i}.txt' + file.write(f'output_model = {output_model}\n') file.write(f'local_listen_port = {self.listen_ports[i]}\n') - file.write(f'data = train{i}.txt\n') + file.write(f'data = {data}\n') _write_dict(self.train_config, file) From c964fdf8ca4d449cbec79fcf864845cb95fbf2aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Mon, 31 May 2021 20:27:30 -0500 Subject: [PATCH 12/17] lint --- tests/distributed/_test_distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 326fccce3206..4ef3156217c1 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -10,7 +10,6 @@ from sklearn.datasets import make_blobs, make_regression from sklearn.metrics import accuracy_score - TESTS_DIR = Path(__file__).absolute().parent BINARY_DIR = TESTS_DIR.parents[1] From 56113de5baf57c5ff793863a2e43d19452acbbb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Mon, 7 Jun 2021 21:32:17 -0500 Subject: [PATCH 13/17] allow to pass executable dir as argument to pytest --- tests/distributed/_test_distributed.py | 28 +++++++++++++++++++------- tests/distributed/conftest.py | 7 +++++++ 2 files changed, 28 insertions(+), 7 deletions(-) create mode 100644 tests/distributed/conftest.py diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 4ef3156217c1..6f2db6e5ece5 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -4,14 +4,25 @@ import subprocess from concurrent.futures import ThreadPoolExecutor from pathlib import Path +from platform import system from typing import Dict, Generator, List import numpy as np +import pytest from sklearn.datasets import make_blobs, make_regression from sklearn.metrics import accuracy_score TESTS_DIR = Path(__file__).absolute().parent -BINARY_DIR = TESTS_DIR.parents[1] + + +@pytest.fixture(scope='module') +def executable(pytestconfig) -> str: + """Returns the path to the lightgbm executable.""" + exec_dir = Path(pytestconfig.getoption('execdir')) + exec_file = 'lightgbm' + if system() in {'Windows', 'Microsoft'}: + exec_file += '.exe' + return str(exec_dir / exec_file) def _find_random_open_port() -> int: @@ -67,10 +78,13 @@ class DistributedMockup: 'output_result': TESTS_DIR / 'predictions.txt', } + def __init__(self, executable: str): + self.executable = executable + def worker_train(self, i: int) -> subprocess.CompletedProcess: """Start the training process on the `i`-th worker.""" config_path = TESTS_DIR / f'train{i}.conf' - cmd = [BINARY_DIR / 'lightgbm', f'config={config_path}'] + cmd = [self.executable, f'config={config_path}'] return subprocess.run(cmd) def _set_ports(self) -> None: @@ -136,7 +150,7 @@ def predict(self, predict_config: Dict = {}) -> np.ndarray: with open(TESTS_DIR / 'predict.conf', 'wt') as file: _write_dict(self.predict_config, file) config_path = TESTS_DIR / 'predict.conf' - cmd = [BINARY_DIR / 'lightgbm', f'config={config_path}'] + cmd = [self.executable, f'config={config_path}'] result = subprocess.run(cmd) if result.returncode != 0: raise RuntimeError @@ -158,7 +172,7 @@ def write_train_config(self, i: int) -> None: _write_dict(self.train_config, file) -def test_classifier(): +def test_classifier(executable): """Test the classification task.""" num_machines = 2 data = create_data(task='binary-classification') @@ -167,14 +181,14 @@ def test_classifier(): 'objective': 'binary', 'num_machines': num_machines, } - clf = DistributedMockup() + clf = DistributedMockup(executable) clf.fit(partitions, train_params) y_probas = clf.predict() y_pred = y_probas > 0.5 assert accuracy_score(clf.label_, y_pred) == 1. -def test_regressor(): +def test_regressor(executable): """Test the regression task.""" num_machines = 2 data = create_data(task='regression') @@ -183,7 +197,7 @@ def test_regressor(): 'objective': 'regression', 'num_machines': num_machines, } - reg = DistributedMockup() + reg = DistributedMockup(executable) reg.fit(partitions, train_params) y_pred = reg.predict() np.testing.assert_allclose(y_pred, reg.label_, rtol=0.2, atol=50.) diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py new file mode 100644 index 000000000000..5458cf5da822 --- /dev/null +++ b/tests/distributed/conftest.py @@ -0,0 +1,7 @@ +from pathlib import Path + +default_exec_dir = str(Path(__file__).absolute().parents[2]) + + +def pytest_addoption(parser): + parser.addoption('--execdir', action='store', default=default_exec_dir) From 83ccf6ace52f355ae79a2254fa5836ca58539562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Wed, 9 Jun 2021 20:53:49 -0500 Subject: [PATCH 14/17] pass execfile to pytest instead of execdir --- tests/distributed/_test_distributed.py | 7 +------ tests/distributed/conftest.py | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 6f2db6e5ece5..36bb85f2c93f 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -4,7 +4,6 @@ import subprocess from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from platform import system from typing import Dict, Generator, List import numpy as np @@ -18,11 +17,7 @@ @pytest.fixture(scope='module') def executable(pytestconfig) -> str: """Returns the path to the lightgbm executable.""" - exec_dir = Path(pytestconfig.getoption('execdir')) - exec_file = 'lightgbm' - if system() in {'Windows', 'Microsoft'}: - exec_file += '.exe' - return str(exec_dir / exec_file) + return pytestconfig.getoption('execfile') def _find_random_open_port() -> int: diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 5458cf5da822..089bd6c598d8 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -1,7 +1,7 @@ from pathlib import Path -default_exec_dir = str(Path(__file__).absolute().parents[2]) +default_exec_file = str(Path(__file__).absolute().parents[2] / 'lightgbm') def pytest_addoption(parser): - parser.addoption('--execdir', action='store', default=default_exec_dir) + parser.addoption('--execfile', action='store', default=default_exec_file) From e84df750313c4d34b761482dfea8568377059369 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Mon, 14 Jun 2021 19:04:42 -0500 Subject: [PATCH 15/17] add suggestions --- tests/distributed/_test_distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 36bb85f2c93f..7d95edca8dcc 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -38,7 +38,7 @@ def _write_dict(d: Dict, file: io.TextIOWrapper) -> None: def create_data(task: str, n_samples: int = 1_000) -> np.ndarray: - """Create the appropiate data for the task. + """Create the appropriate data for the task. The data is returned as a numpy array with the label as the first column. """ @@ -142,9 +142,9 @@ def predict(self, predict_config: Dict = {}) -> np.ndarray: """ self.predict_config = copy.deepcopy(self.default_predict_config) self.predict_config.update(predict_config) - with open(TESTS_DIR / 'predict.conf', 'wt') as file: - _write_dict(self.predict_config, file) config_path = TESTS_DIR / 'predict.conf' + with open(config_path, 'wt') as file: + _write_dict(self.predict_config, file) cmd = [self.executable, f'config={config_path}'] result = subprocess.run(cmd) if result.returncode != 0: From 45bfc79be1b0d3f244537c8d0a964c10758f5eeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20MZ?= Date: Fri, 25 Jun 2021 18:22:17 -0500 Subject: [PATCH 16/17] use os.path and add type hint to predict_config --- tests/distributed/_test_distributed.py | 34 +++++++++++++------------- tests/distributed/conftest.py | 5 ++-- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 7d95edca8dcc..16b6e2c54910 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -1,17 +1,17 @@ import copy import io +import os import socket import subprocess from concurrent.futures import ThreadPoolExecutor -from pathlib import Path -from typing import Dict, Generator, List +from typing import Any, Dict, Generator, List import numpy as np import pytest from sklearn.datasets import make_blobs, make_regression from sklearn.metrics import accuracy_score -TESTS_DIR = Path(__file__).absolute().parent +TESTS_DIR = os.path.abspath(os.path.dirname(__file__)) @pytest.fixture(scope='module') @@ -57,7 +57,7 @@ class DistributedMockup: default_train_config = { 'task': 'train', 'pre_partition': True, - 'machine_list_file': TESTS_DIR / 'mlist.txt', + 'machine_list_file': os.path.join(TESTS_DIR, 'mlist.txt'), 'tree_learner': 'data', 'force_row_wise': True, 'verbose': 0, @@ -68,9 +68,9 @@ class DistributedMockup: default_predict_config = { 'task': 'predict', - 'data': TESTS_DIR / 'train.txt', - 'input_model': TESTS_DIR / 'model0.txt', - 'output_result': TESTS_DIR / 'predictions.txt', + 'data': os.path.join(TESTS_DIR, 'train.txt'), + 'input_model': os.path.join(TESTS_DIR, 'model0.txt'), + 'output_result': os.path.join(TESTS_DIR, 'predictions.txt'), } def __init__(self, executable: str): @@ -78,7 +78,7 @@ def __init__(self, executable: str): def worker_train(self, i: int) -> subprocess.CompletedProcess: """Start the training process on the `i`-th worker.""" - config_path = TESTS_DIR / f'train{i}.conf' + config_path = os.path.join(TESTS_DIR, f'train{i}.conf') cmd = [self.executable, f'config={config_path}'] return subprocess.run(cmd) @@ -95,16 +95,16 @@ def _set_ports(self) -> None: if i == max_tries: raise RuntimeError('Unable to find non-colliding ports.') self.listen_ports = list(ports) - with open(TESTS_DIR / 'mlist.txt', 'wt') as f: + with open(os.path.join(TESTS_DIR, 'mlist.txt'), 'wt') as f: for port in self.listen_ports: f.write(f'127.0.0.1 {port}\n') def _write_data(self, partitions: List[np.ndarray]) -> None: """Write all training data as train.txt and each training partition as train{i}.txt.""" all_data = np.vstack(partitions) - np.savetxt(TESTS_DIR / 'train.txt', all_data, delimiter=',') + np.savetxt(os.path.join(TESTS_DIR, 'train.txt'), all_data, delimiter=',') for i, partition in enumerate(partitions): - np.savetxt(TESTS_DIR / f'train{i}.txt', partition, delimiter=',') + np.savetxt(os.path.join(TESTS_DIR, f'train{i}.txt'), partition, delimiter=',') def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: """Run the distributed training process on a single machine. @@ -134,7 +134,7 @@ def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: if result.returncode != 0: raise RuntimeError - def predict(self, predict_config: Dict = {}) -> np.ndarray: + def predict(self, predict_config: Dict[str, Any] = {}) -> np.ndarray: """Compute the predictions using the model created in the fit step. predict_config is used to predict the training set train.txt @@ -142,14 +142,14 @@ def predict(self, predict_config: Dict = {}) -> np.ndarray: """ self.predict_config = copy.deepcopy(self.default_predict_config) self.predict_config.update(predict_config) - config_path = TESTS_DIR / 'predict.conf' + config_path = os.path.join(TESTS_DIR, 'predict.conf') with open(config_path, 'wt') as file: _write_dict(self.predict_config, file) cmd = [self.executable, f'config={config_path}'] result = subprocess.run(cmd) if result.returncode != 0: raise RuntimeError - y_pred = np.loadtxt(TESTS_DIR / 'predictions.txt') + y_pred = np.loadtxt(os.path.join(TESTS_DIR, 'predictions.txt')) return y_pred def write_train_config(self, i: int) -> None: @@ -158,9 +158,9 @@ def write_train_config(self, i: int) -> None: Each worker gets a different port and piece of the data, the rest are the model parameters contained in `self.config`. """ - with open(TESTS_DIR / f'train{i}.conf', 'wt') as file: - output_model = TESTS_DIR / f'model{i}.txt' - data = TESTS_DIR / f'train{i}.txt' + with open(os.path.join(TESTS_DIR, f'train{i}.conf'), 'wt') as file: + output_model = os.path.join(TESTS_DIR, f'model{i}.txt') + data = os.path.join(TESTS_DIR, f'train{i}.txt') file.write(f'output_model = {output_model}\n') file.write(f'local_listen_port = {self.listen_ports[i]}\n') file.write(f'data = {data}\n') diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 089bd6c598d8..d5db71c69513 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -1,6 +1,7 @@ -from pathlib import Path +import os -default_exec_file = str(Path(__file__).absolute().parents[2] / 'lightgbm') +TESTS_DIR = os.path.dirname(__file__) +default_exec_file = os.path.abspath(os.path.join(TESTS_DIR, '..', '..', 'lightgbm')) def pytest_addoption(parser): From d4cbb7a2c47496e254539cb2d9d1d7c9ad6f03a8 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 28 Jun 2021 04:12:59 +0100 Subject: [PATCH 17/17] Update tests/distributed/_test_distributed.py --- tests/distributed/_test_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 16b6e2c54910..64ffa2b22399 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -132,7 +132,7 @@ def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: results = [f.result() for f in futures] for result in results: if result.returncode != 0: - raise RuntimeError + raise RuntimeError('Error in training') def predict(self, predict_config: Dict[str, Any] = {}) -> np.ndarray: """Compute the predictions using the model created in the fit step.