Skip to content

Commit

Permalink
Merge pull request #8837 from khoaguin/remove-jax-haiku
Browse files Browse the repository at this point in the history
Remove Jax and Haiku. Use Torch instead
  • Loading branch information
khoaguin authored May 24, 2024
2 parents cb6dfe6 + f41c592 commit 870456b
Show file tree
Hide file tree
Showing 16 changed files with 339 additions and 286 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ domain_client = sy.login(
- <a href="notebooks/api/0.8/01-submit-code.ipynb">01-submit-code.ipynb</a>
- <a href="notebooks/api/0.8/02-review-code-and-approve.ipynb">02-review-code-and-approve.ipynb</a>
- <a href="notebooks/api/0.8/03-data-scientist-download-result.ipynb">03-data-scientist-download-result.ipynb</a>
- <a href="notebooks/api/0.8/04-jax-example.ipynb">04-jax-example.ipynb</a>
- <a href="notebooks/api/0.8/04-pytorch-example.ipynb">04-pytorch-example.ipynb</a>
- <a href="notebooks/api/0.8/05-custom-policy.ipynb">05-custom-policy.ipynb</a>
- <a href="notebooks/api/0.8/06-multiple-code-requests.ipynb">06-multiple-code-requests.ipynb</a>
- <a href="notebooks/api/0.8/07-domain-register-control-flow.ipynb">07-domain-register-control-flow.ipynb</a>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
"outputs": [],
"source": [
"# third party\n",
"import haiku as hk\n",
"import jax\n",
"from jax import random\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"# syft absolute\n",
"import syft as sy\n",
Expand All @@ -43,7 +43,7 @@
},
"outputs": [],
"source": [
"node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True)"
"node = sy.orchestra.launch(name=\"test-domain-1\", dev_mode=True, reset=True)"
]
},
{
Expand All @@ -67,7 +67,8 @@
},
"outputs": [],
"source": [
"key = random.PRNGKey(42)"
"# Set the random seed for reproducibility\n",
"torch.manual_seed(42)"
]
},
{
Expand All @@ -79,19 +80,19 @@
},
"outputs": [],
"source": [
"train_data = random.uniform(key, shape=(4, 28, 28, 1))"
"# Generate random data\n",
"train_data = torch.rand((4, 28, 28, 1))\n",
"train_data.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"assert round(train_data.sum()) == 1602"
"assert torch.round(train_data.sum()) == 1557"
]
},
{
Expand Down Expand Up @@ -127,55 +128,60 @@
},
"outputs": [],
"source": [
"train_domain_obj = domain_client.api.services.action.set(train)"
"train_domain_obj = domain_client.api.services.action.set(train)\n",
"type(train_domain_obj)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"assert torch.round(train_domain_obj.syft_action_data.sum()) == 1557"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"class MLP(hk.Module):\n",
" def __init__(self, out_dims, name=None):\n",
" super().__init__(name=name)\n",
"class MLP(nn.Module):\n",
" def __init__(self, out_dims):\n",
" super().__init__()\n",
" self.out_dims = out_dims\n",
" self.linear1 = nn.Linear(784, 128)\n",
" self.linear2 = nn.Linear(128, out_dims)\n",
"\n",
" def __call__(self, x):\n",
" x = x.reshape((x.shape[0], -1))\n",
" x = hk.Linear(128)(x)\n",
" x = jax.nn.relu(x)\n",
" x = hk.Linear(self.out_dims)(x)\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = self.linear1(x)\n",
" x = F.relu(x)\n",
" x = self.linear2(x)\n",
" return x\n",
"\n",
"\n",
"def _forward_fn_linear1(x):\n",
" module = MLP(out_dims=10)\n",
" return module(x)\n",
"\n",
"\n",
"model = hk.transform(_forward_fn_linear1)"
"model = MLP(out_dims=10)\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {
"tags": []
},
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"weights = model.init(key, train.syft_action_data)"
"weights = model.state_dict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"id": "13",
"metadata": {
"tags": []
},
Expand All @@ -187,7 +193,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"id": "14",
"metadata": {
"tags": []
},
Expand All @@ -199,7 +205,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"id": "15",
"metadata": {
"tags": []
},
Expand All @@ -211,7 +217,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"id": "16",
"metadata": {
"tags": []
},
Expand All @@ -223,7 +229,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"id": "17",
"metadata": {
"tags": []
},
Expand All @@ -235,35 +241,42 @@
")\n",
"def train_mlp(weights, data):\n",
" # third party\n",
" import haiku as hk\n",
" import jax\n",
" import torch\n",
" import torch.nn as nn\n",
" import torch.nn.functional as F\n",
"\n",
" class MLP(hk.Module):\n",
" def __init__(self, out_dims, name=None):\n",
" super().__init__(name=name)\n",
" class MLP(nn.Module):\n",
" def __init__(self, out_dims):\n",
" super().__init__()\n",
" self.out_dims = out_dims\n",
" self.linear1 = nn.Linear(784, 128)\n",
" self.linear2 = nn.Linear(128, out_dims)\n",
"\n",
" def __call__(self, x):\n",
" x = x.reshape((x.shape[0], -1))\n",
" x = hk.Linear(128)(x)\n",
" x = jax.nn.relu(x)\n",
" x = hk.Linear(self.out_dims)(x)\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = self.linear1(x)\n",
" x = F.relu(x)\n",
" x = self.linear2(x)\n",
" return x\n",
"\n",
" def _forward_fn_linear1(x):\n",
" module = MLP(out_dims=10)\n",
" return module(x)\n",
" # Initialize the model\n",
" model = MLP(out_dims=10)\n",
"\n",
" # Load weights into the model\n",
" model.load_state_dict(weights)\n",
"\n",
" # Perform a forward pass\n",
" model.eval() # Set the model to evaluation mode\n",
" with torch.no_grad(): # Disable gradient calculation\n",
" output = model(data)\n",
"\n",
" model = hk.transform(_forward_fn_linear1)\n",
" rng_key = jax.random.PRNGKey(42)\n",
" output = model.apply(params=weights, x=data, rng=rng_key)\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"id": "18",
"metadata": {
"tags": []
},
Expand All @@ -276,19 +289,17 @@
{
"cell_type": "code",
"execution_count": null,
"id": "18",
"metadata": {
"tags": []
},
"id": "19",
"metadata": {},
"outputs": [],
"source": [
"assert round(output.sum(), 2) == -0.86"
"assert torch.allclose(torch.sum(output), torch.tensor(1.3907))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19",
"id": "20",
"metadata": {
"tags": []
},
Expand All @@ -301,7 +312,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "20",
"id": "21",
"metadata": {
"tags": []
},
Expand All @@ -313,7 +324,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "21",
"id": "22",
"metadata": {
"tags": []
},
Expand All @@ -326,7 +337,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "22",
"id": "23",
"metadata": {
"tags": []
},
Expand All @@ -338,7 +349,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "23",
"id": "24",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -348,19 +359,19 @@
{
"cell_type": "code",
"execution_count": null,
"id": "24",
"id": "25",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"assert round(float(result.sum()), 2) == -0.86"
"assert torch.allclose(torch.sum(result), torch.tensor(1.3907))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25",
"id": "26",
"metadata": {
"tags": []
},
Expand All @@ -373,7 +384,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "26",
"id": "27",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -395,7 +406,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.12.2"
},
"toc": {
"base_numbering": 1,
Expand Down
2 changes: 1 addition & 1 deletion notebooks/tutorials/deployments/01-deploy-python.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"- [01-submit-code.ipynb](../../api/0.8/01-submit-code.ipynb)\n",
"- [02-review-code-and-approve.ipynb](../../api/0.8/02-review-code-and-approve.ipynb)\n",
"- [03-data-scientist-download-result.ipynb](../../api/0.8/03-data-scientist-download-result.ipynb)\n",
"- [04-jax-example.ipynb](../../api/0.8/04-jax-example.ipynb)\n",
"- [04-pytorch-example.ipynb](../../api/0.8/04-pytorch-example.ipynb)\n",
"- [05-custom-policy.ipynb](../../api/0.8/05-custom-policy.ipynb)\n",
"- [06-multiple-code-requests.ipynb](../../api/0.8/06-multiple-code-requests.ipynb)\n",
"- [07-domain-register-control-flow.ipynb](../../api/0.8/07-domain-register-control-flow.ipynb)\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/tutorials/deployments/02-deploy-container.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
"- [01-submit-code.ipynb](../../api/0.8/01-submit-code.ipynb)\n",
"- [02-review-code-and-approve.ipynb](../../api/0.8/02-review-code-and-approve.ipynb)\n",
"- [03-data-scientist-download-result.ipynb](../../api/0.8/03-data-scientist-download-result.ipynb)\n",
"- [04-jax-example.ipynb](../../api/0.8/04-jax-example.ipynb)\n",
"- [04-pytorch-example.ipynb](../../api/0.8/04-pytorch-example.ipynb)\n",
"- [05-custom-policy.ipynb](../../api/0.8/05-custom-policy.ipynb)\n",
"- [06-multiple-code-requests.ipynb](../../api/0.8/06-multiple-code-requests.ipynb)\n",
"- [07-domain-register-control-flow.ipynb](../../api/0.8/07-domain-register-control-flow.ipynb)\n",
Expand Down
7 changes: 6 additions & 1 deletion notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
"- [01-submit-code.ipynb](../../api/0.8/01-submit-code.ipynb)\n",
"- [02-review-code-and-approve.ipynb](../../api/0.8/02-review-code-and-approve.ipynb)\n",
"- [03-data-scientist-download-result.ipynb](../../api/0.8/03-data-scientist-download-result.ipynb)\n",
"- [04-jax-example.ipynb](../../api/0.8/04-jax-example.ipynb)\n",
"- [04-pytorch-example.ipynb](../../api/0.8/04-pytorch-example.ipynb)\n",
"- [05-custom-policy.ipynb](../../api/0.8/05-custom-policy.ipynb)\n",
"- [06-multiple-code-requests.ipynb](../../api/0.8/06-multiple-code-requests.ipynb)\n",
"- [07-domain-register-control-flow.ipynb](../../api/0.8/07-domain-register-control-flow.ipynb)\n",
Expand All @@ -167,6 +167,11 @@
"\n",
"Feel free to explore these notebooks to get started with PySyft and unlock its full potential for privacy-preserving machine learning!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 870456b

Please sign in to comment.