Skip to content

Commit

Permalink
change variables
Browse files Browse the repository at this point in the history
  • Loading branch information
porteratzo committed Jun 5, 2024
1 parent d1b0b1e commit 0a2ce17
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions openfl-tutorials/experimental/Phi3/Workflow_Interface_Phi3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@
" test_size=val_set_size, shuffle=True, seed=42\n",
")\n",
"\n",
"processed_train_dataset = train_val[\"train\"].shuffle().map(generate_and_tokenize_prompt)\n",
"processed_test_dataset = train_val[\"test\"].shuffle().map(generate_and_tokenize_prompt)\n"
"processed_train_dataset = train_val[\"train\"].shuffle().map(generate_and_tokenize_prompt).select(range(3))\n",
"processed_test_dataset = train_val[\"test\"].shuffle().map(generate_and_tokenize_prompt).select(range(3))\n"
]
},
{
Expand Down Expand Up @@ -482,6 +482,7 @@
" print(f\"Average local model validation values = {self.local_model_accuracy}\")\n",
"\n",
" self.model = FedAvg([input.peft_params for input in inputs], self.model)\n",
" self.peft_params = get_peft_model_state_dict(self.model)\n",
"\n",
" self.model.save_pretrained(\"./aggregated/model\")\n",
" tokenizer.save_pretrained(\"./aggregated/tokenizer\")\n",
Expand All @@ -490,7 +491,7 @@
" self.next(\n",
" self.aggregated_model_validation,\n",
" foreach=\"collaborators\",\n",
" exclude=[\"private\"],\n",
" exclude=[\"model\"],\n",
" )\n",
" else:\n",
" self.next(self.end)\n",
Expand Down Expand Up @@ -518,29 +519,29 @@
"outputs": [],
"source": [
"# Setup participants\n",
"aggregator = Aggregator()\n",
"aggregator.private_attributes = {}\n",
"_aggregator = Aggregator()\n",
"_aggregator.private_attributes = {}\n",
"\n",
"# Setup collaborators with private attributes\n",
"collaborator_names = [\n",
" \"Portland\",\n",
" \"Seattle\",\n",
"]\n",
"collaborators = [Collaborator(name=name) for name in collaborator_names]\n",
"_collaborators = [Collaborator(name=name) for name in collaborator_names]\n",
"\n",
"for idx, current_collaborator in enumerate(collaborators):\n",
"for idx, current_collaborator in enumerate(_collaborators):\n",
" # Set the private attributes of the Collaborator to include their specific training and testing data loaders\n",
" current_collaborator.private_attributes = {\n",
" \"train_dataset\": processed_train_dataset.shard(\n",
" num_shards=len(collaborators), index=idx\n",
" num_shards=len(_collaborators), index=idx\n",
" ),\n",
" \"eval_dataset\": processed_test_dataset.shard(\n",
" num_shards=len(collaborators), index=idx\n",
" num_shards=len(_collaborators), index=idx\n",
" ),\n",
" }\n",
"\n",
"local_runtime = LocalRuntime(\n",
" aggregator=aggregator, collaborators=collaborators, backend=\"single_process\"\n",
" aggregator=_aggregator, collaborators=_collaborators, backend=\"single_process\"\n",
")\n",
"print(f\"Local runtime collaborators = {local_runtime.collaborators}\")"
]
Expand Down

0 comments on commit 0a2ce17

Please sign in to comment.