Skip to content

Commit

Permalink
return inference information
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Dec 6, 2024
1 parent eef8e92 commit a2faac9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ class BioNeMoConstants(object):
TASK_INFERENCE = "bionemo_inference"
NUMBER_SEQUENCES = "bionemo_number_sequences"
DATA_INFO = "bionemo_data_info"
#CONFIG = "bionemo_config"
RESULT_SHAPES = "bionemo_result_shapes"
24 changes: 20 additions & 4 deletions examples/advanced/bionemo/task_fitting/src/bionemo_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import torch

from nvflare.apis.dxo import MetaKey, DXO, from_shareable
from nvflare.apis.event_type import EventType
Expand Down Expand Up @@ -67,14 +69,28 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
if task_name == self._task_name:
if not self.is_initialized:
self._init_launcher(fl_ctx)

success = self._launch_script(fl_ctx)

if success:
n_sequences = 1 # TODO: get from inference command

# Get results path from inference script arguments
args = self.launcher._script.split()
results_path = args[args.index("--results-path")+1]
if os.path.isfile(results_path):
self.log_info(fl_ctx, f"Get result info from: {results_path}")
results = torch.load(results_path)

result_shapes = {}
for k, v in results.items():
if v is not None:
result_shapes[k] = list(v.shape) # turn torch Size type into a simple list for sharing with server

n_sequences = len(results["embeddings"])
else:
n_sequences, result_shapes = "n/a", "n/a"

# Prepare a DXO for our updated model. Create shareable and return
data_info = {BioNeMoConstants.NUMBER_SEQUENCES: n_sequences}
data_info = {BioNeMoConstants.NUMBER_SEQUENCES: n_sequences, BioNeMoConstants.RESULT_SHAPES: result_shapes}

outgoing_dxo = DXO(data_kind=BioNeMoConstants.DATA_INFO, data=data_info)
return outgoing_dxo.to_shareable()
Expand Down

0 comments on commit a2faac9

Please sign in to comment.