diff --git a/.github/workflows/build_wheels_and_releases.yml b/.github/workflows/build_wheels_and_releases.yml
new file mode 100644
index 0000000000..bf33c56bd2
--- /dev/null
+++ b/.github/workflows/build_wheels_and_releases.yml
@@ -0,0 +1,190 @@
+name: Build-Wheels-PyPi
+# https://github.com/pypa/cibuildwheel
+# Controls when the workflow will run
+on:
+ pull_request:
+ branches: [ master ]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+jobs:
+ # Build the wheels for Linux, Windows and macOS for Python 3.8 and newer
+ build_wheels:
+ name: Build wheel for cp${{ matrix.python }}-${{ matrix.platform_id }}-${{ matrix.manylinux_image }}
+ runs-on: ${{ matrix.os }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+
+ strategy:
+ # Ensure that a wheel builder finishes even if another fails
+ fail-fast: false
+ matrix:
+ include:
+ # Window 64 bit
+ - os: windows-2019
+ python: 38
+ bitness: 64
+ platform_id: win_amd64
+ - os: windows-latest
+ python: 39
+ bitness: 64
+ platform_id: win_amd64
+ - os: windows-latest
+ python: 310
+ bitness: 64
+ platform_id: win_amd64
+
+ # Window 32 bit
+ - os: windows-latest
+ python: 38
+ bitness: 32
+ platform_id: win32
+ - os: windows-latest
+ python: 39
+ bitness: 32
+ platform_id: win32
+
+ # Linux 64 bit manylinux2014
+ - os: ubuntu-latest
+ python: 38
+ bitness: 64
+ platform_id: manylinux_x86_64
+ manylinux_image: manylinux2014
+ - os: ubuntu-latest
+ python: 39
+ bitness: 64
+ platform_id: manylinux_x86_64
+ manylinux_image: manylinux2014
+
+ # NumPy on Python 3.10 only supports 64bit and is only available with manylinux2014
+ - os: ubuntu-latest
+ python: 310
+ bitness: 64
+ platform_id: manylinux_x86_64
+ manylinux_image: manylinux2014
+
+ # MacOS x86_64
+ - os: macos-latest
+ bitness: 64
+ python: 38
+ platform_id: macosx_x86_64
+ - os: macos-latest
+ bitness: 64
+ python: 39
+ platform_id: macosx_x86_64
+ - os: macos-latest
+ bitness: 64
+ python: 310
+ platform_id: macosx_x86_64
+
+ # MacOS arm64
+ - os: macos-latest
+ bitness: 64
+ python: 38
+ platform_id: macosx_arm64
+ - os: macos-latest
+ bitness: 64
+ python: 39
+ platform_id: macosx_arm64
+ - os: macos-latest
+ bitness: 64
+ python: 310
+ platform_id: macosx_arm64
+
+ steps:
+ - name: Checkout fedml
+ uses: actions/checkout@v3
+
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ #with:
+ # python-version: '3.9'
+
+ - name: Build and test wheels
+ env:
+ CIBW_BUILD: cp${{ matrix.python }}-${{ matrix.platform_id }}
+ CIBW_ARCHS: all
+ CIBW_MANYLINUX_X86_64_IMAGE: ${{ matrix.manylinux_image }}
+ CIBW_MANYLINUX_I686_IMAGE: ${{ matrix.manylinux_image }}
+ CIBW_TEST_SKIP: "*-macosx_arm64"
+ CIBW_REPAIR_WHEEL_COMMAND_WINDOWS: bash build_tools/github/repair_windows_wheels.sh {wheel} {dest_dir} ${{ matrix.bitness }}
+ CIBW_BEFORE_TEST_WINDOWS: bash build_tools/github/build_minimal_windows_image.sh ${{ matrix.python }} ${{ matrix.bitness }}
+ CIBW_TEST_COMMAND: bash {project}/build_tools/github/test_wheels.sh
+ CIBW_TEST_COMMAND_WINDOWS: bash {project}/build_tools/github/test_windows_wheels.sh ${{ matrix.python }} ${{ matrix.bitness }}
+ CIBW_BUILD_VERBOSITY: 1
+
+ #run: bash build_tools/github/build_wheels.sh
+ run: |
+ python -m pip install -U wheel setuptools
+ python setup.py sdist bdist_wheel
+ pwd
+ ls dist/*.whl
+ ls dist/*.tar.gz
+
+ - name: Upload source zip file
+ uses: actions/upload-artifact@v3
+ with:
+ path: python/dist/*.tar.gz
+
+ - name: Upload Wheels
+ uses: actions/upload-artifact@v3
+ with:
+ path: python/dist/*.whl
+
+ # Build the source distribution under Linux
+# build_sdist:
+# name: Source distribution
+# needs: [ build_wheels ]
+# runs-on: ubuntu-latest
+# defaults:
+# run:
+# shell: bash
+# working-directory: python
+#
+# steps:
+# - name: Checkout fedml
+# uses: actions/checkout@v3
+#
+# - name: Setup Python
+# uses: actions/setup-python@v4
+# with:
+# python-version: '3.9' # update once build dependencies are available
+#
+# - name: Build source distribution
+# run: bash build_tools/github/build_source.sh
+#
+# - name: Test source distribution
+# run: bash build_tools/github/test_source.sh
+# env:
+# OMP_NUM_THREADS: 2
+# OPENBLAS_NUM_THREADS: 2
+#
+# - name: Store artifacts
+# uses: actions/upload-artifact@v3
+# with:
+# path: python/dist/*.tar.gz
+
+ upload_pypi:
+ name: Upload pypi
+ needs: [ build_wheels ]
+ runs-on: ubuntu-latest
+ # upload to PyPI on every tag starting with 'v'
+ # if: github.event_name == 'push' && contains(github.event.comment, 'release v')
+ # alternatively, to publish when a GitHub Release is created, use the following rule:
+ # if: github.event_name == 'release' && github.event.action == 'published'
+ steps:
+ - uses: actions/download-artifact@v3
+ with:
+ name: artifact
+ path: python/dist
+
+ - uses: pypa/gh-action-pypi-publish@v1.4.2
+ with:
+ skip_existing: true
+ packages_dir: python/dist
+ user: ${{ secrets.PYPI_USER_NAME }}
+ password: ${{ secrets.PYPI_PASSWORD }}
+ # To test: repository_url: https://test.pypi.org/legacy/
\ No newline at end of file
diff --git a/.github/workflows/build_wheels_and_releases.yml-backup b/.github/workflows/build_wheels_and_releases.yml-backup
new file mode 100644
index 0000000000..2a27b44e65
--- /dev/null
+++ b/.github/workflows/build_wheels_and_releases.yml-backup
@@ -0,0 +1,250 @@
+#name: Build_Wheels_and_Release
+## https://github.com/pypa/cibuildwheel
+## Controls when the workflow will run
+#on:
+# # Triggers the workflow on push or pull request events but only for the master branch
+# schedule:
+# # Nightly build at 12:12 A.M.
+# - cron: "12 12 */1 * *"
+## push:
+## branches: [ master, test/v0.7.0 ]
+## pull_request:
+## branches: [ master, test/v0.7.0 ]
+#
+# # Allows you to run this workflow manually from the Actions tab
+# workflow_dispatch:
+#
+#jobs:
+# build_wheels:
+# runs-on: [self-hosted, devops]
+# defaults:
+# run:
+# shell: bash
+# working-directory: python
+# strategy:
+# # Ensure that a wheel builder finishes even if another fails
+# fail-fast: false
+# matrix:
+# # Github Actions doesn't support pairing matrix values together, let's improvise
+# # https://github.com/github/feedback/discussions/7835#discussioncomment-1769026
+# buildplat:
+# - [ubuntu-20.04, manylinux_x86_64]
+# - [macos-10.15, macosx_*]
+# - [windows-2019, win_amd64]
+# - [windows-2019, win32]
+# # TODO: uncomment PyPy 3.9 builds once PyPy
+# # re-releases a new minor version
+# # NOTE: This needs a bump of cibuildwheel version, also, once that happens.
+# python: ["cp38", "cp39", "cp310", "pp38"] #, "pp39"]
+# exclude:
+# # Don't build PyPy 32-bit windows
+# - buildplat: [windows-2019, win32]
+# python: "pp38"
+# - buildplat: [windows-2019, win32]
+# python: "pp39"
+# env:
+# IS_32_BIT: ${{ matrix.buildplat[1] == 'win32' }}
+# IS_PUSH: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') }}
+# IS_SCHEDULE_DISPATCH: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
+# steps:
+# - name: Checkout numpy
+# - uses: actions/checkout@v3
+#
+# # Used to push the built wheels
+# - uses: actions/setup-python@v3
+# with:
+# python-version: '3.x'
+#
+# - name: Configure mingw for 32-bit builds
+# run: |
+# # Force 32-bit mingw
+# choco uninstall mingw
+# choco install -y mingw --forcex86 --force --version=7.3.0
+# echo "C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw32\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
+# refreshenv
+# if: ${{ env.IS_32_BIT == 'true' }}
+#
+#
+# - name: Install cibuildwheel
+# run: |
+# pip install --upgrade setuptools
+# python3 -m pip install pipx
+# python3 -m pipx ensurepath
+# export PATH="${PATH}:$(python3 -c 'import site; print(site.USER_BASE)')/bin"
+## python3 -m pipx ensurepath --force
+#
+# - name: Build wheels
+# uses: pypa/cibuildwheel@v2.7.0
+# env:
+# CIBW_BUILD: ${{ matrix.python }}-${{ matrix.buildplat[1] }}
+#
+# - uses: actions/upload-artifact@v3
+# with:
+# name: ${{ matrix.python }}-${{ startsWith(matrix.buildplat[1], 'macosx') && 'macosx' || matrix.buildplat[1] }}
+# path: ./wheelhouse/*.whl
+
+# - name: Upload wheels
+# if: success()
+# shell: bash
+# env:
+# NUMPY_STAGING_UPLOAD_TOKEN: ${{ secrets.NUMPY_STAGING_UPLOAD_TOKEN }}
+# NUMPY_NIGHTLY_UPLOAD_TOKEN: ${{ secrets.NUMPY_NIGHTLY_UPLOAD_TOKEN }}
+# run: |
+# source tools/wheels/upload_wheels.sh
+# set_upload_vars
+# # trigger an upload to
+# # https://anaconda.org/scipy-wheels-nightly/numpy
+# # for cron jobs or "Run workflow" (restricted to main branch).
+# # Tags will upload to
+# # https://anaconda.org/multibuild-wheels-staging/numpy
+# # The tokens were originally generated at anaconda.org
+# upload_wheels
+# strategy:
+# fail-fast: false
+# matrix:
+# os: [ ubuntu-20.04, windows-2019, macOS-10.15 ]
+# arch: [X86, X64, ARM, ARM64]
+# python-version: ['3.6', '3.7', '3.8', '3.9']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+# strategy:
+# # Ensure that a wheel builder finishes even if another fails
+# fail-fast: false
+# matrix:
+# include:
+# # from mpi4py import MPI
+# # ImportError: DLL load failed while importing MPI: The specified module could not be found.
+# - os: windows-2019-py38-amd64
+# python: 38
+# bitness: 64
+# platform_id: win_amd64
+# # from mpi4py import MPI
+# # ImportError: DLL load failed while importing MPI: The specified module could not be found.
+# - os: windows-latest-py39-amd64
+# python: 39
+# bitness: 64
+# platform_id: win_amd64
+# # ERROR: No matching distribution found for MNN==1.1.6
+# - os: windows-latest-py310-amd64
+# python: 310
+# bitness: 64
+# platform_id: win_amd64
+# # #RROR: No matching distribution found for torch==1.11.0
+# - os: windows-latest-py38-win32
+# python: 38
+# bitness: 32
+# platform_id: win32
+# # ERROR: No matching distribution found for torch==1.11.0
+# - os: windows-latest-py39-win32
+# python: 39
+# bitness: 32
+# platform_id: win32
+#
+# # auditwheel.main_repair:This does not look like a platform wheel
+# - os: ubuntu-latest-py38-x86
+# python: 38
+# bitness: 64
+# platform_id: manylinux_x86_64
+# manylinux_image: manylinux2014
+# # auditwheel.main_repair:This does not look like a platform wheel
+# - os: ubuntu-latest-py39-x86
+# python: 39
+# bitness: 64
+# platform_id: manylinux_x86_64
+# manylinux_image: manylinux2014
+# # auditwheel.main_repair:This does not look like a platform wheel
+# - os: ubuntu-latest-py310-x86
+# python: 310
+# bitness: 64
+# platform_id: manylinux_x86_64
+# manylinux_image: manylinux2014
+#
+# # _configtest.c:2:10: fatal error: 'mpi.h' file not found
+# - os: macos-latest-py38-x86
+# bitness: 64
+# python: 38
+# platform_id: macosx_x86_64
+# # _configtest.c:2:10: fatal error: 'mpi.h' file not found
+# - os: macos-latest-py39-x86
+# bitness: 64
+# python: 39
+# platform_id: macosx_x86_64
+# # _configtest.c:2:10: fatal error: 'mpi.h' file not found
+# - os: macos-latest-py310-x86
+# bitness: 64
+# python: 310
+# platform_id: macosx_x86_64
+#
+# # MacOS arm64
+# - os: macos-latest-py38-arm64
+# bitness: 64
+# python: 38
+# platform_id: macosx_arm64
+# - os: macos-latest-py39-arm64
+# bitness: 64
+# python: 39
+# platform_id: macosx_arm64
+# - os: macos-latest-py310-arm64
+# bitness: 64
+# python: 310
+# platform_id: macosx_arm64
+
+# steps:
+# - uses: actions/checkout@v3
+#
+# # Used to host cibuildwheel
+# - uses: actions/setup-python@v3
+#
+## - name: Install cibuildwheel
+## run: python -m pip install cibuildwheel==2.7.0
+#
+# - name: Build wheels
+# working-directory: ./python
+# env:
+# CIBW_BUILD: cp${{ matrix.python }}-${{ matrix.platform_id }}
+# CIBW_ARCHS: all
+# CIBW_REPAIR_WHEEL_COMMAND_WINDOWS: bash {project}/python/build_tools/github/repair_windows_wheels.sh {wheel} {dest_dir} ${{ matrix.bitness }}
+# CIBW_BEFORE_TEST_WINDOWS: bash {project}/python/build_tools/github/build_minimal_windows_image.sh ${{ matrix.python }} ${{ matrix.bitness }}
+# CIBW_TEST_COMMAND: bash {project}/python/build_tools/github/test_wheels.sh
+# CIBW_TEST_COMMAND_WINDOWS: bash {project}/python/build_tools/github/test_windows_wheels.sh ${{ matrix.python }} ${{ matrix.bitness }}
+# CIBW_BUILD_VERBOSITY: 1
+# run: |
+# python -m pip install -U wheel setuptools
+# python setup.py sdist bdist_wheel
+# run: python -m cibuildwheel --output-dir dist
+## run: cd {project}/python && python -m cibuildwheel --output-dir wheelhouse
+# # to supply options, put them in 'env', like (test)
+# # env:
+# # CIBW_SOME_OPTION: value
+#
+# - name: Upload source zip file
+# uses: actions/upload-artifact@v2
+# with:
+# path: python/dist/*.tar.gz
+# - name: Upload Wheels
+# uses: actions/upload-artifact@v2
+# with:
+# path: ./python/dist/*.whl
+
+# upload_pypi:
+# needs: [build_wheels]
+# runs-on: [self-hosted]
+# # upload to PyPI on every tag starting with 'v'
+## if: github.event_name == 'push' && contains(github.event.comment, 'release v')
+# # alternatively, to publish when a GitHub Release is created, use the following rule:
+# # if: github.event_name == 'release' && github.event.action == 'published'
+# steps:
+# - uses: actions/download-artifact@v2
+# with:
+# name: artifact
+# path: dist
+#
+# - uses: pypa/gh-action-pypi-publish@v1.4.2
+# with:
+# skip_existing: true
+# user: chaoyanghe
+# password: ${{ secrets.pypi_password }}
+# # To test: repository_url: https://test.pypi.org/legacy/
\ No newline at end of file
diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml
new file mode 100644
index 0000000000..83b30a6784
--- /dev/null
+++ b/.github/workflows/codeql-analysis.yml
@@ -0,0 +1,74 @@
+# For most projects, this workflow file will not need changing; you simply need
+# to commit it to your repository.
+#
+# You may wish to alter this file to override the set of languages analyzed,
+# or to provide custom queries or build logic.
+#
+# ******** NOTE ********
+# We have attempted to detect the languages in your repository. Please check
+# the `language` matrix defined below to confirm you have the correct set of
+# supported CodeQL languages.
+#
+name: "CodeQL"
+
+on:
+ push:
+ branches: [ "master" ]
+ pull_request:
+ # The branches below must be a subset of the branches above
+ branches: [ "master" ]
+ schedule:
+ - cron: '34 20 * * 4'
+
+jobs:
+ analyze:
+ name: Analyze
+ runs-on: ubuntu-latest
+ permissions:
+ actions: read
+ contents: read
+ security-events: write
+
+ strategy:
+ fail-fast: false
+ matrix:
+ language: [ 'python' ]
+ # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
+ # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+
+ # Initializes the CodeQL tools for scanning.
+ - name: Initialize CodeQL
+ uses: github/codeql-action/init@v2
+ with:
+ languages: ${{ matrix.language }}
+ # If you wish to specify custom queries, you can do so here or in a config file.
+ # By default, queries listed here will override any specified in a config file.
+ # Prefix the list here with "+" to use these queries and those in the config file.
+
+ # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
+ # queries: security-extended,security-and-quality
+
+
+ # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
+ # If this step fails, then you should remove it and run the build manually (see below)
+ - name: Autobuild
+ uses: github/codeql-action/autobuild@v2
+
+ # âšī¸ Command-line programs to run using the OS shell.
+ # đ See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
+
+ # If the Autobuild fails above, remove it and uncomment the following three lines.
+ # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
+
+ # - run: |
+ # echo "Run, Build Application using script"
+ # ./location_of_script_within_repo/buildscript.sh
+
+ - name: Perform CodeQL Analysis
+ uses: github/codeql-action/analyze@v2
+ with:
+ category: "/language:${{matrix.language}}"
diff --git a/.github/workflows/full_e2e_test.yml-bakcup b/.github/workflows/full_e2e_test.yml-bakcup
new file mode 100644
index 0000000000..4dd408fbd8
--- /dev/null
+++ b/.github/workflows/full_e2e_test.yml-bakcup
@@ -0,0 +1,104 @@
+# This is a basic workflow to help you get started with Actions
+
+name: Full End-to-end Test
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ push:
+ branches: [ master, test/v0.7.0 ]
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ # This workflow contains a single job called "mlops-cli-test"
+ mlops-cli-test:
+ strategy:
+ matrix:
+ os: [ ubuntu-latest, windows-latest, macos-latest ]
+ arch: [X64, ARM64]
+ python-version: ['3.8']
+ # The type of runner that the job will run on
+ runs-on: [self-hosted, devops]
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+
+ # Steps represent a sequence of tasks that will be executed as part of the job
+ steps:
+ # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ # - name: pip install -e ./
+ # run: |
+ # echo "install pip"
+ # sudo apt-get install python3-pip << eof
+ #
+ # y
+ #
+ # eof
+ #
+ # echo "pip install -e ./"
+ # pip install -e ./
+
+ - name: test sp_fedavg_mnist_lr_example
+ run: |
+ echo "this is for test sp_fedavg_mnist_lr_example"
+ cd examples/simulation/sp_fedavg_mnist_lr_example
+
+ python torch_fedavg_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp_fedopt_mnist_lr_example
+ run: |
+ echo "this is for test sp_fedopt_mnist_lr_example"
+ cd examples/simulation/sp_fedopt_mnist_lr_example
+
+ python torch_fedopt_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp_fednova_mnist_lr_example
+ run: |
+ echo "this is for test sp_fednova_mnist_lr_example"
+ cd examples/simulation/sp_fednova_mnist_lr_example
+
+ python torch_fednova_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp_turboaggregate_mnist_lr_example
+ run: |
+ echo "this is for test sp_turboaggregate_mnist_lr_example"
+ cd examples/simulation/sp_turboaggregate_mnist_lr_example
+
+ python torch_turboaggregate_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp_hierarchicalfl_mnist_lr_example
+ run: |
+ echo "this is for test sp_hierarchicalfl_mnist_lr_example"
+ cd examples/simulation/sp_hierarchicalfl_mnist_lr_example
+
+ python torch_hierarchicalfl_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp_vertical_mnist_lr_example
+ run: |
+ echo "this is for test sp_vertical_mnist_lr_example"
+ cd examples/simulation/sp_vertical_mnist_lr_example
+
+ python torch_vertical_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp_fedsgd_cifar10_resnet20_example
+ run: |
+ echo "this is for test sp_fedsgd_cifar10_resnet20_example"
+
+ cd examples/simulation/sp_fedsgd_cifar10_resnet20_example
+
+ python sp_fedsgd_cifar10_resnet20_example.py --cf eftopk_config.yaml
+
+ - name: test example B
+ run: |
+ echo "this is for test example B"
+ echo "second line of the script"
\ No newline at end of file
diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml
index a4f4bf6f58..88632bfafe 100644
--- a/.github/workflows/pylint.yml
+++ b/.github/workflows/pylint.yml
@@ -1,23 +1,39 @@
-name: Pylint
+name: Pylint - FedML
-on: [push]
+on:
+ pull_request:
+ branches: [ master, test/v0.7.0, dev/0.7.0 ]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
jobs:
build:
- runs-on: ubuntu-latest
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ runs-on: [self-hosted, runner-linux, devops]
strategy:
matrix:
- python-version: ["3.7", "3.8", "3.9", "3.10"]
+ os: [ ubuntu-latest ]
+ arch: [ X64 ]
+ python-version: ["3.8"]
steps:
- - uses: actions/checkout@v3
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v3
+ - uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install pylint
+ - uses: actions/checkout@v3
- name: Analysing the code with pylint
run: |
- pylint $(git ls-files '*.py')
+ python3 -m pip install --upgrade pip
+ pip install pylint
+ pip install "fedml[gRPC]"
+ pip install "fedml[tensorflow]"
+ pip install "fedml[jax]"
+ pip install "fedml[mxnet]"
+ pip install tensorflow_federated
+ pip install mxnet
+ pip install jax
+ pip install ptflops
+ pylint --rcfile=build_tools/lint/.pylintrc --disable=C,R,W,I ./
diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml
deleted file mode 100644
index ec703542be..0000000000
--- a/.github/workflows/python-publish.yml
+++ /dev/null
@@ -1,39 +0,0 @@
-# This workflow will upload a Python Package using Twine when a release is created
-# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
-
-# This workflow uses actions that are not certified by GitHub.
-# They are provided by a third-party and are governed by
-# separate terms of service, privacy policy, and support
-# documentation.
-
-name: Upload Python Package
-
-on:
- release:
- types: [published]
-
-permissions:
- contents: read
-
-jobs:
- deploy:
-
- runs-on: ubuntu-latest
-
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v3
- with:
- python-version: '3.x'
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install build
- - name: Build package
- run: python -m build
- - name: Publish package
- uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
- with:
- user: __token__
- password: ${{ secrets.PYPI_API_TOKEN }}
diff --git a/.github/workflows/runner.md b/.github/workflows/runner.md
new file mode 100644
index 0000000000..6c17e15ff0
--- /dev/null
+++ b/.github/workflows/runner.md
@@ -0,0 +1,71 @@
+# Install GitHub runner with your own computer:
+ssh -i amir-github-actions-key.cer ubuntu@54.183.200.162
+ssh -i amir-github-actions-key.cer ubuntu@52.53.164.162
+ssh -i github_actions.cer ubuntu@54.153.18.24
+
+sudo rpm -Uvh https://packages.microsoft.com/config/rhel/7/packages-microsoft-prod.rpm
+sudo apt-get update && sudo apt-get install -y dotnet6
+dotnet --version
+#install runner based on the following url: https://github.com/FedML-AI/FedML/settings/actions/runners/new?arch=x64&os=linux
+sudo ./svc.sh install
+sudo ./svc.sh start
+sudo ./svc.sh status
+
+# Install GitHub runner in Ubuntu from AWS:
+ssh -i "fedml-github-action.pem" ubuntu@ec2-54-176-61-229.us-west-1.compute.amazonaws.com
+ssh -i "fedml-github-action.pem" ubuntu@ec2-54-219-186-81.us-west-1.compute.amazonaws.com
+ssh -i "fedml-github-action.pem" ubuntu@ec2-54-219-187-134.us-west-1.compute.amazonaws.com
+
+sudo rpm -Uvh https://packages.microsoft.com/config/rhel/7/packages-microsoft-prod.rpm
+sudo apt-get update && sudo apt-get install -y dotnet6
+dotnet --version
+#install runner based on the following url: https://github.com/FedML-AI/FedML/settings/actions/runners/new?arch=x64&os=linux
+
+sudo ./svc.sh install
+sudo ./svc.sh start
+sudo ./svc.sh status
+
+
+# Install GitHub runner in Windows from AWS:
+1. You may connect to AWS Windows server by RDP client from MAC AppStore based on the url: https://docs.microsoft.com/en-us/windows-server/remote/remote-desktop-services/clients/remote-desktop-mac
+
+host: ec2-184-169-242-201.us-west-1.compute.amazonaws.com
+
+2. Enabling Windows Long Path on Windows based on the following url:
+ https://www.microfocus.com/documentation/filr/filr-4/filr-desktop/t47bx2ogpfz7.html
+
+3. install runner based on the following url: https://github.com/FedML-AI/FedML/settings/actions/runners/new?arch=x64&os=win
+
+# Runner List
+```
+# Windows:
+ec2-184-169-242-201.us-west-1.compute.amazonaws.com
+ec2-54-193-88-223.us-west-1.compute.amazonaws.com
+ec2-54-151-36-0.us-west-1.compute.amazonaws.com
+
+# Linux:
+ec2-54-176-61-229.us-west-1.compute.amazonaws.com
+ec2-54-219-186-81.us-west-1.compute.amazonaws.com
+ec2-54-219-187-134.us-west-1.compute.amazonaws.com
+ec2-13-57-8-59.us-west-1.compute.amazonaws.com
+ec2-3-101-104-5.us-west-1.compute.amazonaws.com
+ec2-13-57-240-161.us-west-1.compute.amazonaws.com
+ec2-3-101-61-77.us-west-1.compute.amazonaws.com
+
+ec2-54-215-107-43.us-west-1.compute.amazonaws.com
+ec2-13-56-228-205.us-west-1.compute.amazonaws.com
+ec2-13-57-49-67.us-west-1.compute.amazonaws.com
+ec2-18-144-32-82.us-west-1.compute.amazonaws.com
+```
+
+```
+# useful commands
+sudo apt update
+sudo apt install libopenmpi-dev openmpi-bin
+sudo apt install python3
+sudo apt install python-is-python3
+sudo apt install pip
+pip install -U fedml
+pip install mpi4py
+nohup bash run.sh > action.log 2>&1 &
+```
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_cross_device_mnn_server_linux.yml b/.github/workflows/smoke_test_cross_device_mnn_server_linux.yml
new file mode 100644
index 0000000000..dd98309fa1
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_device_mnn_server_linux.yml
@@ -0,0 +1,51 @@
+# This is a basic workflow to help you get started with Actions
+
+name: CROSS-DEVICE-MNN-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-device-mnn-server:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest ]
+ arch: [X64]
+ python-version: ['3.8']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops, mnn]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: test server of cross-device
+ run: |
+ cd quick_start/beehive
+ timeout 60 bash run_server.sh || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
+
diff --git a/.github/workflows/smoke_test_cross_silo_fedavg_attack_linux.yml b/.github/workflows/smoke_test_cross_silo_fedavg_attack_linux.yml
new file mode 100644
index 0000000000..c82d2ef8cd
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_fedavg_attack_linux.yml
@@ -0,0 +1,86 @@
+# This is a basic workflow to help you get started with Actions
+
+name: Attacker-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-attack-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2', '3', '4']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - cross-silo - attack
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_attack_mnist_lr_example
+ run_id=cross-silo-attack-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - attack
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_attack_mnist_lr_example
+ run_id=cross-silo-attack-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - attack
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_attack_mnist_lr_example
+ run_id=cross-silo-attack-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
+
+ - name: client 3 - cross-silo - attack
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_attack_mnist_lr_example
+ run_id=cross-silo-attack-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 3 $run_id
+ if: ${{ matrix.client-index == '3' }}
+
+ - name: client 4 - cross-silo - attack
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_attack_mnist_lr_example
+ run_id=cross-silo-attack-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 4 $run_id
+ if: ${{ matrix.client-index == '4' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_cross_silo_fedavg_cdp_linux.yml b/.github/workflows/smoke_test_cross_silo_fedavg_cdp_linux.yml
new file mode 100644
index 0000000000..3dd3fa6a59
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_fedavg_cdp_linux.yml
@@ -0,0 +1,70 @@
+# This is a basic workflow to help you get started with Actions
+
+name: CDP-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-cdp-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - cross-silo - cdp
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_cdp_mnist_lr_example
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - cdp
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_cdp_mnist_lr_example
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - cdp
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_cdp_mnist_lr_example
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_cross_silo_fedavg_defense_linux.yml b/.github/workflows/smoke_test_cross_silo_fedavg_defense_linux.yml
new file mode 100644
index 0000000000..d4d1dae9c2
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_fedavg_defense_linux.yml
@@ -0,0 +1,87 @@
+# This is a basic workflow to help you get started with Actions
+
+name: Defender-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-defense-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2', '3', '4']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install sklearn
+
+ - name: server - cross-silo - defense
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_defense_mnist_lr_example
+ run_id=cross-silo-defense-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - defense
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_defense_mnist_lr_example
+ run_id=cross-silo-defense-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - defense
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_defense_mnist_lr_example
+ run_id=cross-silo-defense-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
+
+ - name: client 3 - cross-silo - defense
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_defense_mnist_lr_example
+ run_id=cross-silo-defense-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 3 $run_id
+ if: ${{ matrix.client-index == '3' }}
+
+ - name: client 4 - cross-silo - defense
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_defense_mnist_lr_example
+ run_id=cross-silo-defense-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 4 $run_id
+ if: ${{ matrix.client-index == '4' }}
diff --git a/.github/workflows/smoke_test_cross_silo_fedavg_ldp_linux.yml b/.github/workflows/smoke_test_cross_silo_fedavg_ldp_linux.yml
new file mode 100644
index 0000000000..1e48b262bb
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_fedavg_ldp_linux.yml
@@ -0,0 +1,70 @@
+# This is a basic workflow to help you get started with Actions
+
+name: LDP-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-ldp-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - cross-silo - ldp
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_ldp_mnist_lr_example
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - ldp
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_ldp_mnist_lr_example
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - ldp
+ run: |
+ cd examples/cross_silo/mqtt_s3_fedavg_ldp_mnist_lr_example
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_cross_silo_ho_linux.yml b/.github/workflows/smoke_test_cross_silo_ho_linux.yml
new file mode 100644
index 0000000000..4c3a7ce5ff
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_ho_linux.yml
@@ -0,0 +1,70 @@
+# This is a basic workflow to help you get started with Actions
+
+name: CROSS-SILO-HO-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-horizontal-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - cross-silo - ho
+ run: |
+ cd quick_start/octopus
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - ho
+ run: |
+ cd quick_start/octopus
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - ho
+ run: |
+ cd quick_start/octopus
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_cross_silo_ho_win.yml b/.github/workflows/smoke_test_cross_silo_ho_win.yml
new file mode 100644
index 0000000000..1c6f59f0e9
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_ho_win.yml
@@ -0,0 +1,71 @@
+# This is a basic workflow to help you get started with Actions
+
+name: CROSS-SILO-HO-Win
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-horizontal-test:
+ defaults:
+ run:
+ shell: powershell
+ working-directory: fedml-devops\python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ windows-2019 ]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-windows, devops]
+ timeout-minutes: 15
+ steps:
+ - name: cleanup running processes
+ continue-on-error: true
+ run: |
+ wmic.exe /interactive:off process where "name='python.exe'" call terminate
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ with:
+ path: fedml-devops
+ clean: true
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - cross-silo - ho
+ run: |
+ cd quick_start/octopus
+ .\run_server.bat ${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - ho
+ run: |
+ cd quick_start/octopus
+ .\run_client.bat 1 ${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - ho
+ run: |
+ cd quick_start/octopus
+ .\run_client.bat 2 ${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '2' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_cross_silo_lightsecagg_linux.yml b/.github/workflows/smoke_test_cross_silo_lightsecagg_linux.yml
new file mode 100644
index 0000000000..ace40806f8
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_lightsecagg_linux.yml
@@ -0,0 +1,70 @@
+# This is a basic workflow to help you get started with Actions
+
+name: LightSecAgg-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-horizontal-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - cross-silo - lightsecagg
+ run: |
+ cd examples/cross_silo/light_sec_agg_example
+ run_id=cross-silo-lightsecagg-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - lightsecagg
+ run: |
+ cd examples/cross_silo/light_sec_agg_example
+ run_id=cross-silo-lightsecagg-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - lightsecagg
+ run: |
+ cd examples/cross_silo/light_sec_agg_example
+ run_id=cross-silo-lightsecagg-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_cross_silo_lightsecagg_win.yml b/.github/workflows/smoke_test_cross_silo_lightsecagg_win.yml
new file mode 100644
index 0000000000..e807294811
--- /dev/null
+++ b/.github/workflows/smoke_test_cross_silo_lightsecagg_win.yml
@@ -0,0 +1,71 @@
+# This is a basic workflow to help you get started with Actions
+
+name: LightSecAgg-Windows
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ cross-silo-horizontal-test:
+ defaults:
+ run:
+ shell: powershell
+ working-directory: fedml-devops\python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ windows-2019 ]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-windows, devops]
+ timeout-minutes: 15
+ steps:
+ - name: cleanup running processes
+ continue-on-error: true
+ run: |
+ wmic.exe /interactive:off process where "name='python.exe'" call terminate
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ with:
+ path: fedml-devops
+ clean: true
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - cross-silo - ho
+ run: |
+ cd examples/cross_silo/light_sec_agg_example
+ .\run_server.bat cross-silo-lightsecagg-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - cross-silo - ho
+ run: |
+ cd examples/cross_silo/light_sec_agg_example
+ .\run_client.bat 1 cross-silo-lightsecagg-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - cross-silo - lightsecagg
+ run: |
+ cd examples/cross_silo/light_sec_agg_example
+ .\run_client.bat 2 cross-silo-lightsecagg-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '2' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_flow_linux.yml b/.github/workflows/smoke_test_flow_linux.yml
new file mode 100644
index 0000000000..5b03a511f0
--- /dev/null
+++ b/.github/workflows/smoke_test_flow_linux.yml
@@ -0,0 +1,70 @@
+# This is a basic workflow to help you get started with Actions
+
+name: Flow-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ Flow-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: server - Flow
+ run: |
+ cd fedml/core/distributed/flow
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash test_run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - Flow
+ run: |
+ cd fedml/core/distributed/flow
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash test_run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - Flow
+ run: |
+ cd fedml/core/distributed/flow
+ run_id=cross-silo-ho-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash test_run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_ml_engines_linux_jax.yml b/.github/workflows/smoke_test_ml_engines_linux_jax.yml
new file mode 100644
index 0000000000..976381e1d5
--- /dev/null
+++ b/.github/workflows/smoke_test_ml_engines_linux_jax.yml
@@ -0,0 +1,71 @@
+# This is a basic workflow to help you get started with Actions
+
+name: ML-Engines-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ jax-ml-engines-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest ]
+ arch: [ X64 ]
+ python-version: [ '3.8' ]
+ client-index: [ '0', '1', '2' ]
+ # exclude:
+ # - os: macos-latest
+ # python-version: '3.8'
+ # - os: windows-latest
+ # python-version: '3.6'
+ runs-on: [ self-hosted, runner-linux, devops ]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install -e '.[jax]'
+
+ - name: server - jax - fedavg
+ run: |
+ cd examples/cross_silo/jax_haiku_mqtt_s3_fedavg_mnist_lr_example
+ run_id=jax-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - jax - fedavg
+ run: |
+ cd examples/cross_silo/jax_haiku_mqtt_s3_fedavg_mnist_lr_example
+ run_id=jax-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - jax - fedavg
+ run: |
+ cd examples/cross_silo/jax_haiku_mqtt_s3_fedavg_mnist_lr_example
+ run_id=jax-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
diff --git a/.github/workflows/smoke_test_ml_engines_linux_maxnet.yml b/.github/workflows/smoke_test_ml_engines_linux_maxnet.yml
new file mode 100644
index 0000000000..e13eb8a92f
--- /dev/null
+++ b/.github/workflows/smoke_test_ml_engines_linux_maxnet.yml
@@ -0,0 +1,71 @@
+# This is a basic workflow to help you get started with Actions
+
+name: ML-Engines-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ mxnet-ml-engines-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest ]
+ arch: [ X64 ]
+ python-version: [ '3.8' ]
+ client-index: [ '0', '1', '2' ]
+ # exclude:
+ # - os: macos-latest
+ # python-version: '3.8'
+ # - os: windows-latest
+ # python-version: '3.6'
+ runs-on: [ self-hosted, runner-linux, devops ]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install -e '.[mxnet]'
+
+ - name: server - mxnet - fedavg
+ run: |
+ cd examples/cross_silo/mxnet_mqtt_s3_fedavg_mnist_lr_example
+ run_id=mxnet-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - mxnet - fedavg
+ run: |
+ cd examples/cross_silo/mxnet_mqtt_s3_fedavg_mnist_lr_example
+ run_id=mxnet-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - mxnet - fedavg
+ run: |
+ cd examples/cross_silo/mxnet_mqtt_s3_fedavg_mnist_lr_example
+ run_id=mxnet-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
diff --git a/.github/workflows/smoke_test_ml_engines_linux_tf.yml b/.github/workflows/smoke_test_ml_engines_linux_tf.yml
new file mode 100644
index 0000000000..2fa0a0f3d8
--- /dev/null
+++ b/.github/workflows/smoke_test_ml_engines_linux_tf.yml
@@ -0,0 +1,71 @@
+# This is a basic workflow to help you get started with Actions
+
+name: ML-Engines-Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ tf-ml-engines-test:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+ client-index: ['0', '1', '2']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install -e '.[tensorflow]'
+
+ - name: server - tensorflow - fedavg
+ run: |
+ cd examples/cross_silo/tf_mqtt_s3_fedavg_mnist_lr_example
+ run_id=tf-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_server.sh $run_id
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - tensorflow - fedavg
+ run: |
+ cd examples/cross_silo/tf_mqtt_s3_fedavg_mnist_lr_example
+ run_id=tf-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 1 $run_id
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - tensorflow - fedavg
+ run: |
+ cd examples/cross_silo/tf_mqtt_s3_fedavg_mnist_lr_example
+ run_id=tf-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ echo ${run_id}
+ bash run_client.sh 2 $run_id
+ if: ${{ matrix.client-index == '2' }}
diff --git a/.github/workflows/smoke_test_ml_engines_win.yml b/.github/workflows/smoke_test_ml_engines_win.yml
new file mode 100644
index 0000000000..169b4ac9c6
--- /dev/null
+++ b/.github/workflows/smoke_test_ml_engines_win.yml
@@ -0,0 +1,162 @@
+# This is a basic workflow to help you get started with Actions
+
+name: ML-Engines-Win
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ tf-ml-engines-test:
+ defaults:
+ run:
+ shell: powershell
+ working-directory: fedml-devops\python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ windows-2019 ]
+ arch: [ X64 ]
+ python-version: [ '3.8' ]
+ runs-on: [self-hosted, runner-windows, devops]
+ timeout-minutes: 15
+ steps:
+ - name: cleanup running processes
+ continue-on-error: true
+ run: |
+ wmic.exe /interactive:off process where "name='python.exe'" call terminate
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ with:
+ path: fedml-devops
+ clean: true
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install -e '.[tensorflow]'
+
+ - name: server - tensorflow - fedavg
+ run: |
+ cd examples/cross_silo/tf_mqtt_s3_fedavg_mnist_lr_example
+ python tf_server.py --cf config/fedml_config.yaml --rank 0 --role server --run_id tf-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - tensorflow - fedavg
+ run: |
+ cd examples/cross_silo/tf_mqtt_s3_fedavg_mnist_lr_example
+ python3 tf_client.py --cf config/fedml_config.yaml --rank 1 --role client --run_id tf-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - tensorflow - fedavg
+ run: |
+ cd examples/cross_silo/tf_mqtt_s3_fedavg_mnist_lr_example
+ python3 tf_client.py --cf config/fedml_config.yaml --rank 2 --role client --run_id tf-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '2' }}
+
+ jax-ml-engines-test:
+ defaults:
+ run:
+ shell: powershell
+ working-directory: fedml-devops\python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ windows-2019 ]
+ arch: [ X64 ]
+ python-version: [ '3.8' ]
+ runs-on: [ self-hosted, runner-windows, devops ]
+ timeout-minutes: 15
+ steps:
+ - name: cleanup running processes
+ continue-on-error: true
+ run: |
+ wmic.exe /interactive:off process where "name='python.exe'" call terminate
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ with:
+ path: fedml-devops
+ clean: true
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install -e '.[jax]'
+
+ - name: server - jax - fedavg
+ run: |
+ cd examples/cross_silo/jax_haiku_mqtt_s3_fedavg_mnist_lr_example
+ python tf_server.py --cf config/fedml_config.yaml --rank 0 --role server --run_id jax-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - jax - fedavg
+ run: |
+ cd examples/cross_silo/jax_haiku_mqtt_s3_fedavg_mnist_lr_example
+ python3 tf_client.py --cf config/fedml_config.yaml --rank 1 --role client --run_id jax-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - jax - fedavg
+ run: |
+ cd examples/cross_silo/jax_haiku_mqtt_s3_fedavg_mnist_lr_example
+ python3 tf_client.py --cf config/fedml_config.yaml --rank 2 --role client --run_id jax-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '2' }}
+
+ mxnet-ml-engines-test:
+ defaults:
+ run:
+ shell: powershell
+ working-directory: fedml-devops\python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ windows-2019 ]
+ arch: [ X64 ]
+ python-version: [ '3.8' ]
+ runs-on: [ self-hosted, runner-windows, devops ]
+ timeout-minutes: 15
+ steps:
+ - name: cleanup running processes
+ continue-on-error: true
+ run: |
+ wmic.exe /interactive:off process where "name='python.exe'" call terminate
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/checkout@v3
+ with:
+ path: fedml-devops
+ clean: true
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install -e '.[mxnet]'
+
+ - name: server - mxnet - fedavg
+ run: |
+ cd examples/cross_silo/mxnet_mqtt_s3_fedavg_mnist_lr_example
+ python tf_server.py --cf config/fedml_config.yaml --rank 0 --role server --run_id mxnet-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '0' }}
+
+ - name: client 1 - mxnet - fedavg
+ run: |
+ cd examples/cross_silo/mxnet_mqtt_s3_fedavg_mnist_lr_example
+ python3 tf_client.py --cf config/fedml_config.yaml --rank 1 --role client --run_id mxnet-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '1' }}
+
+ - name: client 2 - mxnet - fedavg
+ run: |
+ cd examples/cross_silo/mxnet_mqtt_s3_fedavg_mnist_lr_example
+ python3 tf_client.py --cf config/fedml_config.yaml --rank 2 --role client --run_id mxnet-ml-engine-${{ format('{0}{1}{2}{3}', github.run_id, matrix.os, matrix.arch, matrix.python-version) }}
+ if: ${{ matrix.client-index == '2' }}
diff --git a/.github/workflows/smoke_test_pip_cli_sp_linux.yml b/.github/workflows/smoke_test_pip_cli_sp_linux.yml
new file mode 100644
index 0000000000..d5381ccb4c
--- /dev/null
+++ b/.github/workflows/smoke_test_pip_cli_sp_linux.yml
@@ -0,0 +1,88 @@
+# This is a basic workflow to help you get started with Actions
+
+name: PIP, CLI, SP - On Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0, dev/0.7.0 ]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+permissions: write-all
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ pip-install-fedml-and-test-sp:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ # https://github.com/actions/checkout/issues/116#issuecomment-644419389
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: test "fedml login" and "fedml build"
+ run: |
+ cd tests/smoke_test/cli
+ bash login.sh
+ bash build.sh
+ - name: test simulation-sp
+ run: |
+ cd quick_start/parrot
+ python torch_fedavg_mnist_lr_one_line_example.py --cf fedml_config.yaml
+ python torch_fedavg_mnist_lr_custum_data_and_model_example.py --cf fedml_config.yaml
+
+ - name: test sp - sp_decentralized_mnist_lr_example
+ run: |
+ cd examples/simulation/sp_decentralized_mnist_lr_example
+ python torch_fedavg_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp - sp_fednova_mnist_lr_example
+ run: |
+ cd examples/simulation/sp_fednova_mnist_lr_example
+ python torch_fednova_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp - sp_fedopt_mnist_lr_example
+ run: |
+ cd examples/simulation/sp_fedopt_mnist_lr_example
+ python torch_fedopt_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp - sp_hierarchicalfl_mnist_lr_example
+ run: |
+ cd examples/simulation/sp_hierarchicalfl_mnist_lr_example
+ python torch_hierarchicalfl_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp - sp_turboaggregate_mnist_lr_example
+ run: |
+ cd examples/simulation/sp_turboaggregate_mnist_lr_example
+ python torch_turboaggregate_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
+
+ - name: test sp - sp_vertical_mnist_lr_example
+ run: |
+ cd examples/simulation/sp_vertical_mnist_lr_example
+ python torch_vertical_mnist_lr_step_by_step_example.py --cf fedml_config.yaml
diff --git a/.github/workflows/smoke_test_pip_cli_sp_win.yml b/.github/workflows/smoke_test_pip_cli_sp_win.yml
new file mode 100644
index 0000000000..a010eabd57
--- /dev/null
+++ b/.github/workflows/smoke_test_pip_cli_sp_win.yml
@@ -0,0 +1,64 @@
+# This is a basic workflow to help you get started with Actions
+
+name: PIP, CLI, SP - On Windows
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0, dev/0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ pip-install-fedml-and-test-sp:
+ defaults:
+ run:
+ shell: powershell
+ working-directory: fedml-devops\python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ windows-2019 ]
+ arch: [X64]
+ python-version: ['3.8']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-windows, devops]
+ timeout-minutes: 15
+ steps:
+ - name: cleanup running processes
+ continue-on-error: true
+ run: |
+ wmic.exe /interactive:off process where "name='python.exe'" call terminate
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ # https://github.com/actions/checkout/issues/116#issuecomment-644419389
+ - uses: actions/checkout@v3
+ with:
+ path: fedml-devops
+ clean: true
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+
+ - name: test "fedml login" and "fedml build"
+ run: |
+ cd tests/smoke_test/cli
+ .\login.bat
+ .\build.bat
+ - name: test simulation-sp
+ run: |
+ cd quick_start/parrot
+ python torch_fedavg_mnist_lr_one_line_example.py --cf fedml_config.yaml
+ python torch_fedavg_mnist_lr_custum_data_and_model_example.py --cf fedml_config.yaml
diff --git a/.github/workflows/smoke_test_security.yml b/.github/workflows/smoke_test_security.yml
new file mode 100644
index 0000000000..7c72d69789
--- /dev/null
+++ b/.github/workflows/smoke_test_security.yml
@@ -0,0 +1,59 @@
+# This is a basic workflow to help you get started with Actions
+
+name: Security(attack/defense) on Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0 ]
+ types: [opened, reopened]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+permissions: write-all
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+ security-attack-defense-tests:
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ ubuntu-latest]
+ arch: [X64]
+ python-version: ['3.8']
+# exclude:
+# - os: macos-latest
+# python-version: '3.8'
+# - os: windows-latest
+# python-version: '3.6'
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ steps:
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ # https://github.com/actions/checkout/issues/116#issuecomment-644419389
+ - uses: actions/checkout@v3
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install sklearn
+
+ - name: attack tests
+ run: |
+ cd tests/security
+ sh run_attacker_tests.sh
+
+ - name: defense tests
+ run: |
+ cd tests/security
+ sh run_defender_tests.sh
\ No newline at end of file
diff --git a/.github/workflows/smoke_test_simulation_mpi_linux.yml b/.github/workflows/smoke_test_simulation_mpi_linux.yml
new file mode 100644
index 0000000000..b4e43799b2
--- /dev/null
+++ b/.github/workflows/smoke_test_simulation_mpi_linux.yml
@@ -0,0 +1,93 @@
+# This is a basic workflow to help you get started with Actions
+
+name: MPI - On Linux
+
+# Controls when the workflow will run
+on:
+ # Triggers the workflow on push or pull request events but only for the master branch
+ schedule:
+ # Nightly build at 12:12 A.M.
+ - cron: "12 12 */1 * *"
+ pull_request:
+ branches: [ master, test/v0.7.0, dev/0.7.0 ]
+
+ # Allows you to run this workflow manually from the Actions tab
+ workflow_dispatch:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+
+jobs:
+ # https://github.com/mpi4py/mpi4py/actions/runs/34979774/workflow
+ mpi_run:
+ runs-on: [self-hosted, runner-linux, devops]
+ timeout-minutes: 15
+ defaults:
+ run:
+ shell: bash
+ working-directory: python
+ strategy:
+ matrix:
+ python-version: [3.8]
+ mpi: [mpich]
+# mpi: [mpich, openmpi]
+ os: [ ubuntu-latest ]
+ include:
+ - os: ubuntu-latest
+ mpi: mpich
+ install-mpi: sudo apt install -y mpich libmpich-dev
+# - os: ubuntu-latest
+# mpi: openmpi
+# install-mpi: sudo apt install -y openmpi-bin libopenmpi-dev
+ steps:
+ - uses: actions/checkout@v3
+ - name: Install MPI
+ run: ${{ matrix.install-mpi }}
+ - name: Use Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packaging tools
+ run: python -m pip install --upgrade setuptools pip wheel
+ - name: Install build dependencies
+ run: python -m pip install --upgrade cython
+ - name: Build package
+ run: python -m pip wheel -vvv --wheel-dir=dist .
+ - name: Install test dependencies
+ run: python -m pip install --upgrade numpy
+ - name: pip install -e ./
+ run: |
+ pip install -e ./
+ pip install mpi4py==3.1.3
+
+ - name: Test package - FedAvg
+ run: |
+ cd examples/simulation/mpi_torch_fedavg_mnist_lr_example
+ sh run_custom_data_and_model_example.sh 4
+
+ - name: Test package - Base
+ run: |
+ cd examples/simulation/mpi_base_framework_example
+ sh run.sh 4
+
+ - name: Test package - Decentralized
+ run: |
+ cd examples/simulation/mpi_decentralized_fl_example
+ sh run.sh 4
+
+ - name: Test package - FedOPT
+ run: |
+ cd examples/simulation/mpi_fedopt_datasets_and_models_example
+ sh run_step_by_step_example.sh 4 config/mnist_lr/fedml_config.yaml
+
+ - name: Test package - FedProx
+ run: |
+ cd examples/simulation/mpi_fedprox_datasets_and_models_example
+ sh run_step_by_step_example.sh 4 config/mnist_lr/fedml_config.yaml
+
+ - name: Test package - FedGAN
+ run: |
+ cd examples/simulation/mpi_torch_fedgan_mnist_gan_example
+ sh run_step_by_step_example.sh 4
+
+ - name: Uninstall package after testing
+ run: python -m pip uninstall --yes mpi4py
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index ae1f7929c1..74c17e7ba5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,14 @@
.idea
.idea/*
.vscode/*
+*.pkl
+
+data/*
+python/tests/smoke_test/data
+python/tests/smoke_test/data/*
+
+android/fedmlsdk/MobileNN/MNN
+android/fedmlsdk/MobileNN/pytorch
*.pyc
./*/*.pyc
@@ -10,6 +18,12 @@
./*/*/*/*/*.pyc
./*/*/*/*/*/*.pyc
+python/data
+python/fedml.egg-info
+python/examples/simulation/sp_fedavg_mnist_lr_example/data
+python/examples/simulation/sp_fedsgd_cifar10_resnet20_example
+python/examples/simulation/sp_fedsgd_cifar10_resnet20_example/*
+
*.log
wandb
@@ -24,7 +38,7 @@ python/dist
python/FedML.egg-info
doc/deploy
-doc/en/_build
+doc/en/_build/doctrees
*.h5
@@ -37,6 +51,7 @@ cifar-10-batches-py
*.zip
*.json
*.tar.bz2
+*.stats
data/cifar10/*-python
data/cifar100/*-python
@@ -61,31 +76,117 @@ mimiconda.sh
data/fednlp/*
*.npz
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/stats.sh
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/stats.py
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/README.md
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/download_and_unzip.sh
-test/fedml_user_code/simulation_sp/data/mnist/__MACOSX/MNIST/._train
-test/fedml_user_code/simulation_sp/data/mnist/__MACOSX/MNIST/._test
-test/fedml_user_code/simulation_sp/data/mnist/__MACOSX/MNIST/._.DS_Store
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/stats.sh
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/stats.py
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/README.md
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/download_and_unzip.sh
-test/fedml_user_code/simulation_mpi/data/mnist/__MACOSX/MNIST/._train
-test/fedml_user_code/simulation_mpi/data/mnist/__MACOSX/MNIST/._test
-test/fedml_user_code/simulation_mpi/data/mnist/__MACOSX/MNIST/._.DS_Store
-test/fedml_user_code/simulation_sp/data/mnist/__MACOSX/MNIST/._.DS_Store
-test/fedml_user_code/simulation_sp/data/mnist/__MACOSX/MNIST/._test
-test/fedml_user_code/simulation_sp/data/mnist/__MACOSX/MNIST/._train
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/download_and_unzip.sh
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/README.md
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/stats.py
-test/fedml_user_code/simulation_sp/data/mnist/MNIST/stats.sh
-test/fedml_user_code/simulation_mpi/data/mnist/__MACOSX/MNIST/._.DS_Store
-test/fedml_user_code/simulation_mpi/data/mnist/__MACOSX/MNIST/._test
-test/fedml_user_code/simulation_mpi/data/mnist/__MACOSX/MNIST/._train
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/download_and_unzip.sh
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/README.md
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/stats.py
-test/fedml_user_code/simulation_mpi/data/mnist/MNIST/stats.sh
+python/tests/smoke_test/simulation_sp/data/mnist/MNIST/stats.sh
+python/tests/smoke_test/simulation_sp/data/mnist/MNIST/stats.py
+python/tests/smoke_test/simulation_sp/data/mnist/MNIST/README.md
+python/tests/smoke_test/simulation_sp/data/mnist/MNIST/download_and_unzip.sh
+python/tests/smoke_test/simulation_sp/data/mnist/__MACOSX/MNIST/._train
+python/tests/smoke_test/simulation_sp/data/mnist/__MACOSX/MNIST/._test
+python/tests/smoke_test/simulation_sp/data/mnist/__MACOSX/MNIST/._.DS_Store
+python/tests/smoke_test/simulation_mpi/data/mnist/MNIST/stats.sh
+python/tests/smoke_test/simulation_mpi/data/mnist/MNIST/stats.py
+python/tests/smoke_test/simulation_mpi/data/mnist/MNIST/README.md
+python/tests/smoke_test/simulation_mpi/data/mnist/MNIST/download_and_unzip.sh
+python/tests/smoke_test/simulation_mpi/data/mnist/__MACOSX/MNIST/._train
+python/tests/smoke_test/simulation_mpi/data/mnist/__MACOSX/MNIST/._test
+python/tests/smoke_test/simulation_mpi/data/mnist/__MACOSX/MNIST/._.DS_Store
+
+python/app/fednlp/text_classification/cache_dir
+python/app/fednlp/seq2seq/cache_dir
+python/app/fednlp/span_extraction/cache_dir
+
+app/fedgraphnn/moleculenet_graph_clf/data/bace/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/bace/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/bbbp/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/bbbp/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/clintox/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/clintox/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/hiv/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/hiv/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/lipo/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/lipo/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/muv/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/muv/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/pcba/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/pcba/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/sider/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/sider/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/tox21/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/tox21/*.npy
+app/fedgraphnn/moleculenet_graph_clf/data/toxcast/*.pkl
+app/fedgraphnn/moleculenet_graph_clf/data/toxcast/*.npy
+
+app/fedgraphnn/moleculenet_graph_reg/data/esol/*.pkl
+app/fedgraphnn/moleculenet_graph_reg/data/esol/*.npy
+app/fedgraphnn/moleculenet_graph_reg/data/freesolv/*.pkl
+app/fedgraphnn/moleculenet_graph_reg/data/freesolv/*.npy
+app/fedgraphnn/moleculenet_graph_reg/data/herg/*.pkl
+app/fedgraphnn/moleculenet_graph_reg/data/herg/*.npy
+app/fedgraphnn/moleculenet_graph_reg/data/lipo/*.pkl
+app/fedgraphnn/moleculenet_graph_reg/data/lipo/*.npy
+app/fedgraphnn/moleculenet_graph_reg/data/qm9/*.pkl
+app/fedgraphnn/moleculenet_graph_reg/data/qm9/*.npy
+
+
+python/app/fedgraphnn/social_networks_graph_clf/data/TUDataset/*
+python/app/fedgraphnn/social_networks_graph_clf/data/TUDataset/collab.pickle
+python/app/fedgraphnn/social_networks_graph_clf/x_hist.png
+
+python/app/fedgraphnn/social_networks_graph_clf/x_hist.png
+python/app/healthcare/fed_lidc_idri /*
+
+
+
+python/app/fedgraphnn/social_networks_graph_clf/data/TUDataset/*
+python/app/fedgraphnn/social_networks_graph_clf/data/TUDataset/collab.pickle
+python/app/fedgraphnn/social_networks_graph_clf/x_hist.png
+
+app/fedgraphnn/subgraph_relation_pred/data/FB15k-237/*.dict
+app/fedgraphnn/subgraph_relation_pred/data/FB15k-237/*.txt
+app/fedgraphnn/subgraph_relation_pred/data/wn18rr/*.dict
+app/fedgraphnn/subgraph_relation_pred/data/wn18rr/*.txt
+app/fedgraphnn/subgraph_relation_pred/data/YAGO3-10/*.dict
+app/fedgraphnn/subgraph_relation_pred/data/YAGO3-10/*.txt
+
+app/fedgraphnn/subgraph_link_pred/data/FB15k-237/*.dict
+app/fedgraphnn/subgraph_link_pred/data/FB15k-237/*.txt
+app/fedgraphnn/subgraph_link_pred/data/wn18rr/*.dict
+app/fedgraphnn/subgraph_link_pred/data/wn18rr/*.txt
+app/fedgraphnn/subgraph_link_pred/data/YAGO3-10/*.dict
+app/fedgraphnn/subgraph_link_pred/data/YAGO3-10/*.txt
+
+python/app/fedgraphnn/ego_networks_link_pred/data/ego-networks
+python/app/fedgraphnn/ego_networks_node_clf/data/ego-networks
+/devops/scripts/aws/
+devops/scripts/docker
+devops/scripts/kubectl
+android/gradlew
+android/gradlew.bat
+
+/iot/anomaly_detection_for_cybersecurity/data/Danmini_Doorbell/
+/iot/anomaly_detection_for_cybersecurity/data/Ecobee_Thermostat/
+/iot/anomaly_detection_for_cybersecurity/data/Ennio_Doorbell/
+/iot/anomaly_detection_for_cybersecurity/data/Philips_B120N10_Baby_Monitor/
+/iot/anomaly_detection_for_cybersecurity/data/Provision_PT_737E_Security_Camera/
+/iot/anomaly_detection_for_cybersecurity/data/Provision_PT_838_Security_Camera/
+/iot/anomaly_detection_for_cybersecurity/data/Samsung_SNH_1011_N_Webcam/
+/iot/anomaly_detection_for_cybersecurity/data/SimpleHome_XCS7_1002_WHT_Security_Camera/
+/iot/anomaly_detection_for_cybersecurity/data/SimpleHome_XCS7_1003_WHT_Security_Camera/
+
+
+doc/en/_build
+/docker-2/
+/swap/actions-runner/
+*.txt
+python/examples/cross_silo/light_sec_agg_example/mpi_host_file
+/python/examples/cross_silo/mqtt_thetastore_fedavg_mnist_lr_example/custom_data_and_model/mlops/dist-packages/
+/python/examples/cross_silo/mqtt_web3storage_fedavg_mnist_lr_example/custom_data_and_model/mlops/dist-packages/
+/python/examples/cross_silo/mqtt_s3_fedavg_mnist_lr_example/custom_data_and_model/mlops/dist-packages/
+/python/quick_start/parrot/fedml_data/
+/python/tests/smoke_test/simulation_sp/mnist/__MACOSX/MNIST/
+/python/tests/smoke_test/simulation_sp/mnist/MNIST/
+/swap/
+/.github/workflows/build_wheels_and_releaseså¯æŦ.yml
+/devops/dockerfile/device-image/Dockerfile-Local
+/FedML-dev-v0.7.0.iml
+/python/fedml/cli/debug-cli.py
diff --git a/README.md b/README.md
index 813509a91b..c64bcca6a2 100644
--- a/README.md
+++ b/README.md
@@ -1,24 +1,44 @@
-# FedML: The Community Connecting and Building AI Anywhere at Any Scale
+# FedML: The Community Building Open and Collaborative AI Anywhere at Any Scale
-
-
+
At the current stage, FedML library provides a research and production integrated edge-cloud platform for Federated/Distributed Machine Learning at anywhere at any scale.
-Homepage: [https://FedML.ai](https://FedML.ai)
-
)
# News
+* [2022/08/01] (Product Introduction) FedML AI platform releases the worldâs federated learning open platform on the public cloud with an in-depth introduction of products and technologies! Please visit this blog for details.
+
* [2022/03/15] (Fundraising): FedML, Inc. has finished the 1st-round fundraising. We are backed by top VCs who focus on AI, SaaS, and Blockchain/Web3/Crypto from the Bay Area, California of the USA.
* [2022/02/14] (Company): FedML is upgraded as a Delaware-registered C-Corp company. Our headquarter is in California, USA. The two co-founders are CEO Salman Avestimehr (https://www.avestimehr.com/) and CTO Chaoyang He (https://chaoyanghe.com). We welcome contributors anywhere in the world.
* [2021/02/01] (Award): #NeurIPS 2020# FedML won Best Paper Award at NeurIPS Federated Learning workshop 2020
@@ -28,7 +48,7 @@ FedML is hiring! [Come and join us](https://fedml.ai/careers/)!
# **FedML Feature Overview**
-![image](/doc/en/_static/image/4animals.png)
+![image](./doc/en/_static/image/4animals.png)
FedML logo reflects the mission of FedML Inc. FedML aims to build simple and versatile APIs for machine learning running anywhere at any scale.
In other words, FedML supports both federated learning for data silos and distributed training for acceleration with MLOps and Open Source support, covering cutting-edge academia research and industrial grade use cases.
@@ -165,7 +185,7 @@ model_args:
train_args:
federated_optimizer: "FedAvg"
- client_id_list: "[1, 2]"
+ client_id_list:
client_num_in_total: 1000
client_num_per_round: 2
comm_round: 50
@@ -210,7 +230,7 @@ Simulation with Message Passing Interface (MPI):
Simulation with NCCL-based MPI (the fastest training):
- In case your cross-GPU bandwidth is high (e.g., InfiniBand, NVLink, EFA, etc.), we suggest using this NCCL-based MPI FL simulator to accelerate your development.
-## **FedML Octopu Exampless**
+## **FedML Octopus Examples**
Horizontal Federated Learning:
- [mqtt_s3_fedavg_mnist_lr_example](./doc/en/cross-silo/examples/mqtt_s3_fedavg_mnist_lr_example.md): an example to illustrate how to run horizontal federated learning in data silos (hospitals, banks, etc.)
@@ -226,6 +246,25 @@ Here `hierarchical` means that inside each FL Client (data silo), there are mult
- [Federated Learning on Android Smartphones](./doc/en/cross-device/examples/mqtt_s3_fedavg_mnist_lr_example.md)
+# FedML on Smartphone and IoTs
+
+
+
+
+
+
+
+
+
+
+
+
+See the introduction and tutorial at [FedML/android](./android).
+
+
+
+See the introduction and tutorial at [FedML/iot](./iot)
+
# **MLOps User Guide**
[https://open.fedml.ai](https://open.fedml.ai)
@@ -236,14 +275,13 @@ FedML MLOps Platform simplifies the workflow of federated learning anywhere at a
It enables zero-code, lightweight, cross-platform, and provably secure federated learning.
It enables machine learning from decentralized data at various users/silos/edge nodes, without the need to centralize any data to the cloud, hence providing maximum privacy and efficiency.
-![image](./doc/en/_static/image/mlops_workflow.png)
-
+![image](./doc/en/_static/image/MLOps_workflow.png)
The above figure shows the workflow. Such a workflow is handled by web UI without the need to handle complex deployment. Check the following live demo for details:
![image](./doc/en/_static/image/mlops_invite.png)
-3 Minutes Introduction: [https://www.youtube.com/watch?v=E1k05jd1Tyw](https://www.youtube.com/watch?v=E1k05jd1Tyw)
+3 Minutes Introduction: [https://www.youtube.com/watch?v=E1k05jd1Tyw](https://www.youtube.com/watch?v=Xgm0XEaMlVQ)
A detailed guidance for the MLOps can be found at [FedML MLOps User Guide](./doc/en/mlops/user_guide.md).
@@ -275,7 +313,7 @@ FedMLâs core technology is backed by years of cutting-edge research represente
5. AI Applications
A Full-stack of Scientific Publications in ML Algorithms, Security/Privacy, Systems, Applications, and Visionary Impacts
-Please check [this publication list](./doc/en/resource/papers.md) for details.
+Please check [this publication list](./doc/en/resources/papers.md) for details.
## Video (Invited Talks)
@@ -298,12 +336,7 @@ Please check [this publication list](./doc/en/resource/papers.md) for details.
Our WeChat group exceeds 200 members, please add the following account and ask him to invite you to join.
-
-
-## FAQ
-
-We organize the frequently asked questions at [https://github.com/FedML-AI/FedML/discussions](https://github.com/FedML-AI/FedML/discussions).
-Please feel free to ask questions there. We are happy to discuss on supporting your special demands.
+
# Contributing
diff --git a/android/.gitignore b/android/.gitignore
new file mode 100644
index 0000000000..92f7d6fef9
--- /dev/null
+++ b/android/.gitignore
@@ -0,0 +1,22 @@
+.gradle
+/.idea
+.cxx
+
+cpp/build/lightsecagg/build_x86_linux
+cpp/build/lightsecagg/build_arm_android_64
+cpp/build/train/build_x86_linux
+cpp/build/train/build_arm_android_64
+
+*.iml
+/local.properties
+.DS_Store
+/captures
+.externalNativeBuild
+./idea
+
+cpp/build_arm_android_64
+cpp/build_x86_linux
+
+cpp/build/FedMLTrainer
+secring.gpg
+/build/
diff --git a/android/README.md b/android/README.md
new file mode 100644
index 0000000000..3c8b69a43c
--- /dev/null
+++ b/android/README.md
@@ -0,0 +1,155 @@
+# FedML Android App and SDK
+
+
+
+
+
+
+
+
+- Android project root path: https://github.com/FedML-AI/FedML/tree/master/android
+
+
+
+The architecture is divided into three vertical layers and multiple horizontal modules:
+
+### 1. Android APK Layer
+- app
+
+https://github.com/FedML-AI/FedML/tree/master/android/app
+
+
+- fedmlsdk_demo
+
+https://github.com/FedML-AI/FedML/tree/master/android/fedmlsdk_demo
+
+### 2. Android SDK layer (Java API + JNI + So library)
+
+https://github.com/FedML-AI/FedML/tree/master/android/fedmlsdk
+
+
+### 3. MobileNN: FedML Mobile Training Engine Layer (C++, MNN, PyTorch, etc.)
+
+https://github.com/FedML-AI/FedML/tree/master/android/fedmlsdk/MobileNN
+
+https://github.com/FedML-AI/MNN
+
+https://github.com/FedML-AI/pytorch
+
+## Get Started with FedML Android APP
+[https://doc.fedml.ai/cross-device/examples/cross_device_android_example.html](https://doc.fedml.ai/cross-device/examples/cross_device_android_example.html)
+
+## Get Started with FedML Android SDK
+
+`android/fedmlsdk_demo` is a short tutorial for integrating Android SDK for your host App.
+
+1. add repositories by maven
+
+```groovy
+ maven { url 'https://s01.oss.sonatype.org/content/repositories/snapshots' }
+```
+
+2. add dependency in build.gradle
+
+check `android/fedmlsdk_demo/build.gradle` as an example:
+
+```groovy
+ implementation 'ai.fedml:fedml-edge-android:1.0.0-SNAPSHOT'
+```
+
+3. add FedML account id to meta-data in AndroidManifest.xml
+
+check `android/fedmlsdk_demo/src/main/AndroidManifest.xml` as an example:
+
+
+```xml
+
+
+```
+
+or
+
+```xml
+
+
+```
+
+You can find your account ID at FedML Open Platform (https://open.fedml.ai):
+![account](./doc/beehive_account.png)
+
+4. initial FedML Android SDK on your `Application` class.
+
+Taking `android/fedmlsdk_demo/src/main/java/ai/fedml/edgedemo/App.java` as an example:
+```java
+package ai.fedml.edgedemo;
+
+import android.app.Application;
+import android.os.Handler;
+import android.os.Looper;
+
+import ai.fedml.edge.FedEdgeManager;
+
+public class App extends Application {
+ private static Handler sHandler = new Handler(Looper.getMainLooper());
+
+ @Override
+ public void onCreate() {
+ super.onCreate();
+
+ // initial Edge SDK
+ FedEdgeManager.getFedEdgeApi().init(this);
+
+ // set data path (to prepare data, please check this script `android/data/prepare.sh`)
+ FedEdgeManager.getFedEdgeApi().setPrivatePath(Environment.getExternalStorageDirectory().getPath()
+ + "/ai.fedml/device_1/user_0");
+ }
+}
+```
+
+## Android SDK APIs
+At the current stage, we provide high-level APIs with the following three classes.
+
+
+- ai.fedml.edge.FedEdgeManager
+
+This is the top APIs in FedML Android SDK, it supports core training engine and related control commands on your Android devices.
+
+- ai.fedml.edge.OnTrainProgressListener
+
+This is the message flow to interact between FedML Android SDK and your host APP.
+
+- ai.fedml.edge.request.RequestManager
+
+This is used to connect your Android SDK with FedML Open Platform (https://open.fedml.ai), which helps you to simplify the deployment, edge collaborative training, experimental tracking, and more.
+
+You can import them in your Java/Android projects as follows. See [android/fedmlsdk_demo/src/main/java/ai/fedml/edgedemo/ui/main/MainFragment.java](fedmlsdk_demo/src/main/java/ai/fedml/edgedemo/ui/main/MainFragment.java) as an example.
+```
+import ai.fedml.edge.FedEdgeManager;
+import ai.fedml.edge.OnTrainProgressListener;
+import ai.fedml.edge.request.RequestManager;
+```
+
+4. Running Android SDK Demo with MLOps (https://open.fedml.ai)
+
+Please follow this tutorial (https://doc.fedml.ai/mlops/user_guide.html) to start training using FedML BeeHive Platform.
+
+
+
+
+
+
+
+
+
+
+
+
+## How to Run?
+https://doc.fedml.ai/cross-device/examples/cross_device_android_example.html
+
+
+## Want More Advanced APIs or Features?
+We'd love to listen to your feedback!
+
+FedML team has rich experience in Android Platform and Federated Learning Algorithmic Research.
+If you want advanced feature supports, please send emails to avestimehr@fedml.ai and ch@fedml.ai
diff --git a/android/SUBMODULE.md b/android/SUBMODULE.md
new file mode 100644
index 0000000000..cb57c81e9f
--- /dev/null
+++ b/android/SUBMODULE.md
@@ -0,0 +1,12 @@
+# git submodule
+```
+# add git submodule (please execute under FedML/android folder)
+git submodule add https://github.com/FedML-AI/FedMLAndroidSDK.git fedmlsdk
+
+# git submodule init
+git submodule update --init --recursive
+
+# git submodule update
+git submodule update --remote --merge
+```
+for more `git submodule` related commands, please refer to https://devconnected.com/how-to-add-and-update-git-submodules/
\ No newline at end of file
diff --git a/android/TEST.md b/android/TEST.md
new file mode 100644
index 0000000000..ffd55383a8
--- /dev/null
+++ b/android/TEST.md
@@ -0,0 +1,112 @@
+# Train Flow
+
+## 1. onStartTrain
+
+**Received**
+topic: flserver_agent/1/start_train
+
+```json
+{
+ "groupid": "38",
+ "clientLearningRate": 0.001,
+ "partitionMethod": "homo",
+ "starttime": 1646068794775,
+ "trainBatchSize": 64,
+ "edgeids": [
+ 17,
+ 20,
+ 18,
+ 21,
+ 19
+ ],
+ "token": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6NjAsImFjY291bnQiOiJhbGV4LmxpYW5nIiwibG9naW5UaW1lIjoiMTY0NjA2NTY5MDAwNSIsImV4cCI6MH0.0OTXuMTfxqf2duhkBG1CQDj1UVgconnoSH0PASAEzM4",
+ "modelName": "resnet56",
+ "urls": [
+ "https://fedmls3.s3.amazonaws.com/025c28be-b464-457a-ab17-851ae60767a9"
+ ],
+ "clientOptimizer": "adam",
+ "userids": [
+ "60"
+ ],
+ "clientNumPerRound": 3,
+ "name": "1646068810",
+ "commRound": 3,
+ "localEpoch": 1,
+ "runId": 168,
+ "id": 169,
+ "projectid": "56",
+ "dataset": "cifar10",
+ "communicationBackend": "MQTT_S3",
+ "timestamp": "1646068794778"
+}
+```
+
+**Send**
+Topic: fedml_168_1
+
+```json
+{
+ "client_status": "ONLINE",
+ "msg_type": 5,
+ "receiver": 0,
+ "sender": 1
+}
+```
+
+## 2. init Config
+
+**Received**
+Topic: fedml_168_0_1
+
+```json
+{
+ "msg_type": 1,
+ "sender": 0,
+ "receiver": 1,
+ "model_params": "fedml_111_0_39d756ca2-1ce1-44bc-b232-59f0ae054f0e",
+ "client_idx": "0"
+}
+```
+
+**Send**
+Topic: fedml_168_1
+
+```json
+ {
+ "client_idx": "0",
+ "model_params": "fedml_111_0_39d756ca2-1ce1-44bc-b232-59f0ae054f0e",
+ "num_samples": 5,
+ "msg_type": 3,
+ "receiver": 0,
+ "sender": 1
+}
+```
+
+## 2. Sync Config
+
+**Received**
+Topic: fedml_168_1
+
+```json
+{
+ "msg_type": 2,
+ "sender": 0,
+ "receiver": 1,
+ "model_params": "fedml_111_0_39d756ca2-1ce1-44bc-b232-59f0ae054f0e",
+ "client_idx": "0"
+}
+```
+
+**Send**
+Topic: fedml_168_1
+
+```json
+{
+ "client_idx": "0",
+ "model_params": "fedml_111_0_39d756ca2-1ce1-44bc-b232-59f0ae054f0e",
+ "num_samples": 5,
+ "msg_type": 3,
+ "receiver": 0,
+ "sender": 1
+}
+```
\ No newline at end of file
diff --git a/android/app/.gitignore b/android/app/.gitignore
new file mode 100644
index 0000000000..42afabfd2a
--- /dev/null
+++ b/android/app/.gitignore
@@ -0,0 +1 @@
+/build
\ No newline at end of file
diff --git a/android/app/assets/aria_config.xml b/android/app/assets/aria_config.xml
new file mode 100644
index 0000000000..a0f603c1fc
--- /dev/null
+++ b/android/app/assets/aria_config.xml
@@ -0,0 +1,167 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/android/app/build.gradle b/android/app/build.gradle
new file mode 100644
index 0000000000..9fabeb77ec
--- /dev/null
+++ b/android/app/build.gradle
@@ -0,0 +1,86 @@
+apply plugin: 'com.android.application'
+
+android {
+ signingConfigs {
+ release {
+ storeFile file('fedml.jks')
+ storePassword 'fedml0'
+ keyAlias 'fedml'
+ keyPassword 'fedml0'
+ }
+ }
+ compileSdkVersion 32
+ buildToolsVersion '32.0.0'
+ ndkVersion '23.1.7779620'
+
+ defaultConfig {
+ applicationId "ai.fedml"
+ minSdkVersion 21
+ targetSdkVersion 32
+ versionCode 1
+ versionName "1.0"
+
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
+ multiDexEnabled true
+ signingConfig signingConfigs.release
+ }
+
+ buildTypes {
+ debug {
+ }
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
+ }
+ }
+
+ compileOptions {
+ sourceCompatibility 1.8
+ targetCompatibility 1.8
+ }
+
+ sourceSets {
+ main {
+ jniLibs.srcDirs = ['libs']
+ }
+ }
+
+ dataBinding{
+ enabled = true
+ }
+
+}
+
+dependencies {
+ implementation project(':fedmlsdk')
+ testImplementation 'junit:junit:4.13.2'
+ androidTestImplementation 'androidx.test.ext:junit:1.1.3'
+ androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
+ implementation fileTree(dir: "libs", include: ["*.jar"])
+ implementation 'androidx.appcompat:appcompat:1.4.2'
+ implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
+ implementation 'com.google.code.gson:gson:2.9.0'
+ implementation 'androidx.legacy:legacy-support-v4:1.0.0'
+ implementation 'com.google.android.material:material:1.6.1'
+ implementation 'androidx.cardview:cardview:1.0.0'
+ implementation('com.squareup.okhttp3:okhttp:5.0.0-alpha.7')
+
+ implementation "com.squareup.retrofit2:retrofit:2.9.0"
+ implementation "com.squareup.retrofit2:converter-gson:2.9.0"
+ implementation "com.squareup.okio:okio:3.1.0"
+ implementation 'com.squareup.retrofit2:converter-scalars:2.9.0'
+ implementation 'com.squareup.okhttp3:logging-interceptor:5.0.0-alpha.7'
+
+ // zxing
+ implementation 'com.github.yuzhiqiang1993:zxing:2.2.5'
+ // aria file download
+ implementation 'me.laoyuyu.aria:core:3.8.16'
+ annotationProcessor 'me.laoyuyu.aria:compiler:3.8.16'
+
+ annotationProcessor 'org.projectlombok:lombok:1.18.24'
+ compileOnly 'org.projectlombok:lombok:1.18.24'
+
+ implementation 'com.amazonaws:aws-android-sdk-s3:2.45.0'
+ implementation 'com.github.bumptech.glide:glide:4.13.2'
+ annotationProcessor 'com.github.bumptech.glide:compiler:4.13.2'
+}
\ No newline at end of file
diff --git a/android/app/fedml.jks b/android/app/fedml.jks
new file mode 100644
index 0000000000..a8ca85315a
Binary files /dev/null and b/android/app/fedml.jks differ
diff --git a/android/app/proguard-rules.pro b/android/app/proguard-rules.pro
new file mode 100644
index 0000000000..be1ce06284
--- /dev/null
+++ b/android/app/proguard-rules.pro
@@ -0,0 +1,109 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
+-ignorewarnings
+-keepattributes Exceptions
+-keepattributes InnerClasses
+-keepattributes SourceFile,LineNumberTable
+-keepclasseswithmembernames class * {
+ native ;
+}
+-keep public class * extends android.app.Activity
+-keep public class * extends android.app.Application
+-keep public class * extends android.app.Service
+-keep public class * extends android.content.BroadcastReceiver
+-keep public class * extends android.content.ContentProvider
+-keep public class * extends android.app.backup.BackupAgent
+-keep public class * extends android.preference.Preference
+-keep public class * extends android.app.Fragment
+-keep class androidx.** {*;}
+-keep public class * extends androidx.**
+-keep interface androidx.** {*;}
+-dontwarn androidx.**
+-keepclassmembers class * extends android.app.Activity{
+ public void *(android.view.View);
+}
+-keep public class * extends android.view.View{
+ *** get*();
+ void set*(***);
+ public (android.content.Context);
+ public (android.content.Context,android.util.AttributeSet);
+ public (android.content.Context,android.util.AttributeSet,int);
+}
+-keepclasseswithmembers class * {
+ public (android.content.Context, android.util.AttributeSet);
+ public (android.content.Context, android.util.AttributeSet, int);
+}
+-dontwarn android.annotation
+-keepattributes *Annotation*
+# ==================okhttp start===================
+-dontwarn okhttp3.**
+-dontwarn okio.**
+-dontwarn javax.annotation.**
+-dontwarn org.conscrypt.**
+# A resource is loaded with a relative path so the package of this class must be preserved.
+-keepnames class okhttp3.internal.publicsuffix.PublicSuffixDatabase
+# Animal Sniffer compileOnly dependency to ensure APIs are compatible with older versions of Java.
+-dontwarn org.codehaus.mojo.animal_sniffer.*
+# OkHttp platform used only on JVM and when Conscrypt dependency is available.
+-dontwarn okhttp3.internal.platform.ConscryptPlatform
+# ==================okhttp end=====================
+
+# ==================retrofit2 start===================
+# Retain generic type information for use by reflection by converters and adapters.
+-keepattributes Signature
+# Retain service method parameters.
+-keepclassmembernames,allowobfuscation interface * {
+ @retrofit2.http.* ;
+}
+# Ignore annotation used for build tooling.
+-dontwarn org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement
+
+# ==================retrofit2 end=====================
+
+# ==================gson start=====================
+-dontwarn com.google.gson.**
+-keep class com.google.gson.**{*;}
+-keep interface com.google.gson.**{*;}
+-dontwarn sun.misc.**
+-keepclassmembers,allowobfuscation class * {
+ @com.google.gson.annotations.SerializedName ;
+}
+# keep gson entity
+-keep class ai.fedml.edge.service.communicator.message.**{*;}
+-keep class ai.fedml.edge.request.parameter.**{*;}
+-keep class ai.fedml.edge.request.response.**{*;}
+# ==================gson end=====================
+-keep public class * implements com.bumptech.glide.module.GlideModule
+-keep class * extends com.bumptech.glide.module.AppGlideModule {
+ (...);
+}
+-keep public enum com.bumptech.glide.load.ImageHeaderParser$** {
+ **[] $VALUES;
+ public *;
+}
+-keep class com.bumptech.glide.load.data.ParcelFileDescriptorRewinder$InternalRewinder {
+ *** rewind();
+}
+
+# for DexGuard only
+-keepresourcexmlelements manifest/application/meta-data@value=GlideModule
+# ==================Glide end=====================
\ No newline at end of file
diff --git a/android/app/src/androidTest/java/ai/fedml/edge/ExampleInstrumentedTest.java b/android/app/src/androidTest/java/ai/fedml/edge/ExampleInstrumentedTest.java
new file mode 100644
index 0000000000..0aad4e9bca
--- /dev/null
+++ b/android/app/src/androidTest/java/ai/fedml/edge/ExampleInstrumentedTest.java
@@ -0,0 +1,26 @@
+package ai.fedml.edge;
+
+import android.content.Context;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.platform.app.InstrumentationRegistry;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Instrumented test, which will execute on an Android device.
+ *
+ * @see Testing documentation
+ */
+@RunWith(AndroidJUnit4.class)
+public class ExampleInstrumentedTest {
+ @Test
+ public void useAppContext() {
+ // Context of the app under test.
+ Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
+ assertEquals("ai.fedml.ai.fedml.edge.test", appContext.getPackageName());
+ }
+}
\ No newline at end of file
diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000..4704ef5b7f
--- /dev/null
+++ b/android/app/src/main/AndroidManifest.xml
@@ -0,0 +1,65 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/java/ai/fedml/App.java b/android/app/src/main/java/ai/fedml/App.java
new file mode 100644
index 0000000000..6cfa94865d
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/App.java
@@ -0,0 +1,29 @@
+package ai.fedml;
+
+import android.app.Application;
+
+import ai.fedml.edge.FedEdgeManager;
+
+
+public class App extends Application {
+
+
+ private static App app;
+
+ @Override
+ public void onCreate() {
+ super.onCreate();
+ app = this;
+ FedEdgeManager.getFedEdgeApi().init(this);
+ }
+
+
+ public static App getApp() {
+ return app;
+ }
+
+ @Override
+ public void onTerminate() {
+ super.onTerminate();
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/CustomGlideModule.java b/android/app/src/main/java/ai/fedml/CustomGlideModule.java
new file mode 100644
index 0000000000..ed00a0839c
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/CustomGlideModule.java
@@ -0,0 +1,9 @@
+package ai.fedml;
+
+import com.bumptech.glide.annotation.GlideModule;
+import com.bumptech.glide.module.AppGlideModule;
+
+@GlideModule
+public final class CustomGlideModule extends AppGlideModule {
+
+}
\ No newline at end of file
diff --git a/android/app/src/main/java/ai/fedml/base/AppManager.java b/android/app/src/main/java/ai/fedml/base/AppManager.java
new file mode 100644
index 0000000000..1e8a4575e3
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/base/AppManager.java
@@ -0,0 +1,132 @@
+package ai.fedml.base;
+
+import android.app.Activity;
+import android.content.Context;
+
+import java.util.Stack;
+
+public class AppManager {
+
+ private static Stack activityStack;
+ private static AppManager instance;
+
+ private AppManager() {
+ }
+
+ /**
+ * Single instance, UI does not need to consider multi-threaded synchronization issues
+ */
+ public static AppManager getAppManager() {
+ if (instance == null) {
+ instance = new AppManager();
+ }
+ return instance;
+ }
+
+ /**
+ * Add Activity to the stack
+ *
+ * @param activity
+ */
+ public void addActivity(Activity activity) {
+ if (activityStack == null) {
+ activityStack = new Stack();
+ }
+ activityStack.add(activity);
+
+ }
+
+ /**
+ * Get the current Activity (the Activity at the top of the stack)
+ */
+ public Activity currentActivity() {
+ if (activityStack == null || activityStack.isEmpty()) {
+ return null;
+ }
+ Activity activity = activityStack.lastElement();
+ return activity;
+ }
+
+ /**
+ * Get the current Activity (the Activity at the top of the stack) return null when not found
+ */
+ public Activity findActivity(Class> cls) {
+ Activity activity = null;
+ for (Activity aty : activityStack) {
+ if (aty.getClass().equals(cls)) {
+ activity = aty;
+ break;
+ }
+ }
+ return activity;
+ }
+
+ /**
+ * End the current Activity (the Activity at the top of the stack)
+ */
+ public void finishActivity() {
+ if (activityStack.size() <= 0) {
+ return;
+ }
+ Activity activity = activityStack.lastElement();
+ finishActivity(activity);
+ }
+
+ /**
+ * End the specified Activity (overload)
+ */
+ public void finishActivity(Activity activity) {
+ if (activity != null) {
+ activityStack.remove(activity);
+ activity.finish();
+ activity = null;
+ }
+ }
+
+ /**
+ * Remove the specified Activity and call the finish method elsewhere
+ */
+ public void removeActivity(Activity activity) {
+ if (activity != null) {
+ activityStack.remove(activity);
+ activity = null;
+ }
+ }
+
+ /**
+ * End the specified Activity (overload)
+ */
+ public void finishActivity(Class> cls) {
+ for (Activity activity : activityStack) {
+ if (activity.getClass().equals(cls)) {
+ finishActivity(activity);
+ }
+ }
+ }
+
+ /**
+ * End all the Activities
+ */
+ public void finishAllActivity() {
+ for (int i = 0, size = activityStack.size(); i < size; i++) {
+ if (null != activityStack.get(i)) {
+ activityStack.get(i).finish();
+ }
+ }
+ activityStack.clear();
+ }
+
+ /**
+ * The application exits, call this method!
+ */
+ public void AppExit(Context context) {
+ try {
+ finishAllActivity();
+ System.exit(0);
+
+ } catch (Exception e) {
+ System.exit(0);
+ }
+ }
+
+}
diff --git a/android/app/src/main/java/ai/fedml/base/BaseActivity.java b/android/app/src/main/java/ai/fedml/base/BaseActivity.java
new file mode 100644
index 0000000000..4ec11ed23b
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/base/BaseActivity.java
@@ -0,0 +1,35 @@
+package ai.fedml.base;
+
+import android.app.Activity;
+import android.os.Bundle;
+
+import ai.fedml.utils.StatusBarUtil;
+import androidx.annotation.Nullable;
+
+public class BaseActivity extends Activity {
+
+ @Override
+ protected void onCreate(@Nullable Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ AppManager.getAppManager().addActivity(this);
+ setStatusBar();
+
+ }
+
+ protected void setStatusBar() {
+ //Two things are done here, so that the immersive status bar of the second case mentioned at the beginning can be realized.
+ // 1.Make the status bar transparent and fill the contentView to the status bar
+ // 2.Reserve the position of the status bar to prevent the controls on the interface from being too close to the top.
+ StatusBarUtil.setTransparent(this);
+ }
+
+
+
+
+
+ @Override
+ protected void onDestroy() {
+ super.onDestroy();
+ AppManager.getAppManager().removeActivity(this);
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/client/RetrofitManager.java b/android/app/src/main/java/ai/fedml/client/RetrofitManager.java
new file mode 100644
index 0000000000..b4e0a244ae
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/client/RetrofitManager.java
@@ -0,0 +1,52 @@
+package ai.fedml.client;
+
+import java.util.concurrent.TimeUnit;
+
+import ai.fedml.edge.BuildConfig;
+import ai.fedml.edge.request.UserManagerService;
+import okhttp3.OkHttpClient;
+import okhttp3.logging.HttpLoggingInterceptor;
+import retrofit2.Retrofit;
+import retrofit2.converter.gson.GsonConverterFactory;
+
+public final class RetrofitManager {
+ private static final String BASE_API_SERVER_URL = BuildConfig.MLOPS_SVR;
+ private static Retrofit retrofit;
+ private static UserManagerService userManagerService;
+ private static VersionUpdater versionUpdater;
+
+ private static Retrofit retrofit() {
+ if (retrofit == null) {
+ HttpLoggingInterceptor loggingInterceptor = new HttpLoggingInterceptor()
+ .setLevel(HttpLoggingInterceptor.Level.BASIC);
+
+ OkHttpClient okHttpClient = new OkHttpClient.Builder()
+ .writeTimeout(30_1000, TimeUnit.MILLISECONDS)
+ .readTimeout(20_1000, TimeUnit.MILLISECONDS)
+ .connectTimeout(15_1000, TimeUnit.MILLISECONDS)
+ .addInterceptor(loggingInterceptor)
+ .build();
+
+ retrofit = new Retrofit.Builder()
+ .baseUrl(BASE_API_SERVER_URL)
+ .addConverterFactory(GsonConverterFactory.create())
+ .client(okHttpClient)
+ .build();
+ }
+ return retrofit;
+ }
+
+ public static UserManagerService getOpsUserManager() {
+ if (userManagerService == null) {
+ userManagerService = retrofit().create(UserManagerService.class);
+ }
+ return userManagerService;
+ }
+
+ public static VersionUpdater getVersionUpdater() {
+ if (versionUpdater == null) {
+ versionUpdater = retrofit().create(VersionUpdater.class);
+ }
+ return versionUpdater;
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/client/VersionUpdater.java b/android/app/src/main/java/ai/fedml/client/VersionUpdater.java
new file mode 100644
index 0000000000..1aeb733760
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/client/VersionUpdater.java
@@ -0,0 +1,10 @@
+package ai.fedml.client;
+
+import ai.fedml.client.entity.VersionUpdateResponse;
+import retrofit2.Call;
+import retrofit2.http.GET;
+
+public interface VersionUpdater {
+ @GET("/fedmlOpsServer/apk/latestVersion")
+ Call getLatestVersion();
+}
diff --git a/android/app/src/main/java/ai/fedml/client/entity/VersionUpdateResponse.java b/android/app/src/main/java/ai/fedml/client/entity/VersionUpdateResponse.java
new file mode 100644
index 0000000000..bc5f898ea0
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/client/entity/VersionUpdateResponse.java
@@ -0,0 +1,33 @@
+package ai.fedml.client.entity;
+
+import com.google.gson.annotations.SerializedName;
+
+import java.util.List;
+
+import ai.fedml.edge.request.response.BaseResponse;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.ToString;
+
+
+@Data
+@EqualsAndHashCode(callSuper = true)
+@ToString(callSuper = true)
+public class VersionUpdateResponse extends BaseResponse {
+ @SerializedName("data")
+ private List versionList;
+
+ @Data
+ public static class VersionUpdateInfo {
+ @SerializedName("createtime")
+ private String createTime;
+ @SerializedName("code")
+ private Integer code;
+ @SerializedName("name")
+ private String name;
+ @SerializedName("id")
+ private Integer id;
+ @SerializedName("url")
+ private String url;
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/ui/GuideActivity.java b/android/app/src/main/java/ai/fedml/ui/GuideActivity.java
new file mode 100644
index 0000000000..400c47a04e
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/ui/GuideActivity.java
@@ -0,0 +1,88 @@
+package ai.fedml.ui;
+
+import android.Manifest;
+import android.content.Intent;
+import android.content.pm.PackageManager;
+import android.os.Build;
+import android.os.Bundle;
+import android.os.Handler;
+import android.os.Looper;
+import android.text.TextUtils;
+import android.util.Log;
+import android.view.View;
+
+import java.util.Arrays;
+
+import androidx.annotation.Nullable;
+
+import java.io.File;
+
+import ai.fedml.R;
+import ai.fedml.base.AppManager;
+import ai.fedml.base.BaseActivity;
+import ai.fedml.edge.FedEdgeManager;
+import ai.fedml.edge.utils.LogHelper;
+import ai.fedml.edge.utils.StorageUtils;
+
+/**
+ * Guideline pages
+ */
+public class GuideActivity extends BaseActivity {
+ private static final String TAG = "GuideActivity";
+ public static final String TRAIN_MODEL_FILE_PATH = StorageUtils.getSdCardPath() + "/ai.fedml/lenet_mnist.mnn";
+ public static final String TRAIN_DATA_FILE_PATH = StorageUtils.getSdCardPath() + "/ai.fedml/mnist";
+ private final Handler mHandler = new Handler(Looper.getMainLooper());
+
+ @Override
+ protected void onCreate(@Nullable Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_guide);
+ initView();
+ loadData();
+ }
+
+ private void initView() {
+ Log.d(TAG, "guide ModelPath:" + StorageUtils.getModelPath() + "Dataset:" + StorageUtils.getDatasetPath());
+ View guideView = findViewById(R.id.iv_guide);
+ guideView.setOnClickListener(view -> {
+ Log.d(TAG, "OnClick guide ModelPath:" + StorageUtils.getModelPath() +
+ ",Dataset:" + StorageUtils.getDatasetPath());
+ Log.d(TAG, "TRAIN_MODEL_FILE_PATH is " + new File(TRAIN_MODEL_FILE_PATH).exists());
+ Log.d(TAG, "TRAIN_DATA_FILE_PATH is " + new File(TRAIN_DATA_FILE_PATH).isDirectory());
+ });
+ }
+
+ private void loadData() {
+ mHandler.postDelayed(() -> {
+ final String bindingId = FedEdgeManager.getFedEdgeApi().getBoundEdgeId();
+ LogHelper.d("BindingId:%s", bindingId);
+ if (TextUtils.isEmpty(bindingId)) {
+ Intent intent = new Intent();
+ intent.setClass(GuideActivity.this, ScanCodeActivity.class);
+ startActivity(intent);
+ } else {
+ Intent intent = new Intent();
+ intent.setClass(GuideActivity.this, HomeActivity.class);
+ startActivity(intent);
+ }
+ AppManager.getAppManager().finishActivity();
+ }, 200);
+ }
+
+ /**
+ * Get permission
+ */
+ private void getPermission() {
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ int REQUEST_CODE_CONTACT = 101;
+ String[] permissions = {Manifest.permission.WRITE_EXTERNAL_STORAGE, Manifest.permission.READ_EXTERNAL_STORAGE};
+ //Verify permission
+ for (String str : permissions) {
+ if (this.checkSelfPermission(str) != PackageManager.PERMISSION_GRANTED) {
+ //Request permission
+ this.requestPermissions(permissions, REQUEST_CODE_CONTACT);
+ }
+ }
+ }
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/ui/HomeActivity.java b/android/app/src/main/java/ai/fedml/ui/HomeActivity.java
new file mode 100644
index 0000000000..d977f296f5
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/ui/HomeActivity.java
@@ -0,0 +1,224 @@
+package ai.fedml.ui;
+
+import android.Manifest;
+import android.annotation.SuppressLint;
+import android.content.Intent;
+import android.content.pm.PackageManager;
+import android.net.Uri;
+import android.os.Build;
+import android.os.Bundle;
+import android.os.Environment;
+import android.provider.Settings;
+import android.text.TextUtils;
+import android.view.View;
+import android.widget.Button;
+import android.widget.ImageView;
+import android.widget.TextView;
+
+import ai.fedml.GlideApp;
+import ai.fedml.R;
+import ai.fedml.base.BaseActivity;
+import ai.fedml.edge.FedEdgeManager;
+import ai.fedml.edge.OnTrainProgressListener;
+import ai.fedml.edge.request.RequestManager;
+import ai.fedml.edge.service.communicator.message.MessageDefine;
+import ai.fedml.edge.utils.LogHelper;
+import ai.fedml.utils.ToastUtils;
+import ai.fedml.widget.CompletedProgressView;
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import androidx.core.app.ActivityCompat;
+import androidx.core.content.ContextCompat;
+
+/**
+ * HomeActivity
+ */
+public class HomeActivity extends BaseActivity implements View.OnClickListener {
+
+ private static final String TAG = "HomeActivity";
+ private Button btn_set_path;
+ private static final int REQUEST_CODE = 1024;
+ private TextView mStatusTextView;
+ private TextView mAccLossTextView;
+ private CompletedProgressView mProgressView;
+ private TextView mHyperTextView;
+ private TextView mNameTextView;
+ private TextView mEmailTextView;
+ private TextView mGroupTextView;
+ private ImageView mAvatarImageView;
+ private TextView mDeviceAccountInfoTextView;
+
+ @Override
+ protected void onCreate(@Nullable Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_home);
+ initView();
+ loadDate();
+
+ }
+
+ @Override
+ protected void onResume() {
+ super.onResume();
+ String path = FedEdgeManager.getFedEdgeApi().getPrivatePath();
+ if (!TextUtils.isEmpty(path)) {
+ btn_set_path.setText(path);
+ }
+ }
+
+ private void initView() {
+ btn_set_path = findViewById(R.id.btn_set_path);
+ Button btn_unbind = findViewById(R.id.btn_unbind);
+
+ btn_set_path.setOnClickListener(this);
+ btn_unbind.setOnClickListener(this);
+
+ mStatusTextView = findViewById(R.id.tv_status);
+ mAccLossTextView = findViewById(R.id.tv_acc_loss);
+ mProgressView = findViewById(R.id.progress_view);
+ mHyperTextView = findViewById(R.id.tv_hyper_parameter);
+ mDeviceAccountInfoTextView = findViewById(R.id.tv_account_info);
+ mNameTextView = findViewById(R.id.tv_name);
+ mEmailTextView = findViewById(R.id.tv_email);
+ mGroupTextView = findViewById(R.id.tv_group);
+ mAvatarImageView = findViewById(R.id.iv_avatar);
+ }
+
+ private void loadDate() {
+ requestPermission();
+ getUserInfo();
+// VersionUpdate();
+ mDeviceAccountInfoTextView.setText(getString(R.string.account_information, FedEdgeManager.getFedEdgeApi().getBoundEdgeId()));
+ mProgressView.setProgress(0);
+ FedEdgeManager.getFedEdgeApi().setEpochLossListener(new OnTrainProgressListener() {
+ private int mRound = 0;
+ private int mEpoch = 0;
+ private float mLoss = 0f;
+ private float mAccuracy = 0f;
+
+ @Override
+ public void onEpochLoss(int round, int epoch, float loss) {
+ mRound = round;
+ mEpoch = epoch;
+ mLoss = loss;
+ runOnUiThread(() ->
+ mAccLossTextView.setText(getString(R.string.acc_loss_txt, mRound, mEpoch, mAccuracy, mLoss)));
+ }
+
+ @Override
+ public void onEpochAccuracy(int round, int epoch, float accuracy) {
+ mRound = round;
+ mEpoch = epoch;
+ mAccuracy = accuracy;
+ runOnUiThread(() ->
+ mAccLossTextView.setText(getString(R.string.acc_loss_txt, mRound, mEpoch, mAccuracy, mLoss)));
+ }
+
+ @Override
+ public void onProgressChanged(int round, float progress) {
+ runOnUiThread(() ->
+ mProgressView.setProgress(Math.round(progress)));
+ }
+ });
+ FedEdgeManager.getFedEdgeApi().setTrainingStatusListener((status) ->
+ runOnUiThread(() -> {
+ if (status == MessageDefine.KEY_CLIENT_STATUS_INITIALIZING) {
+ mHyperTextView.setText(FedEdgeManager.getFedEdgeApi().getHyperParameters());
+ }
+ mStatusTextView.setText(MessageDefine.CLIENT_STATUS_MAP.get(status));
+ }));
+ }
+
+
+ @SuppressLint("NonConstantResourceId")
+ @Override
+ public void onClick(View view) {
+ switch (view.getId()) {
+ case R.id.btn_set_path:
+ Intent intent = new Intent();
+ intent.setClass(HomeActivity.this, SetFilePathActivity.class);
+ startActivity(intent);
+ break;
+ case R.id.btn_unbind:
+ unbound();
+ break;
+ }
+ }
+
+ private void requestPermission() {
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) {
+ // First determine whether you have permission
+ if (!Environment.isExternalStorageManager()) {
+ Intent intent = new Intent(Settings.ACTION_MANAGE_APP_ALL_FILES_ACCESS_PERMISSION);
+ intent.setData(Uri.parse("package:" + this.getPackageName()));
+ startActivityForResult(intent, REQUEST_CODE);
+ }
+ } else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ // First determine whether you have permission
+ if (ActivityCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED ||
+ ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
+ ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.READ_EXTERNAL_STORAGE, Manifest.permission.WRITE_EXTERNAL_STORAGE}, REQUEST_CODE);
+ }
+ }
+ }
+
+ @Override
+ public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
+ super.onRequestPermissionsResult(requestCode, permissions, grantResults);
+ if (requestCode == REQUEST_CODE) {
+ if (ActivityCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED &&
+ ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED) {
+ } else {
+ ToastUtils.show("EXTERNAL STORAGE Permissions failed");
+ }
+ }
+ }
+
+ @Override
+ protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
+ super.onActivityResult(requestCode, resultCode, data);
+ if (requestCode == REQUEST_CODE && Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) {
+ if (Environment.isExternalStorageManager()) {
+ } else {
+ ToastUtils.show("Failed to obtain storage permission");
+ }
+ }
+ }
+
+ private void getUserInfo() {
+ RequestManager.getUserInfo(userInfo -> {
+ if (userInfo != null) {
+ runOnUiThread(() -> {
+ mNameTextView.setText(String.format("%s %s", userInfo.getLastname(), userInfo.getFirstName()));
+ mEmailTextView.setText(userInfo.getEmail());
+ mGroupTextView.setText(userInfo.getCompany());
+ GlideApp.with(HomeActivity.this)
+ .load(userInfo.getAvatar())
+ .circleCrop()
+ .placeholder(R.mipmap.ic_shijiali)
+ .into(mAvatarImageView);
+ });
+ }
+ });
+ }
+
+
+ private void unbound() {
+ String bindingId = FedEdgeManager.getFedEdgeApi().getBoundEdgeId();
+ LogHelper.d("unbound bindingId:%s", bindingId);
+ RequestManager.unboundAccount(bindingId, isSuccess -> runOnUiThread(() -> {
+ if (isSuccess) {
+ // Jump to scanning page
+ Intent intent = new Intent();
+ intent.setClass(HomeActivity.this, ScanCodeActivity.class);
+ startActivity(intent);
+ finish();
+ }
+ }));
+ }
+
+ @Override
+ protected void onDestroy() {
+ super.onDestroy();
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/ui/ScanCodeActivity.java b/android/app/src/main/java/ai/fedml/ui/ScanCodeActivity.java
new file mode 100644
index 0000000000..3138dcfa51
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/ui/ScanCodeActivity.java
@@ -0,0 +1 @@
+package ai.fedml.ui;
import android.Manifest;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.text.TextUtils;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.ImageView;
import android.widget.LinearLayout;
import androidx.annotation.Nullable;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import com.yzq.zxinglibrary.android.CaptureActivity;
import com.yzq.zxinglibrary.common.Constant;
import ai.fedml.R;
import ai.fedml.base.AppManager;
import ai.fedml.base.BaseActivity;
import ai.fedml.edge.FedEdgeManager;
import ai.fedml.edge.request.RequestManager;
import ai.fedml.edge.request.parameter.BindingAccountReq;
import ai.fedml.edge.utils.DeviceUtils;
import ai.fedml.edge.utils.LogHelper;
import ai.fedml.utils.ToastUtils;
public class ScanCodeActivity extends BaseActivity implements View.OnClickListener {
private static final int REQUEST_CODE_SCAN = 0x0000;//Scan ID
private EditText edt_account_id;
private ImageView img_privacy;
private boolean isSelect = false;
private final Handler mHandler = new Handler(Looper.getMainLooper());
@Override
protected void onCreate(@Nullable Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_scan_code);
initView();
loadData();
}
private void initView() {
LinearLayout scanCodeLayout = findViewById(R.id.line_scan_code);
edt_account_id = findViewById(R.id.edt_account_id);
img_privacy = findViewById(R.id.img_privacy);
Button okButton = findViewById(R.id.btn_ok);
scanCodeLayout.setOnClickListener(this);
img_privacy.setOnClickListener(this);
okButton.setOnClickListener(this);
}
private void loadData() {
// meta-data binding
Runnable runnable = new Runnable() {
@Override
public void run() {
final String bindingId = FedEdgeManager.getFedEdgeApi().getBoundEdgeId();
if (TextUtils.isEmpty(bindingId)) {
mHandler.postDelayed(this, 500);
} else {
Intent intent = new Intent();
intent.setClass(ScanCodeActivity.this, HomeActivity.class);
startActivity(intent);
AppManager.getAppManager().finishActivity();
}
}
};
mHandler.postDelayed(runnable, 500);
}
@Override
public void onClick(View view) {
switch (view.getId()) {
case R.id.line_scan_code:
//Dynamic permission application
if (ContextCompat.checkSelfPermission(ScanCodeActivity.this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(ScanCodeActivity.this, new String[]{Manifest.permission.CAMERA}, 1);
} else {
goScan();
}
break;
case R.id.img_privacy:
if (isSelect) {
img_privacy.setBackgroundResource(R.mipmap.ic_privacy_pressed);
isSelect = false;
} else {
img_privacy.setBackgroundResource(R.mipmap.ic_privacy_normal);
isSelect = true;
}
break;
case R.id.btn_ok:
String id = FedEdgeManager.getFedEdgeApi().getBoundEdgeId();
if (TextUtils.isEmpty(id)) {
postBinding();
} else {
Intent intent = new Intent();
intent.setClass(ScanCodeActivity.this, HomeActivity.class);
startActivity(intent);
AppManager.getAppManager().finishActivity();
}
break;
}
}
/**
* Jump to the scanning code interface to scan the code
*/
private void goScan() {
Intent intent = new Intent(ScanCodeActivity.this, CaptureActivity.class);
startActivityForResult(intent, REQUEST_CODE_SCAN);
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
// Scan the QR code/barcode and send it back
if (requestCode == REQUEST_CODE_SCAN && resultCode == RESULT_OK) {
if (data != null) {
String content = data.getStringExtra(Constant.CODED_CONTENT);
edt_account_id.setText(content);
}
}
}
public void postBinding() {
String accountId = edt_account_id.getText().toString().trim();
if (TextUtils.isEmpty(accountId)) {
ToastUtils.show(R.string.account_input_tip);
return;
}
BindingAccountReq req = BindingAccountReq.builder()
.accountId(accountId).deviceId(DeviceUtils.getDeviceId()).build();
RequestManager.bindingAccount(req, data -> runOnUiThread(() -> {
LogHelper.i("bindingData.getBindingId() = %s", data);
if (data == null) {
ToastUtils.show(R.string.retry_tip);
return;
}
FedEdgeManager.getFedEdgeApi().bindEdge(data.getBindingId());
Intent intent = new Intent();
intent.setClass(ScanCodeActivity.this, HomeActivity.class);
startActivity(intent);
AppManager.getAppManager().finishActivity();
}));
}
}
\ No newline at end of file
diff --git a/android/app/src/main/java/ai/fedml/ui/SetFilePathActivity.java b/android/app/src/main/java/ai/fedml/ui/SetFilePathActivity.java
new file mode 100644
index 0000000000..e658bca9e5
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/ui/SetFilePathActivity.java
@@ -0,0 +1,133 @@
+package ai.fedml.ui;
+
+import android.annotation.SuppressLint;
+import android.os.Bundle;
+import android.os.Environment;
+import android.view.View;
+import android.widget.Button;
+import android.widget.TextView;
+
+import androidx.annotation.Nullable;
+import androidx.recyclerview.widget.LinearLayoutManager;
+import androidx.recyclerview.widget.RecyclerView;
+
+import java.io.File;
+import java.util.ArrayList;
+
+import ai.fedml.R;
+import ai.fedml.base.AppManager;
+import ai.fedml.base.BaseActivity;
+import ai.fedml.edge.FedEdgeManager;
+import ai.fedml.ui.adapter.FileItem;
+import ai.fedml.ui.adapter.RvFilePathAdapter;
+import ai.fedml.utils.FileFilter;
+import ai.fedml.utils.FileOpenUtils;
+import ai.fedml.utils.FormatUtils;
+
+/**
+ * SetFilePath
+ */
+public class SetFilePathActivity extends BaseActivity implements View.OnClickListener, RvFilePathAdapter.OnItemClickListener {
+
+ private TextView tv_path;
+ private RecyclerView rv_file_path;
+ private RvFilePathAdapter rvFilePathAdapter;
+
+ private File[] files;// get everything in the directory
+ private File currentPath;
+
+ @Override
+ protected void onCreate(@Nullable Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_set_file_path);
+ initView();
+ loadData();
+
+ }
+
+ private void initView() {
+ Button btn_back = findViewById(R.id.btn_back);
+ Button btn_save_path = findViewById(R.id.btn_save_path);
+ tv_path = findViewById(R.id.tv_path);
+ rv_file_path = findViewById(R.id.rv_file_path);
+
+ btn_back.setOnClickListener(this);
+ btn_save_path.setOnClickListener(this);
+ }
+
+ private void loadData() {
+ ArrayList dataset = new ArrayList<>();
+ LinearLayoutManager layoutManager = new LinearLayoutManager(this);
+ rv_file_path.setLayoutManager(layoutManager);
+ rvFilePathAdapter = new RvFilePathAdapter(this, dataset);
+ rv_file_path.setAdapter(rvFilePathAdapter);
+ rvFilePathAdapter.setOnItemClickListener(this);
+
+ // Get the directory of the sd card
+ if (Environment.MEDIA_MOUNTED.equals(Environment.getExternalStorageState())) {
+ File sd = Environment.getExternalStorageDirectory();// Get the directory of the sd card
+ // get the contents of the directory
+ showDir(sd);
+ }
+ }
+
+ @SuppressLint("NonConstantResourceId")
+ @Override
+ public void onClick(View view) {
+ switch (view.getId()) {
+ case R.id.btn_save_path:
+ FedEdgeManager.getFedEdgeApi().setPrivatePath(currentPath.toString());
+ AppManager.getAppManager().finishActivity();
+ break;
+ case R.id.btn_back:
+ // Load parent directory ParentFile: parent directory
+ File path = currentPath.getParentFile();
+ if (path == null || path.toString().equals("/storage/emulated")) {
+ AppManager.getAppManager().finishActivity();
+ } else {
+ showDir(path);
+ }
+ break;
+ }
+ }
+
+
+ /**
+ * Load all folders and files and update the interface
+ */
+ @SuppressLint("NotifyDataSetChanged")
+ private void showDir(File dir) {
+ // save current location
+ currentPath = dir;
+ // Get the contents of the directory (listFiles: get all the contents), and filter the files and folders starting with "." by FileFilter() function
+ files = dir.listFiles(new FileFilter());
+ if (files == null) {
+ return;
+ }
+ rvFilePathAdapter.mSetData.clear();
+ for (File file : files) {
+ FileItem item = FileItem.builder().fileIcon(file.isFile() ? R.mipmap.ic_file : R.mipmap.ic_dir)
+ .fileName(file.getName())
+ .fileSize(FormatUtils.unitConversion(file.length()))
+ .fileLastModifiedTime(FormatUtils.longToString(file.lastModified()))
+ .build();
+ rvFilePathAdapter.mSetData.add(item);
+ }
+ tv_path.setText(currentPath.toString());
+ rvFilePathAdapter.notifyDataSetChanged();
+
+ }
+
+
+ @Override
+ public void onItemClick(View view, int position) {
+ if (files[position].isFile()) {
+ // open the file
+ FileOpenUtils.openFile(this, files[position]);
+ } else {
+ // Open the contents of the file directory
+ // load new data
+ showDir(files[position]);
+ }
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/ui/adapter/FileItem.java b/android/app/src/main/java/ai/fedml/ui/adapter/FileItem.java
new file mode 100644
index 0000000000..ae00b77b7a
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/ui/adapter/FileItem.java
@@ -0,0 +1,13 @@
+package ai.fedml.ui.adapter;
+
+import lombok.Builder;
+import lombok.Data;
+
+@Data
+@Builder
+public class FileItem {
+ private int fileIcon;
+ private String fileName;
+ private String fileSize;
+ private String fileLastModifiedTime;
+}
diff --git a/android/app/src/main/java/ai/fedml/ui/adapter/RvFilePathAdapter.java b/android/app/src/main/java/ai/fedml/ui/adapter/RvFilePathAdapter.java
new file mode 100644
index 0000000000..dec577dd0a
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/ui/adapter/RvFilePathAdapter.java
@@ -0,0 +1,128 @@
+package ai.fedml.ui.adapter;
+
+import android.content.Context;
+import android.view.LayoutInflater;
+import android.view.View;
+import android.view.ViewGroup;
+import android.widget.ImageView;
+import android.widget.TextView;
+
+import java.util.List;
+
+import ai.fedml.R;
+import ai.fedml.utils.AppUtils;
+import ai.fedml.utils.FormatUtils;
+import androidx.annotation.NonNull;
+import androidx.recyclerview.widget.RecyclerView;
+
+/**
+ * file list adapter
+ */
+public class RvFilePathAdapter extends RecyclerView.Adapter {
+
+ //Type, use this to determine which layout the recyclerview should use to display
+ public final int TYPE_EMPTY = 0;
+ public final int TYPE_NORMAL = 1;
+
+ private final Context mContext;
+ public List mSetData;
+
+
+ public RvFilePathAdapter(Context mContext, List mSetData) {
+ this.mContext = mContext;
+ this.mSetData = mSetData;
+ }
+
+ @Override
+ public int getItemViewType(int position) {
+ if (mSetData == null || mSetData.size() <= 0) {
+ return TYPE_EMPTY;
+ }
+ return TYPE_NORMAL;
+ }
+
+ @NonNull
+ @Override
+ public RecyclerView.ViewHolder onCreateViewHolder(@NonNull ViewGroup parent, int viewType) {
+ View view;
+ //If it is an empty layout type, it will directly return null
+ if (viewType == TYPE_EMPTY) {
+ view = LayoutInflater.from(mContext).inflate(R.layout.rv_item_empty, parent, false);
+ return new EmptyViewHolder(view);
+ } else {
+ view = LayoutInflater.from(mContext).inflate(R.layout.rv_item_file, parent, false);
+ return new BodyViewHolder(view);
+ }
+ }
+
+ @Override
+ public void onBindViewHolder(@NonNull RecyclerView.ViewHolder holder, int position) {
+ //First determine whether the holder is a custom holder
+ if (holder instanceof BodyViewHolder) {
+ holder.itemView.setOnClickListener((View v) -> {
+ int pos = holder.getLayoutPosition();
+ onItemClickListener.onItemClick(holder.itemView, pos);
+ });
+ if (mSetData == null) {
+ return;
+ }
+ FileItem item = mSetData.get(position);
+ if (item == null) {
+ return;
+ }
+ ((BodyViewHolder) holder).img_icon.setImageResource(item.getFileIcon());
+ ((BodyViewHolder) holder).tv_name.setText(item.getFileName());
+ ((BodyViewHolder) holder).tv_time.setText(item.getFileLastModifiedTime());
+ if (item.getFileIcon() == R.mipmap.ic_dir) {
+ ((BodyViewHolder) holder).tv_size.setVisibility(View.GONE);
+ } else {
+ ((BodyViewHolder) holder).tv_size.setVisibility(View.VISIBLE);
+ ((BodyViewHolder) holder).tv_size.setText(item.getFileSize());
+ }
+ }
+ }
+
+ @Override
+ public int getItemCount() {
+ if (mSetData == null || mSetData.size() <= 0) {
+ return 1;
+ }
+ return mSetData.size();
+ }
+
+ public static class BodyViewHolder extends RecyclerView.ViewHolder {
+
+ ImageView img_icon;
+ TextView tv_name;
+ TextView tv_size;
+ TextView tv_time;
+
+ public BodyViewHolder(@NonNull View itemView) {
+ super(itemView);
+ img_icon = itemView.findViewById(R.id.img_icon);
+ tv_name = itemView.findViewById(R.id.tv_name);
+ tv_size = itemView.findViewById(R.id.tv_size);
+ tv_time = itemView.findViewById(R.id.tv_time);
+
+ }
+ }
+
+ /**
+ * empty layout
+ */
+ public static class EmptyViewHolder extends RecyclerView.ViewHolder {
+ public EmptyViewHolder(@NonNull View itemView) {
+ super(itemView);
+ }
+ }
+
+ public interface OnItemClickListener {
+ void onItemClick(View view, int position);
+ }
+
+ private OnItemClickListener onItemClickListener;
+
+ public void setOnItemClickListener(OnItemClickListener onItemClickListener) {
+ this.onItemClickListener = onItemClickListener;
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/utils/AppUtils.java b/android/app/src/main/java/ai/fedml/utils/AppUtils.java
new file mode 100644
index 0000000000..72187c3eb4
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/utils/AppUtils.java
@@ -0,0 +1,33 @@
+package ai.fedml.utils;
+
+import android.content.Context;
+import android.content.pm.PackageManager;
+import android.os.Looper;
+import android.view.View;
+import android.widget.Toast;
+
+import java.math.BigDecimal;
+
+
+public class AppUtils {
+
+
+
+
+
+ /**
+ * Get local version number
+ */
+ public static int getVersionCode(Context mContext) {
+ int versionCode = 0;
+ try {
+ //Get the software version number, corresponding to android:versionCode under AndroidManifest.xml
+ versionCode = mContext.getPackageManager().
+ getPackageInfo(mContext.getPackageName(), 0).versionCode;
+ } catch (PackageManager.NameNotFoundException e) {
+ e.printStackTrace();
+ }
+ return versionCode;
+ }
+
+}
diff --git a/android/app/src/main/java/ai/fedml/utils/FileFilter.java b/android/app/src/main/java/ai/fedml/utils/FileFilter.java
new file mode 100644
index 0000000000..71ec1d9c15
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/utils/FileFilter.java
@@ -0,0 +1,18 @@
+package ai.fedml.utils;
+
+import java.io.File;
+
+/**
+ * FileFilter
+ */
+public class FileFilter implements java.io.FileFilter {
+
+ @Override
+ public boolean accept(File pathname) {
+ // Filter files and folders starting with ".", get the file name, and the prefix is ".", if it is, return false, if not return true
+ if (pathname.getName().startsWith(".")) {
+ return false;
+ }
+ return true;
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/utils/FileOpenUtils.java b/android/app/src/main/java/ai/fedml/utils/FileOpenUtils.java
new file mode 100644
index 0000000000..dae1346aee
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/utils/FileOpenUtils.java
@@ -0,0 +1,164 @@
+package ai.fedml.utils;
+
+import android.content.Context;
+import android.content.Intent;
+import android.net.Uri;
+import android.os.Build;
+import android.os.StrictMode;
+
+import java.io.File;
+
+public class FileOpenUtils {
+ private static final String[][] MIME_MapTable={
+ //{suffix name, MIME type}
+ {".3gp", "video/3gpp"},
+ {".apk", "application/vnd.android.package-archive"},
+ {".asf", "video/x-ms-asf"},
+ {".avi", "video/x-msvideo"},
+ {".bin", "application/octet-stream"},
+ {".bmp", "image/bmp"},
+ {".c", "text/plain"},
+ {".class", "application/octet-stream"},
+ {".conf", "text/plain"},
+ {".cpp", "text/plain"},
+ {".doc", "application/msword"},
+ {".docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"},
+ {".xls", "application/vnd.ms-excel"},
+ {".xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"},
+ {".exe", "application/octet-stream"},
+ {".gif", "image/gif"},
+ {".gtar", "application/x-gtar"},
+ {".gz", "application/x-gzip"},
+ {".h", "text/plain"},
+ {".htm", "text/html"},
+ {".html", "text/html"},
+ {".jar", "application/java-archive"},
+ {".java", "text/plain"},
+ {".jpeg", "image/jpeg"},
+ {".jpg", "image/jpeg"},
+ {".js", "application/x-javascript"},
+ {".log", "text/plain"},
+ {".m3u", "audio/x-mpegurl"},
+ {".m4a", "audio/mp4a-latm"},
+ {".m4b", "audio/mp4a-latm"},
+ {".m4p", "audio/mp4a-latm"},
+ {".m4u", "video/vnd.mpegurl"},
+ {".m4v", "video/x-m4v"},
+ {".mov", "video/quicktime"},
+ {".mp2", "audio/x-mpeg"},
+ {".mp3", "audio/x-mpeg"},
+ {".mp4", "video/mp4"},
+ {".mpc", "application/vnd.mpohun.certificate"},
+ {".mpe", "video/mpeg"},
+ {".mpeg", "video/mpeg"},
+ {".mpg", "video/mpeg"},
+ {".mpg4", "video/mp4"},
+ {".mpga", "audio/mpeg"},
+ {".msg", "application/vnd.ms-outlook"},
+ {".ogg", "audio/ogg"},
+ {".pdf", "application/pdf"},
+ {".png", "image/png"},
+ {".pps", "application/vnd.ms-powerpoint"},
+ {".ppt", "application/vnd.ms-powerpoint"},
+ {".pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation"},
+ {".prop", "text/plain"},
+ {".rc", "text/plain"},
+ {".rmvb", "audio/x-pn-realaudio"},
+ {".rtf", "application/rtf"},
+ {".sh", "text/plain"},
+ {".tar", "application/x-tar"},
+ {".tgz", "application/x-compressed"},
+ {".txt", "text/plain"},
+ {".wav", "audio/x-wav"},
+ {".wma", "audio/x-ms-wma"},
+ {".wmv", "audio/x-ms-wmv"},
+ {".wps", "application/vnd.ms-works"},
+ {".xml", "text/plain"},
+ {".z", "application/x-compress"},
+ {".zip", "application/x-zip-compressed"},
+ {"", "*/*"}
+ };
+
+ /**
+ * Obtain the corresponding MIME type according to the file suffix.
+ * @param file
+ */
+ public static String getMIMEType(File file) {
+
+ String type="*/*";
+ String fName = file.getName();
+ //Get the position of the separator "." before the suffix name in fName.
+ int dotIndex = fName.lastIndexOf(".");
+ if(dotIndex < 0){
+ return type;
+ }
+ /* Get file extension */
+ String end=fName.substring(dotIndex,fName.length()).toLowerCase();
+ if(end=="")return type;
+ //Find the corresponding MIME type in the MIME and file type match table.
+ for(int i=0;i=Build.VERSION_CODES.N)
+ {
+ StrictMode.VmPolicy.Builder builder = new StrictMode.VmPolicy.Builder();
+ StrictMode.setVmPolicy(builder.build());
+ }
+
+// intent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK);
+ //Set the Action property of the intent
+ intent.setAction(Intent.ACTION_VIEW);
+ //Get the MIME type of the file file
+ String type = getMIMEType(file);
+ //Set the data and Type properties of the intent.
+ intent.setDataAndType(/*uri*/Uri.fromFile(file), type);
+ //Jump
+ context.startActivity(intent);
+
+ }
+
+
+ public static boolean isImage(File file)
+ {
+ String type = getType(file);
+ if (".jpg".equals(type)||".png".equals(type))
+ {
+ return true;
+ }
+ return false;
+
+ }
+
+}
diff --git a/android/app/src/main/java/ai/fedml/utils/FormatUtils.java b/android/app/src/main/java/ai/fedml/utils/FormatUtils.java
new file mode 100644
index 0000000000..1e7d6bf724
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/utils/FormatUtils.java
@@ -0,0 +1,61 @@
+package ai.fedml.utils;
+
+import android.content.Context;
+import android.text.format.DateFormat;
+
+import java.math.BigDecimal;
+import java.util.Date;
+
+import ai.fedml.edge.service.ContextHolder;
+
+public class FormatUtils {
+ private FormatUtils() {
+ }
+
+ /**
+ * Convert long type to String type
+ *
+ * @param milSecond currentTime time of type long to convert
+ * @return LongDateFormat
+ */
+ public static String longToString(long milSecond) {
+ Context context = ContextHolder.getAppContext();
+ return DateFormat.getLongDateFormat(context).format(new Date(milSecond));
+ }
+
+ /**
+ * byte convert to kbãmbãgbãtb
+ *
+ * @param average size in bytes
+ * @return File size
+ */
+ public static String unitConversion(long average) {
+ double temp = average;
+ if (temp < 1024) {
+ BigDecimal result1 = new BigDecimal(temp);
+ return result1.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue() + "B";
+ }
+ temp = temp / 1024;
+ if (temp < 1024) {
+ BigDecimal result1 = new BigDecimal(temp);
+ return result1.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue() + "KB";
+ }
+
+ temp = temp / 1024;
+ if (temp < 1024) {
+ BigDecimal result1 = new BigDecimal(temp);
+ return result1.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue() + "MB";
+ }
+ temp = temp / 1024;
+ if (temp < 1024) {
+ BigDecimal result1 = new BigDecimal(temp);
+ return result1.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue() + "GB";
+ }
+ temp = temp / 1024;
+ if (temp < 1024) {
+ BigDecimal result1 = new BigDecimal(temp);
+ return result1.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue() + "TB";
+ }
+ return "0";
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/utils/StatusBarUtil.java b/android/app/src/main/java/ai/fedml/utils/StatusBarUtil.java
new file mode 100644
index 0000000000..610b504936
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/utils/StatusBarUtil.java
@@ -0,0 +1,69 @@
+package ai.fedml.utils;
+
+import android.annotation.TargetApi;
+import android.app.Activity;
+import android.graphics.Color;
+import android.os.Build;
+import android.view.View;
+import android.view.ViewGroup;
+import android.view.WindowManager;
+
+/**
+ * Adapt to full screen status bar
+ */
+public class StatusBarUtil {
+
+
+ /**
+ * Make the status bar fully transparent
+ *
+ * @param activity needs to be set up
+ */
+ public static void setTransparent(Activity activity) {
+ if (Build.VERSION.SDK_INT < Build.VERSION_CODES.KITKAT) {
+ return;
+ }
+ transparentStatusBar(activity);
+ setRootView(activity);
+ }
+
+ /**
+ * make the status bar transparent
+ */
+ @TargetApi(Build.VERSION_CODES.KITKAT)
+ private static void transparentStatusBar(Activity activity) {
+ // Set Android 6.0 + to achieve status bar text color and icon light black
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ activity.getWindow().getDecorView().setSystemUiVisibility(
+ View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN | View.SYSTEM_UI_FLAG_LIGHT_STATUS_BAR);
+ }
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) {
+ activity.getWindow().addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
+ activity.getWindow().clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
+ //You need to set this flag contentView to extend to the status bar, and the bottom will also extend past
+// activity.getWindow().addFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_NAVIGATION);
+ //The status bar is overlaid on the contentView, and the transparency is set to make the background of the contentView show through
+ activity.getWindow().setStatusBarColor(Color.TRANSPARENT);
+ } else {
+ //Let the contentView extend to the status bar and set the status bar color to be transparent
+ activity.getWindow().addFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
+ }
+ }
+
+ /**
+ * Set root layout parameters
+ */
+ private static void setRootView(Activity activity) {
+ ViewGroup parent = (ViewGroup) activity.findViewById(android.R.id.content);
+ for (int i = 0, count = parent.getChildCount(); i < count; i++) {
+ View childView = parent.getChildAt(i);
+ if (childView instanceof ViewGroup) {
+ childView.setFitsSystemWindows(true);
+ ((ViewGroup) childView).setClipToPadding(true);
+ }
+ }
+ }
+
+
+
+}
diff --git a/android/app/src/main/java/ai/fedml/utils/ToastUtils.java b/android/app/src/main/java/ai/fedml/utils/ToastUtils.java
new file mode 100644
index 0000000000..50b9a2a8a7
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/utils/ToastUtils.java
@@ -0,0 +1,52 @@
+package ai.fedml.utils;
+
+import android.content.Context;
+import android.content.res.Resources;
+import android.os.Handler;
+import android.os.Looper;
+import android.widget.Toast;
+
+import ai.fedml.edge.service.ContextHolder;
+
+public class ToastUtils {
+ private static final Handler sMainHandler = new Handler(Looper.getMainLooper());
+
+ /**
+ * display Toast
+ *
+ * @param text
+ */
+ public static void show(CharSequence text) {
+ Runnable toastRunnable = () -> {
+ Context context = ContextHolder.getAppContext();
+ if (text == null || text.equals("")) return;
+ // If the displayed text exceeds 10, display the long toast, otherwise display the short toast
+ int duration = Toast.LENGTH_SHORT;
+ if (text.length() > 20) {
+ duration = Toast.LENGTH_LONG;
+ }
+ Toast.makeText(context, text, duration).show();
+ };
+ if (Looper.getMainLooper() == Looper.myLooper()) {
+ toastRunnable.run();
+ } else {
+ sMainHandler.post(toastRunnable);
+ }
+
+ }
+
+ /**
+ * Display Toast
+ *
+ * @param id If the correct string id is passed in, the corresponding string will be displayed, if not, an integer string will be displayed
+ */
+ public static void show(int id) {
+ try {
+ // if this is a resource id
+ show(ContextHolder.getAppContext().getResources().getText(id));
+ } catch (Resources.NotFoundException ignored) {
+ // if this is an int type
+ show(String.valueOf(id));
+ }
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/widget/CircleImageView.java b/android/app/src/main/java/ai/fedml/widget/CircleImageView.java
new file mode 100644
index 0000000000..98e4394f83
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/widget/CircleImageView.java
@@ -0,0 +1,84 @@
+package ai.fedml.widget;
+
+import android.content.Context;
+import android.graphics.Bitmap;
+import android.graphics.BitmapShader;
+import android.graphics.Canvas;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Shader;
+import android.graphics.drawable.BitmapDrawable;
+import android.graphics.drawable.Drawable;
+import android.util.AttributeSet;
+
+/**
+ * custom circular image
+ */
+public class CircleImageView extends androidx.appcompat.widget.AppCompatImageView {
+
+ private Paint mPaint = new Paint(); //brush
+
+ private int mRadius; //The radius of the circular image
+
+ private float mScale; //The zoom ratio of the image
+
+ public CircleImageView(Context context) {
+ super(context);
+ }
+
+ public CircleImageView(Context context, AttributeSet attrs) {
+ super(context, attrs);
+ }
+
+ public CircleImageView(Context context, AttributeSet attrs, int defStyleAttr) {
+ super(context, attrs, defStyleAttr);
+ }
+
+ @Override
+ protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ //Because it is a circular image, the width and height should be consistent
+ int size = Math.min(getMeasuredWidth(), getMeasuredHeight());
+ mRadius = size / 2;
+
+ setMeasuredDimension(size, size);
+ }
+
+ @Override
+ protected void onDraw(Canvas canvas) {
+
+ Bitmap bitmap = drawableToBitmap(getDrawable());
+
+ //Initialize BitmapShader, pass in the bitmap object
+ BitmapShader bitmapShader = new BitmapShader(bitmap, Shader.TileMode.CLAMP, Shader.TileMode.CLAMP);
+
+ //Calculate scaling
+ mScale = (mRadius * 2.0f) / Math.min(bitmap.getHeight(), bitmap.getWidth());
+
+ Matrix matrix = new Matrix();
+ matrix.setScale(mScale, mScale);
+ bitmapShader.setLocalMatrix(matrix);
+
+
+ mPaint.setShader(bitmapShader);
+
+ //Draw a circle, specify the center point coordinates, radius, brush
+ canvas.drawCircle(mRadius, mRadius, mRadius, mPaint);
+ }
+
+ //Write a drawble to BitMap method
+ private Bitmap drawableToBitmap(Drawable drawable) {
+ if (drawable instanceof BitmapDrawable) {
+ BitmapDrawable bd = (BitmapDrawable) drawable;
+ return bd.getBitmap();
+ }
+ int w = drawable.getIntrinsicWidth();
+ int h = drawable.getIntrinsicHeight();
+ Bitmap bitmap = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
+ Canvas canvas = new Canvas(bitmap);
+ drawable.setBounds(0, 0, w, h);
+ drawable.draw(canvas);
+ return bitmap;
+ }
+
+}
diff --git a/android/app/src/main/java/ai/fedml/widget/CompletedProgressView.java b/android/app/src/main/java/ai/fedml/widget/CompletedProgressView.java
new file mode 100644
index 0000000000..461eff9c95
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/widget/CompletedProgressView.java
@@ -0,0 +1,165 @@
+package ai.fedml.widget;
+
+import android.content.Context;
+import android.content.res.TypedArray;
+import android.graphics.Canvas;
+import android.graphics.Paint;
+import android.graphics.RectF;
+import android.text.TextUtils;
+import android.util.AttributeSet;
+import android.view.View;
+
+import ai.fedml.R;
+
+/**
+ * Custom circular progress bar
+ */
+public class CompletedProgressView extends View {
+
+ // total progress
+ private static final int TOTAL_PROGRESS = 100;
+ // Paintbrush for drawing a filled circle
+ private Paint mCirclePaint;
+ // brush for drawing circles
+ private Paint mRingPaint;
+ // The background color of the brush for drawing the ring
+ private Paint mRingPaintBg;
+ // brush for drawing fonts
+ private Paint mTextPaint;
+ // circle color
+ private int mCircleColor;
+ // ring color
+ private int mRingColor;
+ // Ring background color
+ private int mRingBgColor;
+ // radius
+ private float mRadius;
+ // Ring radius
+ private float mRingRadius;
+ // Ring width
+ private float mStrokeWidth;
+ // word height
+ private float mTxtHeight;
+ // current progress
+ private int mProgress;
+ private RectF mOuterRect;
+ private String mStatus;
+
+ public CompletedProgressView(Context context, AttributeSet attrs) {
+ super(context, attrs);
+ // Get custom properties
+ initAttrs(context, attrs);
+ initVariable();
+ }
+
+ //properties
+ private void initAttrs(Context context, AttributeSet attrs) {
+ TypedArray typeArray = context.getTheme().obtainStyledAttributes(attrs,
+ R.styleable.TasksCompletedView, 0, 0);
+ mRadius = typeArray.getDimension(R.styleable.TasksCompletedView_radius, 80);
+ mStrokeWidth = typeArray.getDimension(R.styleable.TasksCompletedView_strokeWidth, 10);
+ mCircleColor = typeArray.getColor(R.styleable.TasksCompletedView_circleColor, 0xFFFFFFFF);
+ mRingColor = typeArray.getColor(R.styleable.TasksCompletedView_ringColor, 0xFFFFFFFF);
+ mRingBgColor = typeArray.getColor(R.styleable.TasksCompletedView_ringBgColor, 0xFFFFFFFF);
+ mProgress = typeArray.getInteger(R.styleable.TasksCompletedView_progress, 0);
+
+ mRingRadius = mRadius + mStrokeWidth / 2;
+ }
+
+ //Initialize brush
+ private void initVariable() {
+ //inner circle
+ mCirclePaint = new Paint();
+ mCirclePaint.setAntiAlias(true);
+ mCirclePaint.setColor(mCircleColor);
+ mCirclePaint.setStyle(Paint.Style.FILL);
+
+ //Outer arc background
+ mRingPaintBg = new Paint();
+ mRingPaintBg.setAntiAlias(true);
+ mRingPaintBg.setColor(mRingBgColor);
+ mRingPaintBg.setStyle(Paint.Style.STROKE);
+ mRingPaintBg.setStrokeWidth(mStrokeWidth);
+
+
+ //Outer arc
+ mRingPaint = new Paint();
+ mRingPaint.setAntiAlias(true);
+ mRingPaint.setColor(mRingColor);
+ mRingPaint.setStyle(Paint.Style.STROKE);
+ mRingPaint.setStrokeWidth(mStrokeWidth);
+ //mRingPaint.setStrokeCap(Paint.Cap.ROUND);//Set the line style, there are circles and squares
+
+ //middle word
+ mTextPaint = new Paint();
+ mTextPaint.setAntiAlias(true);
+ mTextPaint.setStyle(Paint.Style.FILL);
+ mTextPaint.setColor(getResources().getColor(R.color.color_3C4043));
+ mTextPaint.setTextSize(mRadius / 2);
+
+ Paint.FontMetrics fm = mTextPaint.getFontMetrics();
+ mTxtHeight = (int) Math.ceil(fm.descent - fm.ascent);
+ }
+
+ @Override
+ protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ }
+
+ //draw
+ @Override
+ protected void onDraw(Canvas canvas) {
+ // The x-coordinate of the center of the circle
+ int mXCenter = getWidth() / 2;
+ // The y coordinate of the center of the circle
+ int mYCenter = getHeight() / 2;
+
+ //inner circle
+ canvas.drawCircle(mXCenter, mYCenter, mRadius, mCirclePaint);
+
+ //Outer arc background
+ if (mOuterRect == null) {
+ mOuterRect = new RectF();
+ mOuterRect.left = (mXCenter - mRingRadius);
+ mOuterRect.top = (mYCenter - mRingRadius);
+ mOuterRect.right = mRingRadius * 2 + (mXCenter - mRingRadius);
+ mOuterRect.bottom = mRingRadius * 2 + (mYCenter - mRingRadius);
+ }
+ canvas.drawArc(mOuterRect, 0, 360, false, mRingPaintBg);
+
+ //The ellipse object where the arc is located, the starting angle of the arc, the angle of the arc, whether to display the radius connection
+
+ //Outer arc
+ if (mProgress > 0) {
+ canvas.drawArc(mOuterRect, -90, ((float) mProgress / TOTAL_PROGRESS) * 360, false, mRingPaint); //
+ }
+
+ //fonts
+ String txt = mStatus;
+ if (TextUtils.isEmpty(mStatus) && mProgress > 0) {
+ txt = mProgress + "%";
+ }
+ if (!TextUtils.isEmpty(txt)) {
+ // word length
+ float mTxtWidth = mTextPaint.measureText(txt, 0, txt.length());
+ canvas.drawText(txt, mXCenter - mTxtWidth / 2, mYCenter + mTxtHeight / 4, mTextPaint);
+ }
+ }
+
+ //set the progress
+ public void setProgress(int progress) {
+ mProgress = progress;
+ mStatus = null;
+ postInvalidate();
+ }
+
+ /**
+ * Set text status
+ *
+ * @param status status
+ */
+ public void setStatus(String status) {
+ mStatus = status;
+ postInvalidate();
+ }
+}
diff --git a/android/app/src/main/java/ai/fedml/widget/CouponTextView.java b/android/app/src/main/java/ai/fedml/widget/CouponTextView.java
new file mode 100644
index 0000000000..ef9cc4df5d
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/widget/CouponTextView.java
@@ -0,0 +1,65 @@
+package ai.fedml.widget;
+
+import android.content.Context;
+import android.content.res.TypedArray;
+import android.graphics.Canvas;
+import android.graphics.Paint;
+import android.graphics.RectF;
+import android.util.AttributeSet;
+
+import ai.fedml.R;
+import androidx.core.content.ContextCompat;
+
+/**
+ * custom bump layout
+ */
+public class CouponTextView extends androidx.appcompat.widget.AppCompatTextView {
+
+ private Paint mPaint;
+ private Context mContext;
+ private int mColor;
+ private int mHeight;
+ private RectF mRectF;
+
+ public CouponTextView(Context context) {
+ this(context, null);
+
+ }
+
+ public CouponTextView(Context context, AttributeSet attrs) {
+ this(context, attrs, 0);
+
+ }
+
+ public CouponTextView(Context context, AttributeSet attrs, int defStyleAttr) {
+ super(context, attrs, defStyleAttr);
+ TypedArray array = context.obtainStyledAttributes(attrs, R.styleable.CouponTextView);
+ mColor = ContextCompat.getColor(context, R.color.white);
+ mColor = array.getColor(R.styleable.CouponTextView_bg_color, mColor);
+ mHeight = array.getDimensionPixelSize(R.styleable.CouponTextView_android_height, 30);
+ mContext = context;
+ initPaint();
+ array.recycle();
+ }
+
+ private void initPaint() {
+ mPaint =new Paint();
+ mPaint.setColor(mColor);
+ mPaint.setStrokeWidth(12f);
+ mPaint.setAntiAlias(true);
+ }
+
+ @Override
+ protected void onDraw(Canvas canvas) {
+
+ super.onDraw(canvas);
+ if(mRectF == null) {
+ mRectF = new RectF(0, 0, getMeasuredWidth(), mHeight);
+ }
+ canvas.drawRect(mRectF, mPaint);
+ mPaint.setColor(ContextCompat.getColor(mContext, R.color.color_F5F6FA));
+ canvas.drawCircle(getMeasuredWidth()/2, 0,50, mPaint);
+ }
+
+
+}
diff --git a/android/app/src/main/java/ai/fedml/widget/PopupwindNormal.java b/android/app/src/main/java/ai/fedml/widget/PopupwindNormal.java
new file mode 100644
index 0000000000..4c928df577
--- /dev/null
+++ b/android/app/src/main/java/ai/fedml/widget/PopupwindNormal.java
@@ -0,0 +1,74 @@
+package ai.fedml.widget;
+
+import android.content.Context;
+import android.graphics.drawable.ColorDrawable;
+import android.view.LayoutInflater;
+import android.view.View;
+import android.view.ViewGroup;
+import android.widget.Button;
+import android.widget.PopupWindow;
+import android.widget.TextView;
+
+import ai.fedml.R;
+
+
+public class PopupwindNormal extends PopupWindow {
+
+ private View mView;
+ private TextView tv_title;
+ private TextView tv_content;
+
+ private Button btn_cancel;
+ private Button btn_ok;
+
+
+ /**
+ *
+ * @param context
+ * @param onClickListener
+ * @param title
+ * @param content
+ */
+ public PopupwindNormal(Context context, View.OnClickListener onClickListener, String title, String content) {
+ super(context);
+
+ LayoutInflater inflater = (LayoutInflater) context.getSystemService(Context.LAYOUT_INFLATER_SERVICE);
+ mView = inflater.inflate(R.layout.popup_normal, null);
+
+ tv_title = mView.findViewById(R.id.tv_title);
+ tv_content = mView.findViewById(R.id.tv_content);
+ btn_cancel = mView.findViewById(R.id.btn_cancel);
+ btn_ok = mView.findViewById(R.id.btn_ok);
+
+ btn_ok.setOnClickListener(onClickListener);
+ btn_cancel.setOnClickListener(new View.OnClickListener() {
+ @Override
+ public void onClick(View v) {
+ dismiss();
+ }
+ });
+
+ tv_title.setText(title);
+ tv_content.setText(content);
+
+
+
+
+ // Set the View of PopupWindow
+ this.setContentView(mView);
+ // Set the width of the PopupWindow pop-up form
+ this.setWidth(ViewGroup.LayoutParams.MATCH_PARENT);
+ // Set the height of the PopupWindow pop-up form
+ this.setHeight(ViewGroup.LayoutParams.MATCH_PARENT);
+ // Set the PopupWindow pop-up form to be clickable
+ this.setFocusable(true);
+ // Set PopupWindow pop-up form animation effect
+ this.setAnimationStyle(R.style.AnimFadePopup);
+ // Instantiate a ColorDrawable with a color of black with 25% opacity
+ ColorDrawable dw = new ColorDrawable(0x40000000);
+ // Set the background of the PopupWindow popup form
+ this.setBackgroundDrawable(dw);
+
+ }
+
+}
diff --git a/android/app/src/main/res/anim/fade_in.xml b/android/app/src/main/res/anim/fade_in.xml
new file mode 100644
index 0000000000..d4193790cf
--- /dev/null
+++ b/android/app/src/main/res/anim/fade_in.xml
@@ -0,0 +1,8 @@
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/anim/fade_out.xml b/android/app/src/main/res/anim/fade_out.xml
new file mode 100644
index 0000000000..6b1b69a51a
--- /dev/null
+++ b/android/app/src/main/res/anim/fade_out.xml
@@ -0,0 +1,8 @@
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/drawable/border_edit.xml b/android/app/src/main/res/drawable/border_edit.xml
new file mode 100644
index 0000000000..c5569dde6e
--- /dev/null
+++ b/android/app/src/main/res/drawable/border_edit.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/drawable/radius_whit_bg.xml b/android/app/src/main/res/drawable/radius_whit_bg.xml
new file mode 100644
index 0000000000..034ba9173f
--- /dev/null
+++ b/android/app/src/main/res/drawable/radius_whit_bg.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/layout/activity_guide.xml b/android/app/src/main/res/layout/activity_guide.xml
new file mode 100644
index 0000000000..699bc1c304
--- /dev/null
+++ b/android/app/src/main/res/layout/activity_guide.xml
@@ -0,0 +1,53 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/layout/activity_home.xml b/android/app/src/main/res/layout/activity_home.xml
new file mode 100644
index 0000000000..1698e18b09
--- /dev/null
+++ b/android/app/src/main/res/layout/activity_home.xml
@@ -0,0 +1,206 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/layout/activity_scan_code.xml b/android/app/src/main/res/layout/activity_scan_code.xml
new file mode 100644
index 0000000000..8ece9e4adc
--- /dev/null
+++ b/android/app/src/main/res/layout/activity_scan_code.xml
@@ -0,0 +1,125 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/layout/activity_set_file_path.xml b/android/app/src/main/res/layout/activity_set_file_path.xml
new file mode 100644
index 0000000000..3df60d0430
--- /dev/null
+++ b/android/app/src/main/res/layout/activity_set_file_path.xml
@@ -0,0 +1,67 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/layout/fragment_mnn_train.xml b/android/app/src/main/res/layout/fragment_mnn_train.xml
new file mode 100644
index 0000000000..f8b35e3c6b
--- /dev/null
+++ b/android/app/src/main/res/layout/fragment_mnn_train.xml
@@ -0,0 +1,127 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/layout/popup_normal.xml b/android/app/src/main/res/layout/popup_normal.xml
new file mode 100644
index 0000000000..29a7b16351
--- /dev/null
+++ b/android/app/src/main/res/layout/popup_normal.xml
@@ -0,0 +1,81 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/android/app/src/main/res/layout/rv_item_empty.xml b/android/app/src/main/res/layout/rv_item_empty.xml
new file mode 100644
index 0000000000..3b397be481
--- /dev/null
+++ b/android/app/src/main/res/layout/rv_item_empty.xml
@@ -0,0 +1,18 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/layout/rv_item_file.xml b/android/app/src/main/res/layout/rv_item_file.xml
new file mode 100644
index 0000000000..8973915324
--- /dev/null
+++ b/android/app/src/main/res/layout/rv_item_file.xml
@@ -0,0 +1,69 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/android/app/src/main/res/mipmap-hdpi/ic_launcher.png b/android/app/src/main/res/mipmap-hdpi/ic_launcher.png
new file mode 100644
index 0000000000..78d460fe4c
Binary files /dev/null and b/android/app/src/main/res/mipmap-hdpi/ic_launcher.png differ
diff --git a/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png b/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png
new file mode 100644
index 0000000000..78d460fe4c
Binary files /dev/null and b/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png differ
diff --git a/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png b/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png
new file mode 100644
index 0000000000..78d460fe4c
Binary files /dev/null and b/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png differ
diff --git a/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png b/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png
new file mode 100644
index 0000000000..78d460fe4c
Binary files /dev/null and b/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_dir.png b/android/app/src/main/res/mipmap-xxhdpi/ic_dir.png
new file mode 100644
index 0000000000..38503eab16
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_dir.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_file.png b/android/app/src/main/res/mipmap-xxhdpi/ic_file.png
new file mode 100644
index 0000000000..36a6edfa0c
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_file.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_guide.png b/android/app/src/main/res/mipmap-xxhdpi/ic_guide.png
new file mode 100644
index 0000000000..54b1dbcdfb
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_guide.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png b/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png
new file mode 100644
index 0000000000..78d460fe4c
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png b/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png
new file mode 100644
index 0000000000..78d460fe4c
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_or.png b/android/app/src/main/res/mipmap-xxhdpi/ic_or.png
new file mode 100644
index 0000000000..47b24f5917
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_or.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_privacy_normal.png b/android/app/src/main/res/mipmap-xxhdpi/ic_privacy_normal.png
new file mode 100644
index 0000000000..da6124c5fc
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_privacy_normal.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_privacy_pressed.png b/android/app/src/main/res/mipmap-xxhdpi/ic_privacy_pressed.png
new file mode 100644
index 0000000000..524e7d3f64
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_privacy_pressed.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_scan_code.png b/android/app/src/main/res/mipmap-xxhdpi/ic_scan_code.png
new file mode 100644
index 0000000000..5e034a4fd6
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_scan_code.png differ
diff --git a/android/app/src/main/res/mipmap-xxhdpi/ic_shijiali.webp b/android/app/src/main/res/mipmap-xxhdpi/ic_shijiali.webp
new file mode 100644
index 0000000000..650320d19e
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxhdpi/ic_shijiali.webp differ
diff --git a/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png b/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png
new file mode 100644
index 0000000000..78d460fe4c
Binary files /dev/null and b/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png differ
diff --git a/android/app/src/main/res/values/attrs.xml b/android/app/src/main/res/values/attrs.xml
new file mode 100644
index 0000000000..a3d5ec59e3
--- /dev/null
+++ b/android/app/src/main/res/values/attrs.xml
@@ -0,0 +1,21 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/values/colors.xml b/android/app/src/main/res/values/colors.xml
new file mode 100644
index 0000000000..e994c12a47
--- /dev/null
+++ b/android/app/src/main/res/values/colors.xml
@@ -0,0 +1,29 @@
+
+
+ #6200EE
+ #3700B3
+ #03DAC5
+ #3CB371
+ #87CEFA
+ #D2E9FF
+ #B0C4DE
+ #DCDCDC
+ #FFFFFF
+ #000000
+
+
+ #54A158
+ #31708F
+ #0E1328
+ #686C7D
+ #F5F6FA
+ #F1F1F1
+ #2FA9CDFE
+ #A9CDFE
+ #3C4043
+ #008CFF
+ #999999
+ #EAEAEA
+ #333333
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/values/strings.xml b/android/app/src/main/res/values/strings.xml
new file mode 100644
index 0000000000..0776eef909
--- /dev/null
+++ b/android/app/src/main/res/values/strings.xml
@@ -0,0 +1,19 @@
+
+ FedML Edge
+ FedML
+ FedML MLOps Cloud Platform
+ https://fedml.ai
+ Please input account ID
+ Please retry
+ Set Private Path
+ IDLE
+ Waiting for training
+ Account Information[%1$s]
+ Hyper-parameters
+ number of rounds
+ local iteration number
+ Unbind Account
+ Exit
+ Round:%1$d Epoch:%2$d \nAcc:%3$.4f Loss:%4$.4f
+ FedML
+
\ No newline at end of file
diff --git a/android/app/src/main/res/values/styles.xml b/android/app/src/main/res/values/styles.xml
new file mode 100644
index 0000000000..5db20661ba
--- /dev/null
+++ b/android/app/src/main/res/values/styles.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/main/res/xml/file_paths.xml b/android/app/src/main/res/xml/file_paths.xml
new file mode 100644
index 0000000000..8e9eb8f919
--- /dev/null
+++ b/android/app/src/main/res/xml/file_paths.xml
@@ -0,0 +1,28 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/android/app/src/test/java/ai/example/websocket/ExampleUnitTest.java b/android/app/src/test/java/ai/example/websocket/ExampleUnitTest.java
new file mode 100644
index 0000000000..7d19f3ad2f
--- /dev/null
+++ b/android/app/src/test/java/ai/example/websocket/ExampleUnitTest.java
@@ -0,0 +1,36 @@
+package ai.example.websocket;
+
+import org.junit.Test;
+
+import java.io.IOException;
+
+
+import ai.fedml.client.RetrofitManager;
+import ai.fedml.edge.request.response.BaseResponse;
+import ai.fedml.edge.request.parameter.BindingAccountReq;
+import ai.fedml.edge.request.response.BindingResponse;
+import ai.fedml.edge.request.response.UserInfoResponse;
+import ai.fedml.client.entity.VersionUpdateResponse;
+import retrofit2.Call;
+import retrofit2.Response;
+
+import static org.junit.Assert.*;
+
+/**
+ * Example local unit test, which will execute on the development machine (host).
+ *
+ * @see Testing documentation
+ */
+public class ExampleUnitTest {
+ @Test
+ public void addition_isCorrect() {
+ assertEquals(4, 2 + 2);
+ }
+
+ @Test
+ public void VersionUpdateTest() throws IOException {
+ Call call = RetrofitManager.getVersionUpdater().getLatestVersion();
+ Response response = call.execute();
+ System.out.println("bindingAccountTest onResponse: " + (response.body() != null ? response.body() : null));
+ }
+}
\ No newline at end of file
diff --git a/android/build.gradle b/android/build.gradle
new file mode 100644
index 0000000000..702f900c95
--- /dev/null
+++ b/android/build.gradle
@@ -0,0 +1,28 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+buildscript {
+ repositories {
+ google()
+ mavenCentral()
+ maven { url 'https://s01.oss.sonatype.org/content/repositories/snapshots' }
+ }
+ dependencies {
+ classpath 'com.android.tools.build:gradle:7.0.4'
+ classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:1.6.10"
+
+ // NOTE: Do not place your application dependencies here; they belong
+ // in the individual module build.gradle files
+ }
+}
+
+allprojects {
+ repositories {
+ google()
+ mavenCentral()
+ maven { url 'https://s01.oss.sonatype.org/content/repositories/snapshots' }
+ maven { url 'https://jitpack.io' }
+ }
+}
+
+task clean(type: Delete) {
+ delete rootProject.buildDir
+}
\ No newline at end of file
diff --git a/android/data/MNIST/raw/t10k-images-idx3-ubyte b/android/data/MNIST/raw/t10k-images-idx3-ubyte
new file mode 100644
index 0000000000..1170b2cae9
Binary files /dev/null and b/android/data/MNIST/raw/t10k-images-idx3-ubyte differ
diff --git a/android/data/MNIST/raw/t10k-labels-idx1-ubyte b/android/data/MNIST/raw/t10k-labels-idx1-ubyte
new file mode 100644
index 0000000000..d1c3a97061
Binary files /dev/null and b/android/data/MNIST/raw/t10k-labels-idx1-ubyte differ
diff --git a/android/data/MNIST/raw/train-images-idx3-ubyte b/android/data/MNIST/raw/train-images-idx3-ubyte
new file mode 100644
index 0000000000..bbce27659e
Binary files /dev/null and b/android/data/MNIST/raw/train-images-idx3-ubyte differ
diff --git a/android/data/MNIST/raw/train-labels-idx1-ubyte b/android/data/MNIST/raw/train-labels-idx1-ubyte
new file mode 100644
index 0000000000..d6b4c5db3b
Binary files /dev/null and b/android/data/MNIST/raw/train-labels-idx1-ubyte differ
diff --git a/android/data/mnn_model/lenet.py b/android/data/mnn_model/lenet.py
new file mode 100644
index 0000000000..62e275d844
--- /dev/null
+++ b/android/data/mnn_model/lenet.py
@@ -0,0 +1,59 @@
+import MNN
+nn = MNN.nn
+F = MNN.expr
+
+
+class Lenet_cifar(nn.Module):
+ """construct a lenet 5 model"""
+ def __init__(self):
+ super(Lenet_cifar, self).__init__()
+ self.conv1 = nn.conv(3, 6, [5, 5])
+ self.conv2 = nn.conv(6, 16, [5, 5])
+ self.fc1 = nn.linear(400, 120)
+ self.fc2 = nn.linear(120, 84)
+ self.fc3 = nn.linear(84, 10)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ x = F.relu(self.conv2(x))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ # MNN use NC4HW4 format for convs, so we need to convert it to NCHW before entering other ops
+ x = F.convert(x, F.NCHW)
+ x = F.reshape(x, [0, -1])
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ x = F.softmax(x, 1)
+ return x
+
+class Lenet_mnist(nn.Module):
+ """construct a lenet 5 model"""
+ def __init__(self):
+ super(Lenet_mnist, self).__init__()
+ self.conv1 = nn.conv(1, 20, [5, 5])
+ self.conv2 = nn.conv(20, 50, [5, 5])
+ self.fc1 = nn.linear(800, 500)
+ self.fc2 = nn.linear(500, 10)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ x = F.relu(self.conv2(x))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ # MNN use NC4HW4 format for convs, so we need to convert it to NCHW before entering other ops
+ x = F.convert(x, F.NCHW)
+ x = F.reshape(x, [0, -1])
+ x = F.relu(self.fc1(x))
+ x = self.fc2(x)
+ x = F.softmax(x, 1)
+ return x
+
+net = Lenet_mnist()
+# net.train(True)
+input_var = MNN.expr.placeholder([1, 1, 28, 28], MNN.expr.NCHW)
+predicts = net.forward(input_var)
+# print(predicts)
+F.save([predicts], "lenet_mnist.mnn")
+
+
diff --git a/android/data/mnn_model/lenet_cifar10.mnn b/android/data/mnn_model/lenet_cifar10.mnn
new file mode 100644
index 0000000000..a0bf711e6b
Binary files /dev/null and b/android/data/mnn_model/lenet_cifar10.mnn differ
diff --git a/android/data/mnn_model/lenet_mnist.mnn b/android/data/mnn_model/lenet_mnist.mnn
new file mode 100644
index 0000000000..1fde3dc8b0
Binary files /dev/null and b/android/data/mnn_model/lenet_mnist.mnn differ
diff --git a/android/data/mnn_model/mobilenetv2.py b/android/data/mnn_model/mobilenetv2.py
new file mode 100644
index 0000000000..fccfba810e
--- /dev/null
+++ b/android/data/mnn_model/mobilenetv2.py
@@ -0,0 +1,114 @@
+
+import re
+from turtle import forward
+import MNN
+nn = MNN.nn
+F = MNN.expr
+
+def make_divisible(v, divisor, min_value = None):
+ if min_value is None:
+ min_value = divisor
+
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+class ConvBnRelu(nn.Module):
+ def __init__(self, in_planes, planes, kernel_size=3, stride=1, depthwise=False):
+ super(ConvBnRelu, self).__init__()
+ self.conv = nn.conv(in_planes, planes, kernel_size=[kernel_size,kernel_size], stride=[stride,stride], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME, depthwise=depthwise)
+ self.bn = nn.batch_norm(planes)
+
+ def forward(self, x):
+ out = F.relu6(self.bn(self.conv(x)))
+ return out
+
+class BottleNeck(nn.Module):
+ def __init__(self, in_planes, planes, stride, expand_ratio):
+ super(BottleNeck, self).__init__()
+ expand_planes = in_planes * expand_ratio
+
+ self.use_shortcut = False
+ if stride == 1 and in_planes == planes:
+ self.use_shortcut = True
+ # print(in_planes)
+
+ self.layers = []
+ if expand_ratio != 1:
+ self.layers.append(ConvBnRelu(in_planes, expand_planes, 1))
+
+ self.layers.append(ConvBnRelu(expand_planes, expand_planes, 3, stride, True))
+
+ self.conv = nn.conv(expand_planes, planes, kernel_size=[1,1], stride=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn = nn.batch_norm(planes)
+
+ def forward(self, x):
+ out = x
+ for layer in self.layers:
+ out = layer.forward(out)
+
+ out = self.bn(self.conv(out))
+
+ if self.use_shortcut:
+ out += x
+
+ return out
+
+class MobilenetV2(nn.Module):
+ def __init__(self, num_classes=10, width_mult=1.0, divisor=8):
+ super(MobilenetV2, self).__init__()
+ in_planes = 32
+ last_planes = 1280
+
+ inverted_residual_setting = [
+ [1, 16, 1, 1],
+ [6, 24, 2, 1],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ [6, 320, 1, 1]]
+
+ in_planes = make_divisible(in_planes * width_mult, divisor)
+ last_planes = make_divisible(last_planes * max(1.0, width_mult), divisor)
+
+ self.first_conv = ConvBnRelu(3, in_planes, 3, 1)
+
+ self.bottle_neck_blocks = []
+ for t, c, n, s in inverted_residual_setting:
+ out_planes = make_divisible(c * width_mult, divisor)
+
+ for i in range(n):
+ stride = s if i == 0 else 1
+
+ self.bottle_neck_blocks.append(BottleNeck(in_planes, out_planes, stride, t))
+ in_planes = out_planes
+
+ self.last_conv = ConvBnRelu(in_planes, last_planes, 1)
+ self.dropout = nn.dropout(0.1)
+ self.fc = nn.linear(last_planes, num_classes)
+
+ def forward(self, x):
+ x = self.first_conv.forward(x)
+
+ for layer in self.bottle_neck_blocks:
+ x = layer.forward(x)
+ # print(x.shape)
+
+ x = self.last_conv.forward(x)
+ print(x.shape)
+ x = F.avg_pool(x, kernel=[4,4], stride=[1,1])
+ x = F.convert(x, F.NCHW)
+ x = F.reshape(x, [0, -1])
+ # x = self.dropout(x)
+ x = self.fc(x)
+ out = F.softmax(x, 1)
+ return out
+
+net = MobilenetV2()
+net.train(True)
+input_var = MNN.expr.placeholder([1, 3, 32, 32], MNN.expr.NC4HW4)
+predicts = net.forward(input_var)
+# print(predicts)
+F.save([predicts], "mobilenetv2.mnn")
\ No newline at end of file
diff --git a/android/data/mnn_model/mobilenetv2_cifar10.mnn b/android/data/mnn_model/mobilenetv2_cifar10.mnn
new file mode 100644
index 0000000000..507794f301
Binary files /dev/null and b/android/data/mnn_model/mobilenetv2_cifar10.mnn differ
diff --git a/android/data/mnn_model/resnet18.py b/android/data/mnn_model/resnet18.py
new file mode 100644
index 0000000000..4e06d6aedd
--- /dev/null
+++ b/android/data/mnn_model/resnet18.py
@@ -0,0 +1,85 @@
+import MNN
+nn = MNN.nn
+F = MNN.expr
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super(ResBlock, self).__init__()
+ self.conv1 = nn.conv(in_planes, planes, kernel_size=[3,3], stride=[stride,stride], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn1 = nn.batch_norm(planes)
+ self.conv2 = nn.conv(planes, planes, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn2 = nn.batch_norm(planes)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += x
+ out = F.relu(out)
+ return out
+
+
+class ResBlock_conv(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super(ResBlock_conv, self).__init__()
+ self.conv1 = nn.conv(in_planes, planes, kernel_size=[3,3], stride=[stride,stride], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn1 = nn.batch_norm(planes)
+ self.conv2 = nn.conv(planes, planes, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn2 = nn.batch_norm(planes)
+
+ self.conv_shortcut = nn.conv(in_planes, planes, kernel_size=[1,1], stride=[stride,stride], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn_shortcut = nn.batch_norm(planes)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.bn_shortcut(self.conv_shortcut(x))
+ return out
+
+
+class Resnet20(nn.Module):
+ def __init__(self, num_classes=10):
+ super(Resnet20, self).__init__()
+
+ self.conv1 = nn.conv(3, 64, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn1 = nn.batch_norm(64)
+
+ self.layer1 = ResBlock(64, 64, 1)
+ self.layer2 = ResBlock(64, 64, 1)
+
+ self.layer3 = ResBlock_conv(64, 128, 2)
+ self.layer4 = ResBlock(128, 128, 1)
+
+ self.layer5 = ResBlock_conv(128, 256, 2)
+ self.layer6 = ResBlock(256, 256, 1)
+
+ self.layer7 = ResBlock_conv(256, 512, 2)
+ self.layer8 = ResBlock(512, 512, 1)
+
+ self.fc = nn.linear(512, num_classes)
+
+ def forward(self, x):
+ x = F.relu(self.bn1(self.conv1(x)))
+
+ x = self.layer1.forward(x)
+ x = self.layer2.forward(x)
+ x = self.layer3.forward(x)
+ x = self.layer4.forward(x)
+ x = self.layer5.forward(x)
+ x = self.layer6.forward(x)
+ x = self.layer7.forward(x)
+ x = self.layer8.forward(x)
+ print(x.shape)
+ x = F.avg_pool(x, kernel=[4,4], stride=[4,4])
+ x = F.convert(x, F.NCHW)
+ x = F.reshape(x, [0, -1])
+ x = self.fc(x)
+ out = F.softmax(x, 1)
+ return out
+
+net = Resnet20()
+net.train(True)
+input_var = MNN.expr.placeholder([1, 3, 32, 32], MNN.expr.NCHW)
+predicts = net.forward(input_var)
+# print(predicts)
+F.save([predicts], "resnet18.mnn")
\ No newline at end of file
diff --git a/android/data/mnn_model/resnet18_cifar10.mnn b/android/data/mnn_model/resnet18_cifar10.mnn
new file mode 100644
index 0000000000..2383eaa6c7
Binary files /dev/null and b/android/data/mnn_model/resnet18_cifar10.mnn differ
diff --git a/android/data/mnn_model/resnet20.mnn b/android/data/mnn_model/resnet20.mnn
new file mode 100644
index 0000000000..ed0968f70a
Binary files /dev/null and b/android/data/mnn_model/resnet20.mnn differ
diff --git a/android/data/mnn_model/resnet20.py b/android/data/mnn_model/resnet20.py
new file mode 100644
index 0000000000..883cdd0d63
--- /dev/null
+++ b/android/data/mnn_model/resnet20.py
@@ -0,0 +1,88 @@
+import MNN
+nn = MNN.nn
+F = MNN.expr
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super(ResBlock, self).__init__()
+ self.conv1 = nn.conv(in_planes, planes, kernel_size=[3,3], stride=[stride,stride], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn1 = nn.batch_norm(planes)
+ self.conv2 = nn.conv(planes, planes, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn2 = nn.batch_norm(planes)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += x
+ out = F.relu(out)
+ return out
+
+
+class ResBlock_conv(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super(ResBlock_conv, self).__init__()
+ self.conv1 = nn.conv(in_planes, planes, kernel_size=[3,3], stride=[stride,stride], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn1 = nn.batch_norm(planes)
+ self.conv2 = nn.conv(planes, planes, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn2 = nn.batch_norm(planes)
+
+ self.conv_shortcut = nn.conv(in_planes, planes, kernel_size=[1,1], stride=[stride,stride], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn_shortcut = nn.batch_norm(planes)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.bn_shortcut(self.conv_shortcut(x))
+ return out
+
+
+class Resnet20(nn.Module):
+ def __init__(self, num_classes=10):
+ super(Resnet20, self).__init__()
+
+ self.conv1 = nn.conv(3, 16, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn1 = nn.batch_norm(16)
+
+ self.layer1 = ResBlock(16, 16, 1)
+ self.layer2 = ResBlock(16, 16, 1)
+ self.layer3 = ResBlock(16, 16, 1)
+
+ self.layer4 = ResBlock_conv(16, 32, 2)
+ self.layer5 = ResBlock(32, 32, 1)
+ self.layer6 = ResBlock(32, 32, 1)
+
+ self.layer7 = ResBlock_conv(32, 64, 2)
+ self.layer8 = ResBlock(64, 64, 1)
+ self.layer9 = ResBlock(64, 64, 1)
+
+ self.fc = nn.linear(64, num_classes)
+
+ def forward(self, x):
+ x = F.relu(self.bn1(self.conv1(x)))
+
+ x = self.layer1.forward(x)
+ x = self.layer2.forward(x)
+ x = self.layer3.forward(x)
+ # print(x.shape)
+ x = self.layer4.forward(x)
+ x = self.layer5.forward(x)
+ x = self.layer6.forward(x)
+ # print(x.shape)
+ x = self.layer7.forward(x)
+ x = self.layer8.forward(x)
+ x = self.layer9.forward(x)
+ # print(x.shape)
+ x = F.avg_pool(x, kernel=[8,8], stride=[8,8])
+ x = F.convert(x, F.NCHW)
+ x = F.reshape(x, [0, -1])
+ x = self.fc(x)
+ out = F.softmax(x, 1)
+ return out
+
+net = Resnet20()
+net.train(True)
+input_var = MNN.expr.placeholder([1, 3, 32, 32], MNN.expr.NCHW)
+predicts = net.forward(input_var)
+# print(predicts)
+F.save([predicts], "resnet20.mnn")
\ No newline at end of file
diff --git a/android/data/mnn_model/vgg11.py b/android/data/mnn_model/vgg11.py
new file mode 100644
index 0000000000..7596c0d65f
--- /dev/null
+++ b/android/data/mnn_model/vgg11.py
@@ -0,0 +1,66 @@
+import MNN
+nn = MNN.nn
+F = MNN.expr
+
+
+class ConvBnRelu(nn.Module):
+ def __init__(self, in_planes, planes):
+ super(ConvBnRelu, self).__init__()
+ self.conv = nn.conv(in_planes, planes, kernel_size=[3,3], stride=[1,1], bias=False, padding=[1,1])
+ self.bn = nn.batch_norm(planes)
+
+ def forward(self, x):
+ out = F.relu(self.bn(self.conv(x)))
+ return out
+
+class VGG11(nn.Module):
+ def __init__(self, num_classes=10):
+ super(VGG11, self).__init__()
+ self.conv1 = nn.conv(3, 64, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn1 = nn.batch_norm(64)
+ self.conv2 = nn.conv(64, 128, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn2 = nn.batch_norm(128)
+ self.conv3 = nn.conv(128, 256, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn3 = nn.batch_norm(256)
+ self.conv4 = nn.conv(256, 256, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn4 = nn.batch_norm(256)
+ self.conv5 = nn.conv(256, 512, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn5 = nn.batch_norm(512)
+ self.conv6 = nn.conv(512, 512, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn6 = nn.batch_norm(512)
+ self.conv7 = nn.conv(512, 512, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn7 = nn.batch_norm(512)
+ self.conv8 = nn.conv(512, 512, kernel_size=[3,3], stride=[1,1], padding=[1,1], bias=False, padding_mode=MNN.expr.Padding_Mode.SAME)
+ self.bn8 = nn.batch_norm(512)
+
+ self.fc = nn.linear(512, num_classes)
+
+ def forward(self, x):
+ x = F.relu(self.bn1(self.conv1(x)))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ x = F.relu(self.bn2(self.conv2(x)))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ x = F.relu(self.bn3(self.conv3(x)))
+ x = F.relu(self.bn4(self.conv4(x)))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ x = F.relu(self.bn5(self.conv5(x)))
+ x = F.relu(self.bn6(self.conv6(x)))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ x = F.relu(self.bn7(self.conv7(x)))
+ x = F.relu(self.bn8(self.conv8(x)))
+ x = F.max_pool(x, [2, 2], [2, 2])
+ print(x.shape)
+ # x = F.avg_pool(x, kernel=[1,1], stride=[1,1])
+ # x = F.convert(x, F.NCHW)
+ x = F.reshape(x, [0, -1])
+ x = self.fc(x)
+ x = F.softmax(x, 1)
+
+ return x
+
+net = VGG11()
+net.train(True)
+input_var = MNN.expr.placeholder([1, 3, 32, 32], MNN.expr.NC4HW4)
+predicts = net.forward(input_var)
+# print(predicts)
+F.save([predicts], "vgg11.mnn")
\ No newline at end of file
diff --git a/python/examples/cross_device/mqtt_s3_fedavg_cifar10_lr_example/__init__.py b/android/data/mnn_model/vgg11_cifar10.mnn
similarity index 100%
rename from python/examples/cross_device/mqtt_s3_fedavg_cifar10_lr_example/__init__.py
rename to android/data/mnn_model/vgg11_cifar10.mnn
diff --git a/android/data/prepare.sh b/android/data/prepare.sh
new file mode 100644
index 0000000000..ceb4c571ab
--- /dev/null
+++ b/android/data/prepare.sh
@@ -0,0 +1,17 @@
+MNIST_DIR=mnist
+CIFAR10_DIR=cifar10
+ANDROID_DIR=/sdcard/ai.fedml
+
+rm -rf $MNIST_DIR
+mkdir $MNIST_DIR
+wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -P $MNIST_DIR
+wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -P $MNIST_DIR
+
+rm -rf $CIFAR10_DIR
+rm -rf cifar-10-binary.tar.gz
+wget https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
+tar -xzvf cifar-10-binary.tar.gz
+mv cifar-10-batches-bin $CIFAR10_DIR
+
+adb push $MNIST_DIR $ANDROID_DIR
+adb push $CIFAR10_DIR $ANDROID_DIR
diff --git a/android/data/torch_model/lenet.py b/android/data/torch_model/lenet.py
new file mode 100644
index 0000000000..231e4ec406
--- /dev/null
+++ b/android/data/torch_model/lenet.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LeNet5(nn.Module):
+ def __init__(self):
+ super(LeNet5, self).__init__()
+ self.conv1 = nn.Conv2d(1, 20, 5)
+ self.conv2 = nn.Conv2d(20, 50, 5)
+ self.fc1 = nn.Linear(800, 500)
+ self.fc2 = nn.Linear(500, 10)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ x = F.max_pool2d(x, 2, 2)
+ x = F.relu(self.conv2(x))
+ x = F.max_pool2d(x, 2, 2)
+
+ x = x.view(x.shape[0], -1)
+ x = F.relu(self.fc1(x))
+ x = self.fc2(x)
+ x = F.log_softmax(x, 1)
+ return x
+
+
+model = LeNet5()
+example = torch.rand(1, 1, 28, 28)
+
+# traced_module = torch.jit.trace(model, example)
+# traced_module.save("traced_lenet_model.pt")
+
+scripted_module = torch.jit.script(model, example)
+# scripted_module.save("scripted_lenet_model.pt")
+
+# no optimization but necessary for c++ jit load
+optimized_scripted_module = torch.jit.optimized_execution(scripted_module)
+scripted_module._save_for_lite_interpreter("scripted_lenet_model.ptl")
+
+
+# from torchsummary import summary
+# summary(model, (1, 28, 28), device="cpu")
+
+# for p in model.parameters():
+# print(p.shape)
diff --git a/android/data/torch_model/scripted_lenet_model.ptl b/android/data/torch_model/scripted_lenet_model.ptl
new file mode 100644
index 0000000000..7d2725425e
Binary files /dev/null and b/android/data/torch_model/scripted_lenet_model.ptl differ
diff --git a/android/doc/FedML-Android-Arch.jpg b/android/doc/FedML-Android-Arch.jpg
new file mode 100644
index 0000000000..67f7ee1cb6
Binary files /dev/null and b/android/doc/FedML-Android-Arch.jpg differ
diff --git a/android/doc/android_running.jpeg b/android/doc/android_running.jpeg
new file mode 100644
index 0000000000..747e552f5b
Binary files /dev/null and b/android/doc/android_running.jpeg differ
diff --git a/android/doc/beehive_account.png b/android/doc/beehive_account.png
new file mode 100644
index 0000000000..f86b6ccb69
Binary files /dev/null and b/android/doc/beehive_account.png differ
diff --git a/android/doc/edge_devices_overview.png b/android/doc/edge_devices_overview.png
new file mode 100644
index 0000000000..141cbb6618
Binary files /dev/null and b/android/doc/edge_devices_overview.png differ
diff --git a/android/doc/mobile.png b/android/doc/mobile.png
new file mode 100644
index 0000000000..03d1a5f121
Binary files /dev/null and b/android/doc/mobile.png differ
diff --git a/android/fedmlsdk/.gitignore b/android/fedmlsdk/.gitignore
new file mode 100644
index 0000000000..90539cfbae
--- /dev/null
+++ b/android/fedmlsdk/.gitignore
@@ -0,0 +1,10 @@
+.idea
+.cxx
+build
+src/main/jni/CMakeFiles
+/libs/torch/arm64-v8a/
+/src/main/jni/build_arm_android_64/
+.DS_Store
+secring.gpg
+local.properties
+/.gradle/
diff --git a/android/fedmlsdk/MobileNN/.gitignore b/android/fedmlsdk/MobileNN/.gitignore
new file mode 100644
index 0000000000..4b2251eb5d
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/.gitignore
@@ -0,0 +1,16 @@
+.idea
+.cxx
+
+build/FedMLTrainer/build_x86_linux
+build/FedMLTrainer/build_arm_android_64
+
+build/lightsecagg/build_x86_linux
+build/lightsecagg/build_arm_android_64
+
+build/train/build_x86_linux
+build/train/build_arm_android_64
+.DS_Store
+/build/MNN/build_x86_linux/
+/build/MNN/build_arm_android_64/
+/build/torch/build_x86_linux/
+/build/torch/build_arm_android_64/
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/README.md b/android/fedmlsdk/MobileNN/README.md
new file mode 100644
index 0000000000..dc20c390a3
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/README.md
@@ -0,0 +1,39 @@
+# Mobile MN
+This demo provides CMake build scripts to compile a source code, `demo.cpp`, that trains a LeNet model in MNN framework, into an executable target file `demo.out` along with MNN shared libraries for Linux/macOS and Android platforms.
+
+## Environmental Requirements
+* cmake (version >=3.10 is recommended)
+* protobuf (version >= 3.0 is required)
+* gcc (version >= 4.9 is required)
+
+## Linux/macOS
+1. run `./build_x86_linux.sh`, which will generate the executable file `demo.out` under build_x86_linux folder.
+
+```
+cd build/train
+sh build_x86_linux.sh
+./build_x86_linux/demo.out mnist ../../../data/lenet_mnist.mnn ../../../data/mnist
+
+```
+
+## Android
+1. [Download and Install NDK](https://developer.android.com/ndk/downloads/), latest release version is recommended
+2. Set ANDROID_NDK path at line 3 in `build_arm_android_64.sh`, eg: ANDROID_NDK=/Users/username/path/to/Android-ndk-r14b
+3. run `./build_arm_android_64.sh`, which will generate the executable file `demo.out` under build_arm_android_64 folder.
+4. run `./test_arm_android_64.sh`, which will push `demo.out` to your android device and execute it.
+
+## Notes
+1. You can change CMake compilation options in `build.sh` as needed (i.e. turn off demo/quantools/evaluation/converter/test/benchmark options, turn on openmp/opencl/opengl/vulkan as your backend, set FP16/BF16 low precision mode, and etc). Check [MNN document](https://www.yuque.com/mnn/en/cmake_opts) and [MNN CMakeList](https://github.com/alibaba/MNN/blob/master/CMakeLists.txt) for more information.
+2. MNN compilation artifacts under `build/mnn_binary_dir`
+ * libMNN: Backend Shared Library
+ * libMNNTrain: Training Framework Shared Library
+ * libMNNExpr: Express Training API Shared Library
+3. To run `demo.out` on your linux/macOS machine, first download MNIST dataset from [Google Drive](https://drive.google.com/drive/folders/1IB1-NJgzHSEb7ucgJzM2Gj8QzxpYAjGy?usp=sharing), and run `./demo.out /path/to/data/mnist_data`.
+To run `demo.out` on your android device, adb push mnist_data to your android device under `/data/local/tmp` before running `./test_arm_android_64.sh`.
+
+## Dependency
+MNN:
+https://github.com/FedML-AI/MNN.git
+
+pytorch:
+https://github.com/FedML-AI/pytorch.git
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/README_LSA.md b/android/fedmlsdk/MobileNN/README_LSA.md
new file mode 100644
index 0000000000..3c98dabcf2
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/README_LSA.md
@@ -0,0 +1,11 @@
+```
+cd cpp/build/FedMLTrainer
+sh build_x86_linux.sh
+./build_x86_linux/demo.out mnist ../../../data/lenet_mnist.mnn ../../../data/mnist
+```
+
+```
+cd cpp/build/lightsecagg
+sh build_x86_linux.sh
+./build_x86_linux/encode.out ../../../data/lenet_mnist.mnn
+```
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/build/FedMLTrainer/CMakeLists.txt b/android/fedmlsdk/MobileNN/build/FedMLTrainer/CMakeLists.txt
new file mode 100644
index 0000000000..b4bbc82636
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/FedMLTrainer/CMakeLists.txt
@@ -0,0 +1,41 @@
+cmake_minimum_required(VERSION 3.0)
+
+set (CMAKE_CXX_STANDARD 11)
+set (TARTGET "main_fedml_client_mangaer.out")
+
+# path to MobileNN directory
+set(MOBILENN_HOME_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../MobileNN")
+
+# path to MNN directory
+set(MNN_HOME_DIR "${CMAKE_CURRENT_LIST_DIR}/../../MNN")
+
+# execute cmake in MNN root folder
+add_subdirectory(${MNN_HOME_DIR} mnn_binary_dir)
+
+# Source code path
+file(GLOB_RECURSE DEMO_SRC
+ ${MOBILENN_HOME_DIR}/src/MNN/*.cpp
+ ${MOBILENN_HOME_DIR}/src/train/*.cpp
+ ${MOBILENN_HOME_DIR}/src/FedMLClientManager.cpp
+ ${MOBILENN_HOME_DIR}/src/main_FedMLClientManager.cpp)
+
+# link libraries and include directories
+add_executable(${TARTGET} ${DEMO_SRC})
+target_link_libraries(${TARTGET} PRIVATE ${MNN_DEPS})
+target_link_libraries(${TARTGET} PRIVATE MNNTrain)
+
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/include)
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/grad)
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/optimizer)
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/transformer)
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/data)
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/nn)
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/models)
+target_include_directories(${TARTGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/datasets)
+
+target_include_directories(${TARTGET} PRIVATE
+ ${MOBILENN_HOME_DIR}/includes
+ ${MOBILENN_HOME_DIR}/includes/MNN
+ ${MOBILENN_HOME_DIR}/includes/security
+ ${MOBILENN_HOME_DIR}/includes/train)
+
diff --git a/android/fedmlsdk/MobileNN/build/FedMLTrainer/build_arm_android_64.sh b/android/fedmlsdk/MobileNN/build/FedMLTrainer/build_arm_android_64.sh
new file mode 100755
index 0000000000..4c24bc0e53
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/FedMLTrainer/build_arm_android_64.sh
@@ -0,0 +1,31 @@
+#!/usr/bin/env bash
+
+# ANDROID_NDK=/Users/leigao/Library/Android/sdk/ndk/22.1.7171670
+ANDROID_NDK=/Users/chaoyanghe/Library/Android/sdk/ndk/24.0.8215888
+
+BUILD_ROOT=`pwd`
+
+function make_or_clean_dir {
+ if [ -d $1 ]; then
+ rm -rf $1/*
+ else
+ mkdir $1
+ fi
+}
+
+make_or_clean_dir build_arm_android_64 && cd build_arm_android_64
+cmake .. \
+ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DANDROID_ABI="arm64-v8a" \
+ -DANDROID_STL=c++_static \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DANDROID_NATIVE_API_LEVEL=android-32 \
+ -DANDROID_TOOLCHAIN=clang \
+ -DMNN_USE_LOGCAT=true \
+ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \
+ -DMNN_BUILD_TRAIN=ON \
+ -DNATIVE_LIBRARY_OUTPUT=. \
+ -DNATIVE_INCLUDE_OUTPUT=. || exit 1;
+make -j16 || exit 1;
+cd $BUILD_ROOT
diff --git a/android/fedmlsdk/MobileNN/build/FedMLTrainer/build_x86_linux.sh b/android/fedmlsdk/MobileNN/build/FedMLTrainer/build_x86_linux.sh
new file mode 100755
index 0000000000..b9fa7470a3
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/FedMLTrainer/build_x86_linux.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+
+BUILD_ROOT=`pwd`
+
+function make_or_clean_dir {
+ if [ -d $1 ]; then
+# rm -rf $1/*
+ echo "incremental compilation"
+ else
+ mkdir $1
+ fi
+}
+
+make_or_clean_dir build_x86_linux && cd build_x86_linux
+cmake .. -DMNN_BUILD_TRAIN=ON || exit 1;
+make -j16 || exit 1;
+
diff --git a/android/fedmlsdk/MobileNN/build/MNN/CMakeLists.txt b/android/fedmlsdk/MobileNN/build/MNN/CMakeLists.txt
new file mode 100644
index 0000000000..35714d1297
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/MNN/CMakeLists.txt
@@ -0,0 +1,44 @@
+cmake_minimum_required(VERSION 3.0)
+
+set (CMAKE_CXX_STANDARD 11)
+set (TARGET "main_mnn_train.out")
+
+# path to MobileNN directory
+set(MOBILENN_HOME_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../MobileNN")
+
+# enable MNN backend
+add_definitions(-DUSE_MNN_BACKEND)
+
+# path to MNN directory
+set(MNN_HOME_DIR "${MOBILENN_HOME_DIR}/MNN")
+
+# execute cmake in MNN root folder
+add_subdirectory(${MNN_HOME_DIR} mnn_binary_dir)
+
+# Source code path
+file(GLOB_RECURSE DEMO_SRC
+ ${MOBILENN_HOME_DIR}/src/MNN/*.cpp
+ ${MOBILENN_HOME_DIR}/src/train/FedMLBaseTrainer.cpp
+ ${MOBILENN_HOME_DIR}/src/train/FedMLMNNTrainer.cpp
+ ${MOBILENN_HOME_DIR}/src/train/FedMLTrainer.cpp
+ ${MOBILENN_HOME_DIR}/src/main_MNN_train.cpp)
+
+# link libraries and include directories
+add_executable(${TARGET} ${DEMO_SRC})
+target_link_libraries(${TARGET} PRIVATE ${MNN_DEPS})
+target_link_libraries(${TARGET} PRIVATE MNNTrain)
+
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/include)
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/grad)
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/optimizer)
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/transformer)
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/data)
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/nn)
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/models)
+target_include_directories(${TARGET} PRIVATE ${MNN_HOME_DIR}/tools/train/source/datasets)
+
+target_include_directories(${TARGET} PRIVATE
+ ${MOBILENN_HOME_DIR}/includes
+ ${MOBILENN_HOME_DIR}/includes/MNN
+ ${MOBILENN_HOME_DIR}/includes/train)
+
diff --git a/android/fedmlsdk/MobileNN/build/MNN/build_arm_android_64.sh b/android/fedmlsdk/MobileNN/build/MNN/build_arm_android_64.sh
new file mode 100755
index 0000000000..4c24bc0e53
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/MNN/build_arm_android_64.sh
@@ -0,0 +1,31 @@
+#!/usr/bin/env bash
+
+# ANDROID_NDK=/Users/leigao/Library/Android/sdk/ndk/22.1.7171670
+ANDROID_NDK=/Users/chaoyanghe/Library/Android/sdk/ndk/24.0.8215888
+
+BUILD_ROOT=`pwd`
+
+function make_or_clean_dir {
+ if [ -d $1 ]; then
+ rm -rf $1/*
+ else
+ mkdir $1
+ fi
+}
+
+make_or_clean_dir build_arm_android_64 && cd build_arm_android_64
+cmake .. \
+ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DANDROID_ABI="arm64-v8a" \
+ -DANDROID_STL=c++_static \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DANDROID_NATIVE_API_LEVEL=android-32 \
+ -DANDROID_TOOLCHAIN=clang \
+ -DMNN_USE_LOGCAT=true \
+ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \
+ -DMNN_BUILD_TRAIN=ON \
+ -DNATIVE_LIBRARY_OUTPUT=. \
+ -DNATIVE_INCLUDE_OUTPUT=. || exit 1;
+make -j16 || exit 1;
+cd $BUILD_ROOT
diff --git a/android/fedmlsdk/MobileNN/build/MNN/build_x86_linux.sh b/android/fedmlsdk/MobileNN/build/MNN/build_x86_linux.sh
new file mode 100755
index 0000000000..b9fa7470a3
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/MNN/build_x86_linux.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+
+BUILD_ROOT=`pwd`
+
+function make_or_clean_dir {
+ if [ -d $1 ]; then
+# rm -rf $1/*
+ echo "incremental compilation"
+ else
+ mkdir $1
+ fi
+}
+
+make_or_clean_dir build_x86_linux && cd build_x86_linux
+cmake .. -DMNN_BUILD_TRAIN=ON || exit 1;
+make -j16 || exit 1;
+
diff --git a/android/fedmlsdk/MobileNN/build/MNN/release_libs_to_android.sh b/android/fedmlsdk/MobileNN/build/MNN/release_libs_to_android.sh
new file mode 100644
index 0000000000..18469b909a
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/MNN/release_libs_to_android.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+
+cp -rf ./build_arm_android_64/mnn_binary_dir/libMNN.so ../../../libs/MNN/arm64-v8a
+cp -rf ./build_arm_android_64/mnn_binary_dir/libMNN_Express.so ../../../libs/MNN/arm64-v8a
+cp -rf ./build_arm_android_64/mnn_binary_dir/tools/train/libMNNTrain.so ../../../libs/MNN/arm64-v8a
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/build/torch/CMakeLists.txt b/android/fedmlsdk/MobileNN/build/torch/CMakeLists.txt
new file mode 100644
index 0000000000..b17b6de953
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/torch/CMakeLists.txt
@@ -0,0 +1,62 @@
+cmake_minimum_required(VERSION 3.0)
+
+option(BUILD_ANDROID "Build for Android" OFF)
+
+set(CMAKE_CXX_STANDARD 14)
+set(CMAKE_THREAD_LIBS_INIT "-lpthread")
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
+set(CMAKE_HAVE_THREADS_LIBRARY 1)
+set(CMAKE_USE_WIN32_THREADS_INIT 0)
+set(CMAKE_USE_PTHREADS_INIT 1)
+set(THREADS_PREFER_PTHREAD_FLAG ON)
+
+set(TARGET "main_torch_train.out")
+
+# path to MobileNN directory
+set(MOBILENN_HOME_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../MobileNN")
+
+# enable TORCH backend
+add_definitions(-DUSE_TORCH_BACKEND)
+
+# path to pytorch directory
+set(PYTORCH_HOME_DIR "${MOBILENN_HOME_DIR}/pytorch")
+
+# Source code path
+file(GLOB_RECURSE DEMO_SRC
+ ${MOBILENN_HOME_DIR}/src/torch/*.cpp
+ ${MOBILENN_HOME_DIR}/src/train/FedMLBaseTrainer.cpp
+ ${MOBILENN_HOME_DIR}/src/train/FedMLTorchTrainer.cpp
+ ${MOBILENN_HOME_DIR}/src/train/FedMLTrainer.cpp
+ ${MOBILENN_HOME_DIR}/src/main_torch_train.cpp)
+
+# link libraries and include directories
+add_executable(${TARGET} ${DEMO_SRC})
+target_include_directories(${TARGET} PRIVATE
+ ${MOBILENN_HOME_DIR}/includes
+ ${MOBILENN_HOME_DIR}/includes/torch
+ ${MOBILENN_HOME_DIR}/includes/train)
+
+if(${BUILD_ANDROID})
+ target_include_directories(${TARGET} PRIVATE
+ ${PYTORCH_HOME_DIR}/build_android/install/include
+ ${PYTORCH_HOME_DIR}/build_android/install/include/torch/csrc/api/include
+ ${PYTORCH_HOME_DIR}/aten/src
+ ${PYTORCH_HOME_DIR}/include)
+ target_link_libraries(${TARGET} PRIVATE
+ ${PYTORCH_HOME_DIR}/build_android/install/lib/libc10.so
+ ${PYTORCH_HOME_DIR}/build_android/install/lib/libtorch_cpu.so
+ ${PYTORCH_HOME_DIR}/build_android/install/lib/libtorch_global_deps.so
+ ${PYTORCH_HOME_DIR}/build_android/install/lib/libtorch.so
+ log)
+else()
+ target_include_directories(${TARGET} PRIVATE
+ ${PYTORCH_HOME_DIR}/build_mobile/install/include
+ ${PYTORCH_HOME_DIR}/build_mobile/install/include/torch/csrc/api/include
+ ${PYTORCH_HOME_DIR}/aten/src
+ ${PYTORCH_HOME_DIR}/include)
+ target_link_libraries(${TARGET} PRIVATE
+ ${PYTORCH_HOME_DIR}/build_mobile/install/lib/libc10.dylib
+ ${PYTORCH_HOME_DIR}/build_mobile/install/lib/libtorch_global_deps.dylib
+ ${PYTORCH_HOME_DIR}/build_mobile/install/lib/libtorch.dylib
+ ${PYTORCH_HOME_DIR}/build_mobile/install/lib/libtorch_cpu.dylib)
+endif()
diff --git a/android/fedmlsdk/MobileNN/build/torch/build_arm_android_64.sh b/android/fedmlsdk/MobileNN/build/torch/build_arm_android_64.sh
new file mode 100755
index 0000000000..c8b41201df
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/torch/build_arm_android_64.sh
@@ -0,0 +1,33 @@
+#!/usr/bin/env bash
+
+# use ndk 21.x
+export ANDROID_NDK=/Users/leigao/Library/Android/sdk/ndk/21.4.7075529
+export ANDROID_ABI=arm64-v8a
+export BUILD_LITE_INTERPRETER=0
+
+# build pytorch
+if [ ! -d "./../../pytorch/build_android" ]; then
+bash ./../../pytorch/scripts/build_android.sh || exit 1;
+fi
+
+function make_or_clean_dir {
+ if [ -d $1 ]; then
+# rm -rf $1/*
+ echo "incremental compilation"
+ else
+ mkdir $1
+ fi
+}
+
+# build our source code
+make_or_clean_dir build_arm_android_64 && cd build_arm_android_64
+cmake .. \
+ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DANDROID_ABI="arm64-v8a" \
+ -DANDROID_STL=c++_static \
+ -DANDROID_NATIVE_API_LEVEL=android-21 \
+ -DANDROID_TOOLCHAIN=clang \
+ -DBUILD_ANDROID=ON || exit 1;
+make -j16 || exit 1;
+
diff --git a/android/fedmlsdk/MobileNN/build/torch/build_x86_linux.sh b/android/fedmlsdk/MobileNN/build/torch/build_x86_linux.sh
new file mode 100755
index 0000000000..443c255dd9
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/torch/build_x86_linux.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+
+# build pytorch
+if [ ! -d "./../../pytorch/build_mobile" ]; then
+bash ./../../pytorch/scripts/build_mobile.sh || exit 1;
+fi
+
+function make_or_clean_dir {
+ if [ -d $1 ]; then
+# rm -rf $1/*
+ echo "incremental compilation"
+ else
+ mkdir $1
+ fi
+}
+
+# build our source code
+make_or_clean_dir build_x86_linux && cd build_x86_linux
+cmake .. || exit 1;
+make -j16 || exit 1;
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/build/torch/release_libs_to_android.sh b/android/fedmlsdk/MobileNN/build/torch/release_libs_to_android.sh
new file mode 100644
index 0000000000..69110eb8ff
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/build/torch/release_libs_to_android.sh
@@ -0,0 +1,3 @@
+#!/usr/bin/env bash
+
+cp -rf ../../pytorch/build_android/lib/*.so ../../../libs/torch/arm64-v8a
diff --git a/android/fedmlsdk/MobileNN/includes/FedMLClientManager.h b/android/fedmlsdk/MobileNN/includes/FedMLClientManager.h
new file mode 100644
index 0000000000..4e1e2bbcc0
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/FedMLClientManager.h
@@ -0,0 +1,43 @@
+#ifndef FEDML_ANDROID_FEDMLCLIENTMANAGER_H
+#define FEDML_ANDROID_FEDMLCLIENTMANAGER_H
+
+#include "FedMLTrainer.h"
+
+class FedMLClientManager {
+
+public:
+ FedMLClientManager();
+
+ ~FedMLClientManager();
+
+ void init(const char *modelCachePath, const char *dataCachePath, const char *dataSet,
+ int trainSize, int testSize, int batchSizeNum, double LearningRate, int epochNum,
+ progressCallback progress_callback,
+ accuracyCallback accuracy_callback,
+ lossCallback loss_callback);
+
+ std::string train();
+
+ /**
+ * the local epoch index in each global epoch training, and the training loss in this local epoch
+ *
+ * @return current epoch and the loss value in this epoch (format: "epoch,loss")
+ */
+ std::string getEpochAndLoss();
+
+ /**
+ * Stop the current training
+ *
+ * @return success
+ */
+ bool stopTraining();
+
+private:
+ FedMLBaseTrainer *mFedMLTrainer;
+
+// std::string m_modelCachePath;
+// std::string m_dataSet;
+};
+
+
+#endif //FEDML_ANDROID_FEDMLCLIENTMANAGER_H
diff --git a/android/fedmlsdk/MobileNN/includes/FedMLClientManagerSA.h b/android/fedmlsdk/MobileNN/includes/FedMLClientManagerSA.h
new file mode 100644
index 0000000000..f168ce6787
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/FedMLClientManagerSA.h
@@ -0,0 +1,82 @@
+#ifndef FEDML_ANDROID_FEDMLCLIENTMANAGERSA_H
+#define FEDML_ANDROID_FEDMLCLIENTMANAGERSA_H
+
+#include "train/FedMLTrainerSA.h"
+#include "LightSecAggForMNN.h"
+
+class FedMLClientManagerSA {
+
+public:
+ FedMLClientManagerSA();
+
+ ~FedMLClientManagerSA();
+
+ void init(const char *modelCachePath, const char *dataCachePath, const char *dataSet,
+ int trainSize, int testSize,
+ int batchSizeNum, double LearningRate, int epochNum,
+ int q_bits, int p, int client_num,
+ progressCallback progress_callback,
+ accuracyCallback accuracy_callback,
+ lossCallback loss_callback);
+
+ /**
+ * generate local mask and encode mask to share with other users
+ */
+ std::vector > get_local_encoded_mask();
+
+ /**
+ * receive other mask from surviving users
+ */
+ void save_mask_from_paired_clients(int client_index,
+ std::vector local_encode_mask);
+
+
+ /**
+ * receive client index from surviving users
+ */
+ std::vector get_client_IDs_that_have_sent_mask();
+
+
+ std::string train();
+
+ /**
+ * get masked model after the local training is done
+ * the model file is saved at the original path "modelCachePath"
+ */
+ void generate_masked_model();
+
+ /**
+ * the server will ask those clients that are online to send aggregated encoded masks
+ */
+ std::vector get_aggregated_encoded_mask(std::vector surviving_list_from_server);
+
+ /**
+ * the local epoch index in each global epoch training, and the training loss in this local epoch
+ *
+ * @return current epoch and the loss value in this epoch (format: "epoch,loss")
+ */
+ std::string getEpochAndLoss();
+
+ /**
+ * Stop the current training
+ *
+ * @return success
+ */
+ bool stopTraining();
+
+ /**
+ * print MNN variables
+ */
+ void printMNNVar(VARP x);
+
+private:
+ FedMLTrainerSA *mFedMLTrainer;
+ LightSecAggForMNN *mLightSecAggForMNN;
+
+ VARPS m_local_mask;
+ std::string m_modelCachePath;
+ std::string m_dataSet;
+};
+
+
+#endif //FEDML_ANDROID_FEDMLCLIENTMANAGERSA_H
diff --git a/android/fedmlsdk/MobileNN/includes/MNN/cifar10.h b/android/fedmlsdk/MobileNN/includes/MNN/cifar10.h
new file mode 100644
index 0000000000..c79b5075f2
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/MNN/cifar10.h
@@ -0,0 +1,33 @@
+#ifndef Cifar10Dateset_hpp
+#define Cifar10Dateset_hpp
+
+#include
+#include "Dataset.hpp"
+#include "Example.hpp"
+
+namespace MNN {
+namespace Train {
+class MNN_PUBLIC Cifar10Dataset : public Dataset {
+public:
+ enum Mode { TRAIN, TEST };
+
+ Example get(size_t index) override;
+
+ size_t size() override;
+
+ const VARP images();
+
+ const VARP labels();
+
+ static DatasetPtr create(const std::string path, Mode mode = Mode::TRAIN, int32_t trainSize = 50000, int32_t testSize = 10000);
+private:
+ explicit Cifar10Dataset(const std::string path, Mode mode = Mode::TRAIN);
+ VARP mImages, mLabels;
+ const uint8_t* mImagePtr = nullptr;
+ const uint8_t* mLabelsPtr = nullptr;
+};
+}
+}
+
+
+#endif // Cifar10Dateset_hpp
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/includes/MNN/mnist.h b/android/fedmlsdk/MobileNN/includes/MNN/mnist.h
new file mode 100644
index 0000000000..a793eecb2f
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/MNN/mnist.h
@@ -0,0 +1,41 @@
+//
+// MnistDataset.hpp
+// MNN
+//
+// Created by MNN on 2019/11/15.
+// Copyright Š 2018, Alibaba Group Holding Limited
+//
+
+#ifndef MnistDataset_hpp
+#define MnistDataset_hpp
+
+#include
+#include "Dataset.hpp"
+#include "Example.hpp"
+
+namespace MNN {
+namespace Train {
+class MNN_PUBLIC MnistDataset : public Dataset {
+public:
+ enum Mode { TRAIN, TEST };
+
+ Example get(size_t index) override;
+
+ size_t size() override;
+
+ const VARP images();
+
+ const VARP labels();
+
+ static DatasetPtr create(const std::string path, Mode mode = Mode::TRAIN, int32_t trainSize = 60000, int32_t testSize = 10000);
+private:
+ explicit MnistDataset(const std::string path, Mode mode = Mode::TRAIN);
+ VARP mImages, mLabels;
+ const uint8_t* mImagePtr = nullptr;
+ const uint8_t* mLabelsPtr = nullptr;
+};
+}
+}
+
+
+#endif // MnistDataset_hpp
diff --git a/android/fedmlsdk/MobileNN/includes/security/LightSecAgg.h b/android/fedmlsdk/MobileNN/includes/security/LightSecAgg.h
new file mode 100644
index 0000000000..a853f8038f
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/security/LightSecAgg.h
@@ -0,0 +1,35 @@
+#ifndef LIGHTSECAGG_CPP_MPC_FUNC_H
+#define LIGHTSECAGG_CPP_MPC_FUNC_H
+
+#include
+#include
+#include
+#include
+#include
+
+
+class LightSecAgg {
+
+public:
+ std::vector> LCC_encoding_with_points(std::vector> const &X,
+ std::vector const &alpha_s,
+ std::vector const &beta_s, int p);
+
+
+ std::vector> LCC_decoding_with_points(std::vector> f_eval,
+ std::vector eval_points,
+ std::vector target_points, int p);
+private:
+ int modInverse(int a, int p);
+
+ int modDivide(int num, int den, int p);
+
+ int PI(std::vector vals, int p);
+
+ std::vector> gen_Lagrange_coeffs(std::vector const &alpha_s,
+ std::vector const &beta_s,
+ int p, int is_K1 = 0);
+};
+
+
+#endif //LIGHTSECAGG_CPP_MPC_FUNC_H
diff --git a/android/fedmlsdk/MobileNN/includes/security/LightSecAggForMNN.h b/android/fedmlsdk/MobileNN/includes/security/LightSecAggForMNN.h
new file mode 100644
index 0000000000..5331965b84
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/security/LightSecAggForMNN.h
@@ -0,0 +1,69 @@
+#ifndef FEDML_ANDROID_LIGHTSECAGGFORMNN_H
+#define FEDML_ANDROID_LIGHTSECAGGFORMNN_H
+
+
+#include
+#include
+#include
+#include
+#include "mnist.h"
+#include "cifar10.h"
+#include
+#include "SGD.hpp"
+#include
+#include "Loss.hpp"
+#include "LearningRateScheduler.hpp"
+#include "Transformer.hpp"
+#include "NN.hpp"
+#include "LightSecAgg.h"
+#include
+#include
+
+class LightSecAggForMNN {
+
+public:
+ void init(int q_bits, int p, int client_num);
+
+ static void printVar(VARP x);
+
+ VARPS mask_generate(const char *modelCachePath);
+
+ std::vector > local_mask_encoding(VARPS model_mask);
+
+ void MNN_encode(const char *modelCachePath, const char *dataSet, VARPS model_mask);
+
+ std::vector mask_agg(std::vector surviving_list_from_server);
+
+ void save_mask_from_paired_clients(int client_index,
+ std::vector local_encode_mask);
+
+ std::vector get_client_IDs_that_have_sent_mask();
+
+private:
+
+ VARP my_q(VARP const &X);
+
+ VARPS transform_tensor_to_finite(VARPS const &model_params);
+
+ VARPS generate_random_mask(VARPS const &model_params);
+
+ void model_masking(VARPS &weights_finite, VARPS const &local_mask, int prime_number);
+
+ std::vector mask_transform(VARPS model_mask);
+
+ std::vector >
+ mask_encoding(int num_clients, int prime_number, std::vector const &local_mask);
+
+ std::vector
+ z_tilde_sum(std::vector > const &z_tilde_buffer, std::vector const &sur_list);
+
+private:
+ int m_q_bits;
+ int m_p;
+ int m_client_num;
+
+ std::vector > m_local_received_mask_from_other_clients;
+ std::vector m_surviving_clients;
+};
+
+#endif //FEDML_ANDROID_LIGHTSECAGGFORMNN_H
diff --git a/android/fedmlsdk/MobileNN/includes/torch/cifar10.h b/android/fedmlsdk/MobileNN/includes/torch/cifar10.h
new file mode 100644
index 0000000000..8317add530
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/torch/cifar10.h
@@ -0,0 +1,33 @@
+#include
+
+#include
+
+namespace torch {
+namespace data {
+namespace datasets {
+class CIFAR10 : public torch::data::datasets::Dataset {
+ public:
+ // The mode in which the dataset is loaded.
+ enum class Mode { kTrain, kTest };
+
+ explicit CIFAR10(const std::string &root, uint32_t kSize, Mode mode = Mode::kTrain);
+
+ // https://pytorch.org/cppdocs/api/structtorch_1_1data_1_1_example.html#structtorch_1_1data_1_1_example
+ torch::data::Example<> get(size_t index) override;
+
+ torch::optional size() const override;
+
+ bool is_train() const noexcept;
+
+ // Returns all images stacked into a single tensor.
+ const torch::Tensor &images() const;
+
+ const torch::Tensor &targets() const;
+
+ private:
+ // Returns all targets stacked into a single tensor.
+ torch::Tensor images_, targets_;
+};
+} // namespace datasets
+} // namespace data
+} // namespace torch
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/includes/torch/mnist.h b/android/fedmlsdk/MobileNN/includes/torch/mnist.h
new file mode 100644
index 0000000000..960e665182
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/torch/mnist.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+
+/// The MNIST dataset.
+class TORCH_API MNIST : public torch::data::datasets::Dataset {
+ public:
+ /// The mode in which the dataset is loaded.
+ enum class Mode { kTrain, kTest };
+
+ /// Loads the MNIST dataset from the `root` path.
+ ///
+ /// The supplied `root` path should contain the *content* of the unzipped
+ /// MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
+ explicit MNIST(const std::string& root, uint32_t kSize, Mode mode = Mode::kTrain);
+
+ /// Returns the `Example` at the given `index`.
+ torch::data::Example<> get(size_t index) override;
+
+ /// Returns the size of the dataset.
+ torch::optional size() const override;
+
+ /// Returns true if this is the training subset of MNIST.
+ // NOLINTNEXTLINE(bugprone-exception-escape)
+ // bool is_train() const noexcept;
+
+ /// Returns all images stacked into a single tensor.
+ const torch::Tensor& images() const;
+
+ /// Returns all targets stacked into a single tensor.
+ const torch::Tensor& targets() const;
+
+ private:
+ torch::Tensor images_, targets_;
+};
diff --git a/android/fedmlsdk/MobileNN/includes/train/FedMLBaseTrainer.h b/android/fedmlsdk/MobileNN/includes/train/FedMLBaseTrainer.h
new file mode 100644
index 0000000000..bc9ffecf52
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/train/FedMLBaseTrainer.h
@@ -0,0 +1,49 @@
+#ifndef FEDML_ANDROID_FEDMLBASETRAINER_H
+#define FEDML_ANDROID_FEDMLBASETRAINER_H
+
+#include
+#include
+
+typedef std::function progressCallback;
+
+typedef std::function accuracyCallback;
+
+typedef std::function lossCallback;
+
+class FedMLBaseTrainer {
+
+public:
+ void init(const char *modelCachePath, const char *dataCachePath,
+ const char *dataSet, int trainSize, int testSize,
+ int batchSizeNum, double learningRate, int epochNum,
+ progressCallback progress_callback,
+ accuracyCallback accuracy_callback,
+ lossCallback loss_callback);
+
+ virtual std::string train() {return nullptr;};
+
+ std::string getEpochAndLoss();
+
+ bool stopTraining();
+
+protected:
+ std::string m_modelCachePath;
+ std::string m_dataCachePath;
+ std::string m_dataSet;
+ int m_trainSize;
+ int m_testSize;
+ int m_batchSizeNum;
+ double m_LearningRate;
+ int m_epochNum;
+
+ int curEpoch = 0;
+ float curLoss = 0.0;
+ bool bRunStopFlag = false;
+
+ progressCallback m_progress_callback;
+ accuracyCallback m_accuracy_callback;
+ lossCallback m_loss_callback;
+};
+
+
+#endif //FEDML_ANDROID_FEDMLBASETRAINER_H
diff --git a/android/fedmlsdk/MobileNN/includes/train/FedMLMNNTrainer.h b/android/fedmlsdk/MobileNN/includes/train/FedMLMNNTrainer.h
new file mode 100644
index 0000000000..8a8a0055a2
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/train/FedMLMNNTrainer.h
@@ -0,0 +1,25 @@
+#ifndef FEDML_ANDROID_FEDMLMNNTRAINER_H
+#define FEDML_ANDROID_FEDMLMNNTRAINER_H
+
+#include "FedMLBaseTrainer.h"
+#include "mnist.h"
+#include "cifar10.h"
+#include
+#include "DataLoader.hpp"
+#include "SGD.hpp"
+#include "Loss.hpp"
+#include "LearningRateScheduler.hpp"
+#include "Transformer.hpp"
+#include "NN.hpp"
+
+using namespace MNN;
+using namespace MNN::Express;
+using namespace MNN::Train;
+
+class FedMLMNNTrainer: public FedMLBaseTrainer {
+ public:
+ std::string train() override;
+};
+
+
+#endif //FEDML_ANDROID_FEDMLMNNTRAINER_H
diff --git a/android/fedmlsdk/MobileNN/includes/train/FedMLTorchTrainer.h b/android/fedmlsdk/MobileNN/includes/train/FedMLTorchTrainer.h
new file mode 100644
index 0000000000..6e7a2a9cb5
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/train/FedMLTorchTrainer.h
@@ -0,0 +1,22 @@
+#ifndef FEDML_ANDROID_FEDMLTORCHTRAINER_H
+#define FEDML_ANDROID_FEDMLTORCHTRAINER_H
+
+#include "FedMLBaseTrainer.h"
+#include
+#include "cifar10.h"
+#include "mnist.h"
+
+#include
+#include
+#include
+#include
+#include
+
+
+class FedMLTorchTrainer : public FedMLBaseTrainer {
+ public:
+ std::string train() override;
+};
+
+
+#endif //FEDML_ANDROID_FEDMLTORCHTRAINER_H
diff --git a/android/fedmlsdk/MobileNN/includes/train/FedMLTrainer.h b/android/fedmlsdk/MobileNN/includes/train/FedMLTrainer.h
new file mode 100644
index 0000000000..0df9302f8b
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/train/FedMLTrainer.h
@@ -0,0 +1,24 @@
+#ifndef FEDML_ANDROID_FEDMLTRAINER_H
+#define FEDML_ANDROID_FEDMLTRAINER_H
+
+#include "FedMLBaseTrainer.h"
+
+#ifdef USE_MNN_BACKEND
+#include "FedMLMNNTrainer.h"
+#endif
+#ifdef USE_TORCH_BACKEND
+#include "FedMLTorchTrainer.h"
+#endif
+
+
+class FedMLTrainer {
+ public:
+ FedMLTrainer();
+ FedMLBaseTrainer* getTrainer() {return m_trainer;}
+ std::string getEpochAndLoss() {return m_trainer->getEpochAndLoss();}
+
+ private:
+ FedMLBaseTrainer* m_trainer;
+};
+
+#endif //FEDML_ANDROID_FEDMLTRAINER_H
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/includes/train/FedMLTrainerSA.h b/android/fedmlsdk/MobileNN/includes/train/FedMLTrainerSA.h
new file mode 100644
index 0000000000..b10f8e524d
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/includes/train/FedMLTrainerSA.h
@@ -0,0 +1,55 @@
+#ifndef FEDML_ANDROID_FEDMLTRAINERSA_H
+#define FEDML_ANDROID_FEDMLTRAINERSA_H
+
+#include
+#include
+#include "mnist.h"
+#include "cifar10.h"
+#include
+#include "SGD.hpp"
+#include
+#include "Loss.hpp"
+#include "LearningRateScheduler.hpp"
+#include "Transformer.hpp"
+#include "NN.hpp"
+#include "LightSecAggForMNN.h"
+
+typedef std::function progressCallback;
+
+typedef std::function accuracyCallback;
+
+typedef std::function lossCallback;
+
+class FedMLTrainerSA {
+
+public:
+ void init(const char *modelCachePath, const char *dataCachePath,
+ const char *dataSet, int trainSize,
+ int testSize, int batchSizeNum, double learningRate, int epochNum,
+ progressCallback progress_callback,
+ accuracyCallback accuracy_callback,
+ lossCallback loss_callback);
+
+ std::string train();
+
+ std::string getEpochAndLoss();
+
+ bool stopTraining();
+
+private:
+ std::string m_modelCachePath;
+ std::string m_dataCachePath;
+ std::string m_dataSet;
+ int m_trainSize;
+ int m_testSize;
+ int m_batchSizeNum;
+ double m_LearningRate;
+ int m_epochNum;
+
+ progressCallback m_progress_callback;
+ accuracyCallback m_accuracy_callback;
+ lossCallback m_loss_callback;
+};
+
+
+#endif //FEDML_ANDROID_FEDMLTRAINERSA_H
diff --git a/android/fedmlsdk/MobileNN/src/FedMLClientManager.cpp b/android/fedmlsdk/MobileNN/src/FedMLClientManager.cpp
new file mode 100644
index 0000000000..560c49f0b2
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/FedMLClientManager.cpp
@@ -0,0 +1,38 @@
+#include "FedMLClientManager.h"
+
+FedMLClientManager::FedMLClientManager() {
+ this->mFedMLTrainer = FedMLTrainer().getTrainer(); // input 1 to use MNN as backend
+}
+
+FedMLClientManager::~FedMLClientManager() {
+ delete mFedMLTrainer;
+}
+
+void FedMLClientManager::init(const char *modelCachePath, const char *dataCachePath, const char *dataSet,
+ int trainSize, int testSize,
+ int batchSizeNum, double LearningRate, int epochNum,
+ progressCallback progress_callback,
+ accuracyCallback accuracy_callback,
+ lossCallback loss_callback) {
+ this->mFedMLTrainer->init(modelCachePath, dataCachePath, dataSet, trainSize, testSize,
+ batchSizeNum, LearningRate, epochNum,
+ progress_callback, accuracy_callback, loss_callback);
+
+// this->m_modelCachePath = modelCachePath;
+// this->m_dataSet = dataSet;
+}
+
+std::string FedMLClientManager::train() {
+ std::string result = this->mFedMLTrainer->train();
+ return result;
+}
+
+std::string FedMLClientManager::getEpochAndLoss() {
+ std::string result = this->mFedMLTrainer->getEpochAndLoss();
+ return result;
+}
+
+bool FedMLClientManager::stopTraining() {
+ bool result = this->mFedMLTrainer->stopTraining();
+ return result;
+}
diff --git a/android/fedmlsdk/MobileNN/src/FedMLClientManagerSA.cpp b/android/fedmlsdk/MobileNN/src/FedMLClientManagerSA.cpp
new file mode 100644
index 0000000000..ebf63f78d2
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/FedMLClientManagerSA.cpp
@@ -0,0 +1,80 @@
+#include "FedMLClientManagerSA.h"
+
+using namespace MNN;
+using namespace MNN::Express;
+using namespace MNN::Train;
+
+FedMLClientManagerSA::FedMLClientManagerSA() {
+ this->mFedMLTrainer = new FedMLTrainerSA();
+ this->mLightSecAggForMNN = new LightSecAggForMNN();
+}
+
+FedMLClientManagerSA::~FedMLClientManagerSA() {
+ delete mFedMLTrainer;
+ delete mLightSecAggForMNN;
+}
+
+void FedMLClientManagerSA::init(const char *modelCachePath, const char *dataCachePath, const char *dataSet,
+ int trainSize, int testSize,
+ int batchSizeNum, double LearningRate, int epochNum,
+ int q_bits, int p, int client_num,
+ progressCallback progress_callback,
+ accuracyCallback accuracy_callback,
+ lossCallback loss_callback) {
+ this->mFedMLTrainer->init(modelCachePath, dataCachePath, dataSet, trainSize, testSize,
+ batchSizeNum, LearningRate, epochNum,
+ progress_callback, accuracy_callback, loss_callback);
+
+ this->mLightSecAggForMNN->init(q_bits, p, client_num);
+
+ this->m_modelCachePath = modelCachePath;
+ this->m_dataSet = dataSet;
+}
+
+std::vector > FedMLClientManagerSA::get_local_encoded_mask() {
+ this->m_local_mask = this->mLightSecAggForMNN->mask_generate(this->m_modelCachePath.data());
+ std::vector > local_encode_mask = this->mLightSecAggForMNN->local_mask_encoding(this->m_local_mask);
+ return local_encode_mask;
+}
+
+
+void FedMLClientManagerSA::save_mask_from_paired_clients(int client_index,
+ std::vector local_encode_mask) {
+
+ this->mLightSecAggForMNN->save_mask_from_paired_clients(client_index,
+ local_encode_mask);
+
+}
+
+std::vector FedMLClientManagerSA::get_client_IDs_that_have_sent_mask() {
+ return this->mLightSecAggForMNN->get_client_IDs_that_have_sent_mask();
+}
+
+std::string FedMLClientManagerSA::train() {
+ std::string result = this->mFedMLTrainer->train();
+ return result;
+}
+
+void FedMLClientManagerSA::generate_masked_model() {
+ this->mLightSecAggForMNN->MNN_encode(this->m_modelCachePath.data(), this->m_dataSet.data(), this->m_local_mask);
+}
+
+std::vector FedMLClientManagerSA::get_aggregated_encoded_mask(std::vector surviving_list_from_server) {
+ std::vector sum_mask = this->mLightSecAggForMNN->mask_agg(surviving_list_from_server);
+ return sum_mask;
+}
+
+
+std::string FedMLClientManagerSA::getEpochAndLoss() {
+ std::string result = this->mFedMLTrainer->getEpochAndLoss();
+ return result;
+}
+
+bool FedMLClientManagerSA::stopTraining() {
+ bool result = this->mFedMLTrainer->stopTraining();
+ return result;
+}
+
+void FedMLClientManagerSA::printMNNVar(VARP x) {
+ this->mLightSecAggForMNN->printVar(x);
+}
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/src/MNN/cifar10.cpp b/android/fedmlsdk/MobileNN/src/MNN/cifar10.cpp
new file mode 100644
index 0000000000..d6bd426e04
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/MNN/cifar10.cpp
@@ -0,0 +1,119 @@
+#include "cifar10.h"
+#include
+#include
+#include
+
+namespace MNN {
+namespace Train {
+
+static int32_t kTrainSize;
+static int32_t kTestSize;
+static uint32_t kSizePerBatch;
+const uint32_t kImageRows = 32;
+const uint32_t kImageColumns = 32;
+const uint32_t kBytesPerRow = 3073;
+const uint32_t kBytesPerChannelPerRow = 1024;
+
+const std::vector kTrainDataBatchFiles = {
+ "data_batch_1.bin",
+ "data_batch_2.bin",
+ "data_batch_3.bin",
+ "data_batch_4.bin",
+ "data_batch_5.bin",
+};
+
+const std::vector kTestDataBatchFiles = {
+ "test_batch.bin"
+};
+
+std::pair read_data(const std::string& root, bool train) {
+ const auto& files = train ? kTrainDataBatchFiles : kTestDataBatchFiles;
+ const auto num_samples = train ? kTrainSize : kTestSize;
+
+ uint32_t kBytesPerBatchFile = kBytesPerRow * kSizePerBatch;
+
+ std::vector data_buffer;
+ data_buffer.reserve(files.size() * kBytesPerBatchFile);
+
+ for (const auto& file : files) {
+ auto path = root;
+ if (path.back() != '/') {
+ path.push_back('/');
+ }
+ path += file;
+ std::ifstream data(path, std::ios::binary);
+ if (!data.is_open()) {
+ MNN_PRINT("Error opening data file at %s", path.c_str());
+ MNN_ASSERT(false);
+ }
+
+ data_buffer.insert(data_buffer.end(), std::istreambuf_iterator(data), {});
+ }
+
+ MNN_ASSERT(data_buffer.size() == files.size() * kBytesPerBatchFile);
+
+ auto images = _Input({num_samples, 3, kImageRows, kImageColumns}, NCHW, halide_type_of());
+ auto labels = _Input({num_samples}, NCHW, halide_type_of());
+
+ for (uint32_t i = 0; i != num_samples; ++i) {
+ // The first byte of each row is the target class index.
+ uint32_t start_index = i * kBytesPerRow;
+ labels->writeMap()[i] = data_buffer[start_index];
+
+ // The next bytes correspond to the rgb channel values in the following order:
+ // red (32 *32 = 1024 bytes) | green (1024 bytes) | blue (1024 bytes)
+ uint32_t image_start = start_index + 1;
+ uint32_t image_end = image_start + 3 * kBytesPerChannelPerRow;
+ std::copy(data_buffer.begin() + image_start, data_buffer.begin() + image_end,
+ reinterpret_cast(images->writeMap() + (i * 3 * kBytesPerChannelPerRow)));
+ }
+
+ return {images, labels};
+}
+
+Cifar10Dataset::Cifar10Dataset(const std::string root, Mode mode) {
+ auto data = read_data(root, mode == Mode::TRAIN);
+ mImages = data.first;
+ mLabels = data.second;
+ mImagePtr = mImages->readMap();
+ mLabelsPtr = mLabels->readMap();
+}
+
+Example Cifar10Dataset::get(size_t index) {
+ auto data = _Input({3, kImageRows, kImageColumns}, NCHW, halide_type_of());
+ auto label = _Input({}, NCHW, halide_type_of());
+
+ auto dataPtr = mImagePtr + index * 3 * kImageRows * kImageColumns;
+ ::memcpy(data->writeMap(), dataPtr, 3 * kImageRows * kImageColumns);
+
+ auto labelPtr = mLabelsPtr + index;
+ ::memcpy(label->writeMap(), labelPtr, 1);
+
+ auto returnIndex = _Const(index);
+ // return the index for test
+ return {{data, returnIndex}, {label}};
+}
+
+size_t Cifar10Dataset::size() {
+ return mImages->getInfo()->dim[0];
+}
+
+const VARP Cifar10Dataset::images() {
+ return mImages;
+}
+
+const VARP Cifar10Dataset::labels() {
+ return mLabels;
+}
+
+DatasetPtr Cifar10Dataset::create(const std::string path, Mode mode, int32_t trainSize, int32_t testSize) {
+ kTrainSize = trainSize;
+ kTestSize = testSize;
+
+ DatasetPtr res;
+ res.mDataset.reset(new Cifar10Dataset(path, mode));
+ return res;
+}
+
+}
+}
diff --git a/android/fedmlsdk/MobileNN/src/MNN/mnist.cpp b/android/fedmlsdk/MobileNN/src/MNN/mnist.cpp
new file mode 100644
index 0000000000..077bc66795
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/MNN/mnist.cpp
@@ -0,0 +1,154 @@
+//
+// MnistDataset.cpp
+// MNN
+//
+// Created by MNN on 2019/11/15.
+// Copyright Š 2018, Alibaba Group Holding Limited
+//
+
+#include "mnist.h"
+#include
+#include
+#include
+namespace MNN {
+namespace Train {
+
+// referenced from pytorch C++ frontend mnist.cpp
+// https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp
+static int32_t kTrainSize;
+static int32_t kTestSize;
+const int32_t kImageMagicNumber = 2051;
+const int32_t kTargetMagicNumber = 2049;
+const int32_t kImageRows = 28;
+const int32_t kImageColumns = 28;
+const char* kTrainImagesFilename = "train-images-idx3-ubyte";
+const char* kTrainTargetsFilename = "train-labels-idx1-ubyte";
+const char* kTestImagesFilename = "t10k-images-idx3-ubyte";
+const char* kTestTargetsFilename = "t10k-labels-idx1-ubyte";
+
+bool check_is_little_endian() {
+ const uint32_t word = 1;
+ return reinterpret_cast(&word)[0] == 1;
+}
+
+constexpr uint32_t flip_endianness(uint32_t value) {
+ return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) | ((value & 0xff0000u) >> 8u) |
+ ((value & 0xff000000u) >> 24u);
+}
+
+uint32_t read_int32(std::ifstream& stream) {
+ static const bool is_little_endian = check_is_little_endian();
+ uint32_t value;
+ stream.read(reinterpret_cast(&value), sizeof value);
+ return is_little_endian ? flip_endianness(value) : value;
+}
+
+uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
+ const auto value = read_int32(stream);
+ // clang-format off
+ MNN_ASSERT(value == expected);
+ // clang-format on
+ return value;
+}
+
+std::string join_paths(std::string head, const std::string& tail) {
+ if (head.back() != '/') {
+ head.push_back('/');
+ }
+ head += tail;
+ return head;
+}
+
+VARP read_images(const std::string& root, bool train) {
+ const auto path = join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
+ std::ifstream images(path, std::ios::binary);
+ if (!images.is_open()) {
+ MNN_PRINT("Error opening images file at %s", path.c_str());
+ MNN_ASSERT(false);
+ }
+
+ const auto count = train ? kTrainSize : kTestSize;
+
+ // From http://yann.lecun.com/exdb/mnist/
+ expect_int32(images, kImageMagicNumber);
+ expect_int32(images, count);
+ expect_int32(images, kImageRows);
+ expect_int32(images, kImageColumns);
+
+ std::vector dims = {count, 1, kImageRows, kImageColumns};
+ int length = 1;
+ for (int i = 0; i < dims.size(); ++i) {
+ length *= dims[i];
+ }
+ auto data = _Input(dims, NCHW, halide_type_of());
+ images.read(reinterpret_cast(data->writeMap()), length);
+ return data;
+}
+
+VARP read_targets(const std::string& root, bool train) {
+ const auto path = join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
+ std::ifstream targets(path, std::ios::binary);
+ if (!targets.is_open()) {
+ MNN_PRINT("Error opening images file at %s", path.c_str());
+ MNN_ASSERT(false);
+ }
+
+ const auto count = train ? kTrainSize : kTestSize;
+
+ expect_int32(targets, kTargetMagicNumber);
+ expect_int32(targets, count);
+
+ std::vector dims = {count};
+ int length = 1;
+ for (int i = 0; i < dims.size(); ++i) {
+ length *= dims[i];
+ }
+ auto labels = _Input(dims, NCHW, halide_type_of());
+ targets.read(reinterpret_cast(labels->writeMap()), length);
+
+ return labels;
+}
+
+MnistDataset::MnistDataset(const std::string root, Mode mode)
+ : mImages(read_images(root, mode == Mode::TRAIN)), mLabels(read_targets(root, mode == Mode::TRAIN)) {
+ mImagePtr = mImages->readMap();
+ mLabelsPtr = mLabels->readMap();
+}
+
+Example MnistDataset::get(size_t index) {
+ auto data = _Input({1, kImageRows, kImageColumns}, NCHW, halide_type_of());
+ auto label = _Input({}, NCHW, halide_type_of());
+
+ auto dataPtr = mImagePtr + index * kImageRows * kImageColumns;
+ ::memcpy(data->writeMap(), dataPtr, kImageRows * kImageColumns);
+
+ auto labelPtr = mLabelsPtr + index;
+ ::memcpy(label->writeMap(), labelPtr, 1);
+
+ auto returnIndex = _Const(index);
+ // return the index for test
+ return {{data, returnIndex}, {label}};
+}
+
+size_t MnistDataset::size() {
+ return mImages->getInfo()->dim[0];
+}
+
+const VARP MnistDataset::images() {
+ return mImages;
+}
+
+const VARP MnistDataset::labels() {
+ return mLabels;
+}
+
+DatasetPtr MnistDataset::create(const std::string path, Mode mode, int32_t trainSize, int32_t testSize) {
+ kTrainSize = trainSize;
+ kTestSize = testSize;
+
+ DatasetPtr res;
+ res.mDataset.reset(new MnistDataset(path, mode));
+ return res;
+}
+}
+}
diff --git a/android/fedmlsdk/MobileNN/src/main_FedMLClientManager.cpp b/android/fedmlsdk/MobileNN/src/main_FedMLClientManager.cpp
new file mode 100644
index 0000000000..0121f52288
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/main_FedMLClientManager.cpp
@@ -0,0 +1,49 @@
+#include "FedMLClientManager.h"
+
+static void onProgressCallback(float progress) {
+ printf("callback. progress = %f\n", progress);
+}
+
+static void onLossCallback(int epoch, float loss) {
+ printf("callback. epoch = %d, loss = %f\n", epoch, loss);
+}
+
+static void onAccuracyCallback(int epoch, float accuracy) {
+ printf("callback. epoch = %d, accuracy = %f\n", epoch, accuracy);
+}
+
+
+int main(int argc, char *argv[]) {
+ std::cout << "You have entered " << argc
+ << " arguments:" << "\n";
+
+ for (int i = 0; i < argc; ++i)
+ std::cout << argv[i] << "\n";
+
+ /*
+ * usage:
+ * ./build_x86_linux/main_fedml_client_mangaer.out mnist ../../../../data/lenet_mnist.mnn ../../../../data/MNIST/raw
+ */
+ const char* datasetName = argv[1];
+ const char* modelPath = argv[2];
+ const char* dataPath = argv[3];
+
+ int trainSize = 600;
+ int testSize = 100;
+ int batchSize = 8;
+ double learningRate = 0.01;
+ int epochNum = 1;
+
+ MobileNNBackend backend = USE_TORCH;
+ FedMLClientManager *mFedMLClientManager = new FedMLClientManager(backend);
+
+ mFedMLClientManager->init(modelPath, dataPath, datasetName,
+ trainSize, testSize, batchSize, learningRate, epochNum,
+ std::bind(&onProgressCallback, std::placeholders::_1),
+ std::bind(&onLossCallback, std::placeholders::_1, std::placeholders::_2),
+ std::bind(&onAccuracyCallback, std::placeholders::_1, std::placeholders::_2));
+
+ mFedMLClientManager->train();
+ std::cout << mFedMLClientManager->getEpochAndLoss() << std::endl;
+ return 0;
+}
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/src/main_FedMLTrainerSA.cpp b/android/fedmlsdk/MobileNN/src/main_FedMLTrainerSA.cpp
new file mode 100644
index 0000000000..32699dae80
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/main_FedMLTrainerSA.cpp
@@ -0,0 +1,126 @@
+#include "FedMLClientManagerSA.h"
+#include
+
+
+
+static void onProgressCallback(float progress) {
+ printf("callback. progress = %f\n", progress);
+}
+
+static void onLossCallback(int epoch, float loss) {
+ printf("callback. epoch = %d, loss = %f\n", epoch, loss);
+}
+
+static void onAccuracyCallback(int epoch, float accuracy) {
+ printf("callback. epoch = %d, accuracy = %f\n", epoch, accuracy);
+}
+
+
+int main(int argc, char *argv[]) {
+ std::cout << "You have entered 1" << argc
+ << " arguments:" << "\n";
+
+ for (int i = 0; i < argc; ++i)
+ std::cout << argv[i] << "\n";
+
+ /*
+ * usage:
+ * ./build_x86_linux/fedml_trainer.out mnist ../../../../data/lenet_mnist.mnn ../../../../data/mnist
+ */
+ const char *dataSetType = argv[1];
+ const char *modelPath = argv[2];
+ const char *dataPath = argv[3];
+
+ int epochNum = 1;
+ float learningRate = 0.01;
+ int batchSizeNum = 8;
+ int trainSize = 20000;
+ int testSize = 1000;
+
+ //test parameter for the encoding part
+ int client_num = 10;
+ int q_bits = 15;
+ int p = pow(2, 15) - 19;
+ printf("debug22");
+ printf("main::CreateModelFromFile(%s, %s, %s, %d, %f, %d, %d, %d)\n", modelPath, dataPath, dataSetType,
+ batchSizeNum,
+ learningRate, epochNum, trainSize, testSize);
+ printf("debug11");
+ FedMLClientManager *mFedMLClientManager = new FedMLClientManager();
+
+ //init all required parameters
+ printf("debug1");
+
+ mFedMLClientManager->init(modelPath, dataPath, dataSetType,
+ trainSize, testSize, batchSizeNum,
+ learningRate, epochNum,
+ q_bits, p, client_num,
+ std::bind(&onProgressCallback, std::placeholders::_1),
+ std::bind(&onLossCallback, std::placeholders::_1, std::placeholders::_2),
+ std::bind(&onAccuracyCallback, std::placeholders::_1, std::placeholders::_2));
+
+ /**
+ * 1. generate mask and encode local mask for others
+ */
+ printf("debug2");
+ std::vector > encoded_mask = mFedMLClientManager->get_local_encoded_mask();
+ std::cout << encoded_mask[0].size() << std::endl;
+
+ /**
+ * 2. share and receive local mask from others via server (including share to self)
+ * call this function repeatedly during listening phase, once we receive a pair, store it via this function
+ */
+ printf("debug3");
+ int client_index = 1;
+ std::vector local_encode_mask = encoded_mask[0];
+ mFedMLClientManager->save_mask_from_paired_clients(client_index, local_encode_mask);
+
+ int client_index_another = 3;
+ std::vector local_encode_mask_another = encoded_mask[9];
+ mFedMLClientManager->save_mask_from_paired_clients(client_index_another, local_encode_mask_another);
+
+ /**
+ * 3. report receive online users to server
+ */
+ printf("debug4");
+ std::vector online_user = mFedMLClientManager->get_client_IDs_that_have_sent_mask();
+
+ /**
+ * 4. do training
+ */
+ printf("debug5");
+ mFedMLClientManager->train();
+
+ /**
+ * 5. save masked model
+ */
+ printf("debug6");
+ mFedMLClientManager->generate_masked_model();
+
+ /**
+ * 6. receive online user list from server
+ * aggregate received mask
+ * surviving list represents the users that is online confirmed by server
+ */
+ printf("debug7");
+ std::vector surviving_list_from_server;
+ surviving_list_from_server.push_back(1);
+ surviving_list_from_server.push_back(3);
+ std::vector upload_agg_mask = mFedMLClientManager->get_aggregated_encoded_mask(surviving_list_from_server);
+ std::cout << upload_agg_mask[0];
+
+ /**
+ * test loading function for the encoded mnn
+ *
+ */
+ // load computational graph
+ auto varMap = Variable::loadMap(modelPath);
+ auto inputOutputs = Variable::getInputAndOutput(varMap);
+ auto inputs = Variable::mapToSequence(inputOutputs.first);
+ auto outputs = Variable::mapToSequence(inputOutputs.second);
+
+ // convert to trainable module
+ std::shared_ptr model(NN::extract(inputs, outputs, true));
+
+ return 0;
+}
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/src/main_MNN_train.cpp b/android/fedmlsdk/MobileNN/src/main_MNN_train.cpp
new file mode 100644
index 0000000000..cd659cd2b2
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/main_MNN_train.cpp
@@ -0,0 +1,54 @@
+#include
+#include
+#include
+#include
+#include "FedMLTrainer.h"
+
+
+static void onProgressCallback(float progress) {
+ printf("callback. progress = %f\n", progress);
+}
+
+static void onLossCallback(int epoch, float loss) {
+ printf("callback. epoch = %d, loss = %f\n", epoch, loss);
+}
+
+static void onAccuracyCallback(int epoch, float accuracy) {
+ printf("callback. epoch = %d, accuracy = %f\n", epoch, accuracy);
+}
+
+
+int main(int argc, char *argv[]) {
+ std::cout << "You have entered " << argc
+ << " arguments:" << "\n";
+
+ for (int i = 0; i < argc; ++i)
+ std::cout << argv[i] << "\n";
+
+ /*
+ * usage:
+ * ./build_x86_linux/main_mnn_train.out mnist ../../../../data/mnn_model/lenet_mnist.mnn ../../../../data/MNIST/raw
+ */
+ const char* datasetName = argv[1];
+ const char* modelPath = argv[2];
+ const char* dataPath = argv[3];
+
+
+ int trainSize = 60000;
+ int testSize = 10000;
+ int batchSize = 8;
+ double learningRate = 0.01;
+ int epochNum = 10;
+
+ FedMLBaseTrainer *pFedMLTrainer = FedMLTrainer().getTrainer();
+ pFedMLTrainer->init(modelPath, dataPath,
+ datasetName, trainSize, testSize,
+ batchSize, learningRate, epochNum,
+ std::bind(&onProgressCallback, std::placeholders::_1),
+ std::bind(&onAccuracyCallback, std::placeholders::_1, std::placeholders::_2),
+ std::bind(&onLossCallback, std::placeholders::_1, std::placeholders::_2));
+ pFedMLTrainer->train();
+
+ return 0;
+
+}
diff --git a/android/fedmlsdk/MobileNN/src/main_torch_train.cpp b/android/fedmlsdk/MobileNN/src/main_torch_train.cpp
new file mode 100644
index 0000000000..2f86966a83
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/main_torch_train.cpp
@@ -0,0 +1,54 @@
+#include
+#include
+#include
+#include
+#include "FedMLTrainer.h"
+
+
+static void onProgressCallback(float progress) {
+ printf("callback. progress = %f\n", progress);
+}
+
+static void onLossCallback(int epoch, float loss) {
+ printf("callback. epoch = %d, loss = %f\n", epoch, loss);
+}
+
+static void onAccuracyCallback(int epoch, float accuracy) {
+ printf("callback. epoch = %d, accuracy = %f\n", epoch, accuracy);
+}
+
+
+int main(int argc, char *argv[]) {
+ std::cout << "You have entered " << argc
+ << " arguments:" << "\n";
+
+ for (int i = 0; i < argc; ++i)
+ std::cout << argv[i] << "\n";
+
+ /*
+ * usage:
+ * ./build_x86_linux/main_torch_train.out mnist ../../../../data/torch_model/scripted_lenet_model.ptl ../../../../data/MNIST/raw
+ */
+ const char* datasetName = argv[1];
+ const char* modelPath = argv[2];
+ const char* dataPath = argv[3];
+
+
+ int trainSize = 6000;
+ int testSize = 1000;
+ int batchSize = 32;
+ double learningRate = 0.01;
+ int epochNum = 1;
+
+ FedMLBaseTrainer *pFedMLTrainer = FedMLTrainer().getTrainer();
+ pFedMLTrainer->init(modelPath, dataPath,
+ datasetName, trainSize, testSize,
+ batchSize, learningRate, epochNum,
+ std::bind(&onProgressCallback, std::placeholders::_1),
+ std::bind(&onAccuracyCallback, std::placeholders::_1, std::placeholders::_2),
+ std::bind(&onLossCallback, std::placeholders::_1, std::placeholders::_2));
+ pFedMLTrainer->train();
+
+ return 0;
+
+}
diff --git a/android/fedmlsdk/MobileNN/src/security/LightSecAgg.cpp b/android/fedmlsdk/MobileNN/src/security/LightSecAgg.cpp
new file mode 100644
index 0000000000..6255d6da7f
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/security/LightSecAgg.cpp
@@ -0,0 +1,135 @@
+#include "LightSecAgg.h"
+
+
+int LightSecAgg::modInverse(int a, int p) {
+ int m = p;
+ int y = 0, x = 1;
+ int q = 0;
+
+ while (a > 1) {
+ if (m != 0) {
+ q = a / m;
+ int t = m;
+ m = a % m, a = t;
+ t = y;
+ y = x - q * y;
+ x = t;
+ } else {
+ q = 0;
+ a = 0;
+ int t = y;
+ y = x - q * y;
+ x = t;
+ }
+ if (x < 0) {
+ x = x + p;
+ }
+ }
+ x = x % p;
+ if (x < 0)
+ x = x + p;
+ return x;
+}
+
+int LightSecAgg::modDivide(int num, int den, int p) {
+ num = num % p;
+ den = den % p;
+ int inv = modInverse(den, p);
+ int c = (inv * num) % p;
+ return c;
+}
+
+
+int LightSecAgg::PI(std::vector vals, int p) {
+ int accum = 1;
+ for (auto v: vals) {
+ if (v < 0)
+ v = v + p;
+ int tmp = v % p;
+ accum = accum * tmp % p;
+ }
+ return accum;
+}
+
+
+std::vector > LightSecAgg::gen_Lagrange_coeffs(std::vector const &alpha_s,
+ std::vector const &beta_s,
+ int p, int is_K1) {
+ int num_alpha = (is_K1 == 1) ? 1 : alpha_s.size();
+ std::vector > U(num_alpha, std::vector(beta_s.size(), 0));
+
+ std::vector w(beta_s.size(), 0);
+ for (int j = 0; j < beta_s.size(); j++) {
+ int cur_beta = beta_s[j];
+ std::vector val;
+ for (auto o: beta_s) {
+ if (cur_beta != o)
+ val.push_back(cur_beta - o);
+ }
+ int den = PI(val, p);
+ w[j] = den;
+ }
+
+ std::vector l(num_alpha, 0);
+ for (int i = 0; i < num_alpha; i++) {
+ std::vector val;
+ for (auto o: beta_s) {
+ val.push_back(alpha_s[i] - o);
+ }
+ l[i] = PI(val, p);
+ }
+
+ for (int j = 0; j < beta_s.size(); j++) {
+ for (int i = 0; i < num_alpha; i++) {
+ int tmp = alpha_s[i] - beta_s[j];
+ if (tmp < 0)
+ tmp = tmp + p;
+ int den = (tmp % p) * w[j] % p;
+ // int den = ((alpha_s[i] - beta_s[j]) % p) * w[j] % p;
+ U[i][j] = modDivide(l[i], den, p);
+ }
+ }
+
+ return U;
+}
+
+std::vector > LightSecAgg::LCC_encoding_with_points(std::vector > const &X,
+ std::vector const &alpha_s,
+ std::vector const &beta_s, int p) {
+ int m = X.size();
+ int d = X[0].size();
+ auto U = gen_Lagrange_coeffs(beta_s, alpha_s, p);
+ std::vector > X_LCC(beta_s.size(), std::vector(d, 0.0));
+ for (int i = 0; i < U.size(); i++) {
+ for (int j = 0; j < d; j++) {
+ X_LCC[i][j] = 0;
+ for (int k = 0; k < U[0].size(); k++) {
+ X_LCC[i][j] += U[i][k] * X[k][j];
+ }
+ X_LCC[i][j] = std::fmod(X_LCC[i][j], p);
+ }
+ }
+
+ return X_LCC;
+}
+
+std::vector > LightSecAgg::LCC_decoding_with_points(std::vector > f_eval,
+ std::vector eval_points,
+ std::vector target_points, int p) {
+ auto alpha_s_eval = eval_points;
+ auto beta_s = target_points;
+ auto U_dec = gen_Lagrange_coeffs(beta_s, alpha_s_eval, p);
+
+ std::vector > f_recon(U_dec.size(), std::vector(f_eval[0].size(), 0.0));
+ for (int i = 0; i < U_dec.size(); i++) {
+ for (int j = 0; j < f_eval[0].size(); j++) {
+ f_recon[i][j] = 0;
+ for (int k = 0; k < U_dec[0].size(); k++) {
+ f_recon[i][j] += U_dec[i][k] * f_eval[k][j];
+ }
+ f_recon[i][j] = std::fmod(f_recon[i][j], p);
+ }
+ }
+ return f_recon;
+
+}
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/src/security/LightSecAggForMNN.cpp b/android/fedmlsdk/MobileNN/src/security/LightSecAggForMNN.cpp
new file mode 100644
index 0000000000..da07e9c963
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/security/LightSecAggForMNN.cpp
@@ -0,0 +1,211 @@
+#include "LightSecAggForMNN.h"
+
+using namespace MNN;
+using namespace MNN::Train;
+using namespace MNN::Express;
+
+
+void LightSecAggForMNN::init(int q_bits, int p, int client_num) {
+ this->m_q_bits = q_bits;
+ this->m_p = p;
+ this->m_client_num = client_num;
+}
+
+
+void LightSecAggForMNN::printVar(VARP x) {
+ auto size = x->getInfo()->size;
+ auto ptr = x->readMap();
+ for (int i = 0; i < size; ++i) {
+ MNN_PRINT("%f, ", ptr[i]);
+ }
+ MNN_PRINT("\n");
+}
+
+
+VARPS LightSecAggForMNN::mask_generate(const char *modelCachePath) {
+ auto varMap = Variable::loadMap(modelCachePath);
+ auto inputOutputs = Variable::getInputAndOutput(varMap);
+ auto inputs = Variable::mapToSequence(inputOutputs.first);
+ auto outputs = Variable::mapToSequence(inputOutputs.second);
+ LightSecAggForMNN mec;
+ // convert to trainable module and get weights
+ std::shared_ptr model(NN::extract(inputs, outputs, true));
+ auto param = model->parameters();
+
+ // generate the mask for the model
+ VARPS finite_weights = mec.transform_tensor_to_finite(param);
+ VARPS model_mask = mec.generate_random_mask(finite_weights);
+
+ return model_mask;
+}
+
+
+std::vector > LightSecAggForMNN::local_mask_encoding(VARPS model_mask) {
+ std::vector local_mask = mask_transform(model_mask);
+ std::vector > encode_mask;
+ encode_mask = mask_encoding(this->m_client_num, this->m_p, local_mask);
+ return encode_mask;
+}
+
+
+void
+LightSecAggForMNN::MNN_encode(const char *modelCachePath, const char *dataSet, VARPS model_mask) {
+
+ auto varMap = Variable::loadMap(modelCachePath);
+ auto inputOutputs = Variable::getInputAndOutput(varMap);
+ auto inputs = Variable::mapToSequence(inputOutputs.first);
+ auto outputs = Variable::mapToSequence(inputOutputs.second);
+
+ // convert to trainable module and get weights
+ std::shared_ptr model(NN::extract(inputs, outputs, true));
+ auto param = model->parameters();
+
+ // encode the model
+ VARPS finite_weights = transform_tensor_to_finite(param);
+ // mec.printVar(finite_weights[0]);
+ model_masking(finite_weights, model_mask, this->m_p);
+ model->loadParameters(finite_weights);
+
+ //save the model
+ VARP forwardInput;
+ if (strcmp(dataSet, "mnist") == 0) { // mnist dataset
+ forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
+ } else if (strcmp(dataSet, "cifar10") == 0) { // cifar10 dataset
+ forwardInput = _Input({1, 3, 32, 32}, NC4HW4);
+ }
+ model->setIsTraining(true); // save the training state computation graph
+ forwardInput->setName("data");
+ auto inputPredict = model->forward(forwardInput);
+ inputPredict->setName("prob");
+ Variable::save({inputPredict}, modelCachePath);
+ printf("masked model save done");
+}
+
+
+std::vector LightSecAggForMNN::mask_agg(std::vector surviving_list_from_server) {
+ std::vector sum_mask(this->m_local_received_mask_from_other_clients[0].size(), 0);
+ for (int j = 0; j < sum_mask.size(); j++) {
+ for (int i = 0; i < this->m_surviving_clients.size(); i++) {
+ if(std::find(surviving_list_from_server.begin(), surviving_list_from_server.end(), this->m_surviving_clients[i]) != surviving_list_from_server.end() )
+ sum_mask[j] = sum_mask[j] + this->m_local_received_mask_from_other_clients[i][j];
+ }
+ }
+ return sum_mask;
+}
+
+void LightSecAggForMNN::save_mask_from_paired_clients(int client_index,
+ std::vector received_encode_mask) {
+ this->m_local_received_mask_from_other_clients.push_back(received_encode_mask);
+ this->m_surviving_clients.push_back(client_index);
+}
+
+std::vector LightSecAggForMNN::get_client_IDs_that_have_sent_mask() {
+ return this->m_surviving_clients;
+}
+
+//private methods:
+
+VARP LightSecAggForMNN::my_q(VARP const &X) {
+ VARP result = _Input(X->getInfo()->dim, X->getInfo()->order, halide_type_of());
+ auto size = X->getInfo()->size;
+ auto x = X->readMap();
+ auto y = result->writeMap();
+ for (int i = 0; i < size; ++i) {
+ y[i] = round(x[i] * pow(2, this->m_q_bits));
+ if (y[i] < 0) {
+ y[i] += this->m_p;
+ }
+ }
+ return result;
+}
+
+VARPS LightSecAggForMNN::transform_tensor_to_finite(VARPS const &model_params) {
+ VARPS result;
+ for (auto tmp: model_params) {
+ tmp = my_q(tmp);
+ result.push_back(tmp);
+ }
+ return result;
+}
+
+VARPS LightSecAggForMNN::generate_random_mask(VARPS const &model_params) {
+ VARPS model_mask;
+ for (auto tmp: model_params) {
+ VARP layer_weights = _Input(tmp->getInfo()->dim, tmp->getInfo()->order, halide_type_of());
+ auto size = layer_weights->getInfo()->size;
+ auto mask = layer_weights->writeMap();
+ for (int i = 0; i < size; ++i) {
+ mask[i] = rand() % this->m_p;
+ }
+ model_mask.push_back(layer_weights);
+ }
+ return model_mask;
+}
+
+
+void LightSecAggForMNN::model_masking(VARPS &weights_finite, VARPS const &local_mask, int prime_number) {
+ std::vector g;
+ int l = 1;
+ g.push_back(l);
+ VARP p = _Input(g, weights_finite[0]->getInfo()->order, halide_type_of());
+ auto prime = p->writeMap();
+ prime[0] = prime_number;
+ for (int i = 0; i < weights_finite.size(); i++) {
+ weights_finite[i] = weights_finite[i] + local_mask[i];
+ weights_finite[i] = _FloorMod(weights_finite[i], p);
+ }
+
+}
+
+std::vector LightSecAggForMNN::mask_transform(VARPS model_mask) {
+ std::vector local_mask;
+ for (auto tmp: model_mask) {
+ auto size = tmp->getInfo()->size;
+ auto x = tmp->readMap();
+ for (int i = 0; i < size; ++i) {
+ local_mask.push_back(x[i]);
+ }
+ }
+ return local_mask;
+}
+
+std::vector >
+LightSecAggForMNN::mask_encoding(int num_clients, int prime_number, std::vector const &local_mask) {
+ int d = local_mask.size();
+ int N = num_clients;
+ int T = N / 2;
+ int U = T + 1;
+ int p = prime_number;
+
+ std::vector beta_s(N);
+ std::iota(std::begin(beta_s), std::end(beta_s), 1);
+ std::vector alpha_s(U);
+ std::iota(std::begin(alpha_s), std::end(alpha_s), N + 1);
+
+ auto local_mask_rand = local_mask;
+ std::random_device dev;
+ std::mt19937 rng(dev());
+ std::uniform_real_distribution dist(1, p);
+ for (int i = 0; i < (T * d / (U - T)); i++)
+ local_mask_rand.push_back(dist(rng));
+
+ int y = d / (U - T);
+ std::vector > LCC_in(U, std::vector(y));
+
+ for (int i = 0; i < local_mask_rand.size(); i++)
+ LCC_in[i / y][i % y] = local_mask_rand[i];
+ LightSecAgg lsa;
+ std::vector > encoded_mask_set = lsa.LCC_encoding_with_points(LCC_in, alpha_s, beta_s, p);
+ return encoded_mask_set;
+}
+
+std::vector LightSecAggForMNN::z_tilde_sum(std::vector > const &z_tilde_buffer,
+ std::vector const &sur_list) {
+ std::vector w(z_tilde_buffer[0].size(), 0.0);
+ for (int i = 0; i < w.size(); ++i) {
+ for (int j: sur_list)
+ w[i] = w[i] + z_tilde_buffer[i][j];
+ }
+ return w;
+}
+
diff --git a/android/fedmlsdk/MobileNN/src/torch/cifar10.cpp b/android/fedmlsdk/MobileNN/src/torch/cifar10.cpp
new file mode 100644
index 0000000000..f2c7b422cb
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/torch/cifar10.cpp
@@ -0,0 +1,174 @@
+// Mimicing
+// https://github.com/pytorch/pytorch/blob/fc804b5def5e7d7ecad24c4d1ca4ac575e588ae8/torch/csrc/api/src/data/datasets/mnist.cpp
+
+// CIFAR dataset
+// https://www.cs.toronto.edu/~kriz/cifar.html
+
+// #include
+#include "cifar10.h"
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+namespace torch {
+namespace data {
+namespace datasets {
+const std::vector train_set_file_names{
+ "data_batch_1.bin"};
+const std::vector test_set_file_names{"test_batch.bin"};
+const std::string meta_data_file_name{"batches.meta.txt"};
+
+constexpr const uint32_t num_samples_per_file{10000};
+constexpr const uint32_t image_height{32};
+constexpr const uint32_t image_width{32};
+constexpr const uint32_t image_channels{3};
+
+std::string join_paths(std::string head, const std::string& tail) {
+ if (head.back() != '/') {
+ head.push_back('/');
+ }
+ head += tail;
+ return head;
+}
+
+torch::Tensor read_targets_from_file(const std::string& file_path, uint32_t kSize) {
+ torch::Tensor targets =
+ torch::empty({kSize * 1}, torch::kUInt8);
+ uint8_t* ptr_data = targets.data_ptr();
+
+ std::fstream f;
+ f.open(file_path, f.binary | f.in);
+ if (!f.is_open()) {
+ std::cerr << "Failed to open " << file_path << std::endl;
+ // TORCH_CHECK(f, "Error opening targets file at ", file_path);
+ } else {
+ for(uint32_t i=0; i(ptr_data + i * 1), 1);
+ f.ignore(image_height * image_width * image_channels * 1);
+ }
+ }
+
+ // assert(
+ // (count == num_samples_per_file) &&
+ // "Insufficient number of targets. Data file might have been corrupted.");
+
+ // targets = targets.reshape({num_samples_per_file, 1});
+
+ return targets;
+}
+
+torch::Tensor read_images_from_file(const std::string& file_path, uint32_t kSize) {
+ constexpr const uint32_t num_image_bytes{image_height * image_width *
+ image_channels * 1};
+
+ torch::Tensor images =
+ torch::empty({kSize * num_image_bytes}, torch::kUInt8);
+ uint8_t* ptr_data = images.data_ptr();
+
+ std::fstream f;
+ f.open(file_path, f.binary | f.in);
+ if (!f.is_open()) {
+ std::cerr << "Failed to open " << file_path << std::endl;
+ // TORCH_CHECK(f, "Error opening images file at ", file_path);
+ } else {
+ for(uint32_t i=0; i(ptr_data + i * num_image_bytes),
+ num_image_bytes);
+ }
+ }
+
+ // assert((count == num_samples_per_file) &&
+ // "Insufficient number of images. Data file might have been corrupted.");
+
+ // The next 3072 bytes are the values of the pixels of the image.
+ // The first 1024 bytes are the red channel values, the next 1024 the green,
+ // and the final 1024 the blue. The values are stored in row-major order, so
+ // the first 32 bytes are the red channel values of the first row of the
+ // image. NCHW format
+ images = images.reshape(
+ {kSize, image_channels, image_height, image_width});
+
+ return images;
+}
+
+torch::Tensor read_images(const std::string& root, uint32_t kSize, bool train) {
+ std::vector data_set_file_names;
+ if (train) {
+ data_set_file_names = train_set_file_names;
+ } else {
+ data_set_file_names = test_set_file_names;
+ }
+
+ std::vector data_set_file_paths;
+ for (const std::string& data_set_file_name : data_set_file_names) {
+ data_set_file_paths.push_back(join_paths(root, data_set_file_name));
+ }
+
+ std::vector image_tensors;
+
+ for (const std::string& data_set_file_path : data_set_file_paths) {
+ torch::Tensor images = read_images_from_file(data_set_file_path, kSize);
+ image_tensors.push_back(images);
+ }
+
+ torch::Tensor images = torch::cat(image_tensors, 0);
+
+ images = images.to(torch::kFloat32).div_(255);
+
+ return images;
+}
+
+torch::Tensor read_targets(const std::string& root, uint32_t kSize, bool train) {
+ std::vector data_set_file_names;
+ if (train) {
+ data_set_file_names = train_set_file_names;
+ } else {
+ data_set_file_names = test_set_file_names;
+ }
+
+ std::vector data_set_file_paths;
+ for (const std::string& data_set_file_name : data_set_file_names) {
+ data_set_file_paths.push_back(join_paths(root, data_set_file_name));
+ }
+
+ std::vector target_tensors;
+
+ for (const std::string& data_set_file_path : data_set_file_paths) {
+ torch::Tensor targets = read_targets_from_file(data_set_file_path, kSize);
+ target_tensors.push_back(targets);
+ }
+
+ torch::Tensor targets = torch::cat(target_tensors, 0);
+
+ targets = targets.to(torch::kInt64);
+
+ return targets;
+}
+
+CIFAR10::CIFAR10(const std::string& root, uint32_t kSize, Mode mode)
+ : images_(read_images(root, kSize, mode == Mode::kTrain)),
+ targets_(read_targets(root, kSize, mode == Mode::kTrain)) {}
+
+torch::data::Example<> CIFAR10::get(size_t index) {
+ return {images_[index], targets_[index]};
+}
+
+torch::optional CIFAR10::size() const { return images_.size(0); }
+
+// bool CIFAR10::is_train() const noexcept {
+// return images_.size(0) == num_samples_per_file * train_set_file_names.size();
+// }
+
+const torch::Tensor& CIFAR10::images() const { return images_; }
+
+const torch::Tensor& CIFAR10::targets() const { return targets_; }
+
+} // namespace datasets
+} // namespace data
+} // namespace torch
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/src/torch/mnist.cpp b/android/fedmlsdk/MobileNN/src/torch/mnist.cpp
new file mode 100644
index 0000000000..f0fced64ad
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/torch/mnist.cpp
@@ -0,0 +1,120 @@
+#include
+
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+#include
+
+namespace {
+// constexpr uint32_t kTrainSize = 60000;
+// constexpr uint32_t kTestSize = 10000;
+constexpr uint32_t kImageMagicNumber = 2051;
+constexpr uint32_t kTargetMagicNumber = 2049;
+constexpr uint32_t kImageRows = 28;
+constexpr uint32_t kImageColumns = 28;
+constexpr const char* kTrainImagesFilename = "train-images-idx3-ubyte";
+constexpr const char* kTrainTargetsFilename = "train-labels-idx1-ubyte";
+constexpr const char* kTestImagesFilename = "t10k-images-idx3-ubyte";
+constexpr const char* kTestTargetsFilename = "t10k-labels-idx1-ubyte";
+
+bool check_is_little_endian() {
+ const uint32_t word = 1;
+ return reinterpret_cast(&word)[0] == 1;
+}
+
+constexpr uint32_t flip_endianness(uint32_t value) {
+ return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) |
+ ((value & 0xff0000u) >> 8u) | ((value & 0xff000000u) >> 24u);
+}
+
+uint32_t read_int32(std::ifstream& stream) {
+ static const bool is_little_endian = check_is_little_endian();
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ uint32_t value;
+ AT_ASSERT(stream.read(reinterpret_cast(&value), sizeof value));
+ return is_little_endian ? flip_endianness(value) : value;
+}
+
+uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
+ const auto value = read_int32(stream);
+ // clang-format off
+ // TORCH_CHECK(value == expected,
+ // "Expected to read number ", expected, " but found ", value, " instead");
+ // clang-format on
+ return value;
+}
+
+std::string join_paths(std::string head, const std::string& tail) {
+ if (head.back() != '/') {
+ head.push_back('/');
+ }
+ head += tail;
+ return head;
+}
+
+torch::Tensor read_images(const std::string& root, uint32_t kSize, bool train) {
+ const auto path =
+ join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
+ std::ifstream images(path, std::ios::binary);
+ TORCH_CHECK(images, "Error opening images file at ", path);
+
+ const auto count = kSize;
+
+ // From http://yann.lecun.com/exdb/mnist/
+ expect_int32(images, kImageMagicNumber);
+ expect_int32(images, count);
+ expect_int32(images, kImageRows);
+ expect_int32(images, kImageColumns);
+
+ auto tensor =
+ torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte);
+ images.read(reinterpret_cast(tensor.data_ptr()), tensor.numel());
+ return tensor.to(torch::kFloat32).div_(255);
+}
+
+torch::Tensor read_targets(const std::string& root, uint32_t kSize, bool train) {
+ const auto path =
+ join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
+ std::ifstream targets(path, std::ios::binary);
+ TORCH_CHECK(targets, "Error opening targets file at ", path);
+
+ const auto count = kSize;
+
+ expect_int32(targets, kTargetMagicNumber);
+ expect_int32(targets, count);
+
+ auto tensor = torch::empty(count, torch::kByte);
+ targets.read(reinterpret_cast(tensor.data_ptr()), count);
+ return tensor.to(torch::kInt64);
+}
+} // namespace
+
+MNIST::MNIST(const std::string& root, uint32_t kSize, Mode mode)
+ : images_(read_images(root, kSize, mode == Mode::kTrain)),
+ targets_(read_targets(root, kSize, mode == Mode::kTrain)) {}
+
+torch::data::Example<> MNIST::get(size_t index) {
+ return {images_[index], targets_[index]};
+}
+
+torch::optional MNIST::size() const {
+ return images_.size(0);
+}
+
+// NOLINTNEXTLINE(bugprone-exception-escape)
+// bool MNIST::is_train() const noexcept {
+// return images_.size(0) == 60000;
+// }
+
+const torch::Tensor& MNIST::images() const {
+ return images_;
+}
+
+const torch::Tensor& MNIST::targets() const {
+ return targets_;
+}
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/src/train/FedMLBaseTrainer.cpp b/android/fedmlsdk/MobileNN/src/train/FedMLBaseTrainer.cpp
new file mode 100644
index 0000000000..2cd2922235
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/train/FedMLBaseTrainer.cpp
@@ -0,0 +1,34 @@
+#include "FedMLBaseTrainer.h"
+
+void FedMLBaseTrainer::init(const char *modelCachePath, const char *dataCachePath,
+ const char *dataSet, int trainSize, int testSize,
+ int batchSizeNum, double learningRate, int epochNum,
+ progressCallback progress_callback,
+ accuracyCallback accuracy_callback,
+ lossCallback loss_callback) {
+
+ m_modelCachePath = modelCachePath;
+ m_dataCachePath = dataCachePath;
+ m_dataSet = dataSet;
+
+ m_trainSize = trainSize;
+ m_testSize = testSize;
+ m_batchSizeNum = batchSizeNum;
+ m_LearningRate = learningRate;
+ m_epochNum = epochNum;
+
+ m_progress_callback = progress_callback;
+ m_accuracy_callback = accuracy_callback;
+ m_loss_callback = loss_callback;
+}
+
+std::string FedMLBaseTrainer::getEpochAndLoss() {
+ std::string result = std::to_string(curEpoch) + "," + std::to_string(curLoss);
+ return result;
+}
+
+bool FedMLBaseTrainer::stopTraining() {
+ bRunStopFlag = true;
+ printf("stopTraining By User.");
+ return true;
+}
\ No newline at end of file
diff --git a/android/fedmlsdk/MobileNN/src/train/FedMLMNNTrainer.cpp b/android/fedmlsdk/MobileNN/src/train/FedMLMNNTrainer.cpp
new file mode 100644
index 0000000000..f9705aa44c
--- /dev/null
+++ b/android/fedmlsdk/MobileNN/src/train/FedMLMNNTrainer.cpp
@@ -0,0 +1,138 @@
+#include "FedMLMNNTrainer.h"
+
+std::string FedMLMNNTrainer::train() {
+ const char* modelCachePath = m_modelCachePath.c_str();
+ const char* dataCachePath = m_dataCachePath.c_str();
+ const char* dataSet = m_dataSet.c_str();
+
+ // load model
+ auto varMap = Variable::loadMap(modelCachePath);
+ auto inputOutputs = Variable::getInputAndOutput(varMap);
+ auto inputs = Variable::mapToSequence(inputOutputs.first);
+ auto outputs = Variable::mapToSequence(inputOutputs.second);
+
+ std::shared_ptr model(NN::extract(inputs, outputs, true));
+
+ // set executor
+ auto exe = Executor::getGlobalExecutor();
+ BackendConfig config;
+ exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
+
+ // set optimizer
+ std::shared_ptr sgd(new SGD(model));
+ sgd->setLearningRate(m_LearningRate);
+ sgd->setMomentum(0.1f);
+
+ m_progress_callback(10.0f);
+ if (bRunStopFlag) {printf("Training Stop By User."); return nullptr;}
+
+ // load data
+ DatasetPtr dataset;
+ DatasetPtr testDataset;
+ VARP forwardInput;
+
+ if (strcmp(dataSet, "mnist") == 0) {
+ printf("loading mnist\n");
+ dataset = MnistDataset::create(dataCachePath, MnistDataset::Mode::TRAIN, m_trainSize, m_testSize);
+ testDataset = MnistDataset::create(dataCachePath, MnistDataset::Mode::TEST, m_trainSize, m_testSize);
+ forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
+ } else if (strcmp(dataSet, "cifar10") == 0) {
+ printf("loading cifar10\n");
+ dataset = Cifar10Dataset::create(dataCachePath, Cifar10Dataset::Mode::TRAIN, m_trainSize, m_testSize);
+ testDataset = Cifar10Dataset::create(dataCachePath, Cifar10Dataset::Mode::TEST, m_trainSize, m_testSize);
+ forwardInput = _Input({1, 3, 32, 32}, NC4HW4);
+ }
+ auto dataLoader = std::shared_ptr(dataset.createLoader(m_batchSizeNum, true, true, 0));
+ size_t iterations = dataLoader->iterNumber();
+ size_t trainSamples = dataLoader->size();
+
+ auto testDataLoader = std::shared_ptr(testDataset.createLoader(m_batchSizeNum, true, false, 0));
+ size_t testIterations = testDataLoader->iterNumber();
+ size_t testSamples = testDataLoader->size();
+
+ m_progress_callback(20.0f);
+ if (bRunStopFlag) {printf("Training Stop By User."); return nullptr;}
+
+ // model training
+ for (int epoch = 0; epoch < m_epochNum; ++epoch) {
+ curEpoch = epoch;
+
+ model->clearCache();
+ exe->gc(Executor::FULL);
+ exe->resetProfile();
+ {
+ dataLoader->reset();
+ model->setIsTraining(true);
+
+ int moveBatchSize = 0;
+ for (int i = 0; i < iterations; i++) {
+ if (bRunStopFlag) {printf("Training Stop By User."); return nullptr;}
+
+ auto trainData = dataLoader->next();
+ auto example = trainData[0];
+ auto cast = _Cast(example.first[0]);
+ example.first[0] = cast * _Const(1.0f / 255.0f);
+ moveBatchSize += example.first[0]->getInfo()->dim[0];
+
+ // Compute One-Hot
+ auto newTarget = _OneHot(_Cast(example.second[0]), _Scalar(10), _Scalar(1.0f),
+ _Scalar(0.0f));
+
+ auto predict = model->forward(example.first[0]);
+ auto loss = _CrossEntropy(predict, newTarget);
+ sgd->step(loss);
+
+ curLoss = loss->readMap