From 087d2f0e594bde65737a4148c224acbb864e8def Mon Sep 17 00:00:00 2001 From: KeremP Date: Thu, 27 Apr 2023 23:54:49 -0400 Subject: [PATCH 1/2] get logits at each token as numpy array --- pygptj/model.py | 17 ++++++++++++++++- src/gptj.cpp | 11 +++++++++-- src/gptj.h | 2 +- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/pygptj/model.py b/pygptj/model.py index cc5bdb0..a0feeb5 100644 --- a/pygptj/model.py +++ b/pygptj/model.py @@ -18,6 +18,7 @@ import sys import _pygptj as pp from pygptj._logger import set_log_level +import numpy as np class Model: @@ -36,6 +37,7 @@ def new_text_callback(text): ``` """ _new_text_callback = None + _logits_callback = None def __init__(self, model_path: str, @@ -64,6 +66,8 @@ def __init__(self, self.res = "" + self.logits = [] + def _load_model(self): """ Helper function to load the model @@ -84,10 +88,18 @@ def _call_new_text_callback(self, text_bytes) -> None: except UnicodeDecodeError: logging.warning(f"UnicodeDecodeError of bytes {text_bytes}") # save res + + def _call_logits_callback(self, logits: np.ndarray): + if Model._logits_callback is not None: + self.logits.append(logits.tolist()) + + def braindump(self, path: str): + np.save(path, np.asarray(self.logits)) def generate(self, prompt: str, new_text_callback: Callable[[str], None] = None, + logits_callback: Callable = None, n_predict: int = 128, seed: int = -1, n_threads: int = 4, @@ -124,8 +136,11 @@ def generate(self, self.res = "" Model._new_text_callback = new_text_callback + # assign _logits_callback used for saving logits, token by token + Model._logits_callback = logits_callback + # run the prediction - pp.gptj_generate(self.gpt_params, self._model, self._vocab, self._call_new_text_callback) + pp.gptj_generate(self.gpt_params, self._model, self._vocab, self._call_new_text_callback, self._call_logits_callback) return self.res @staticmethod diff --git a/src/gptj.cpp b/src/gptj.cpp index 97a1fc2..d0c4e80 100644 --- a/src/gptj.cpp +++ b/src/gptj.cpp @@ -21,6 +21,7 @@ #include #include +#include namespace py = pybind11; @@ -595,7 +596,7 @@ bool gptj_eval( return true; } -int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback) { +int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback, py::function logits_callback) { // auto model = context->model; // auto vocab = context->vocab; @@ -666,7 +667,9 @@ int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab printf("Failed to predict\n"); return 1; } - + // collect logits for each token + py::array_t _logits = py::array_t{model.hparams.n_vocab, logits.data(), py::none()}; + logits_callback(_logits); t_predict_us += ggml_time_us() - t_start_us; } @@ -729,6 +732,10 @@ int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); } + // py::array_t _logits = py::cast(logits.data()); + // py::array_t _logits = py::array_t{50400, logits.data(), py::none()}; + // // printf("%d",logits.size()); + // logits_callback(_logits); ggml_free(model.ctx); return 0; diff --git a/src/gptj.h b/src/gptj.h index 77d87d0..d48c8cb 100644 --- a/src/gptj.h +++ b/src/gptj.h @@ -85,4 +85,4 @@ bool gptj_eval( std::vector & embd_w, size_t & mem_per_token); -int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback); \ No newline at end of file +int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback, py::function logits_callback); \ No newline at end of file From 9ec6bfd4920a8564d9e8fc85a6ba9f3691937742 Mon Sep 17 00:00:00 2001 From: KeremP Date: Fri, 28 Apr 2023 00:04:32 -0400 Subject: [PATCH 2/2] minor changes. --- pygptj/model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pygptj/model.py b/pygptj/model.py index a0feeb5..a121f54 100644 --- a/pygptj/model.py +++ b/pygptj/model.py @@ -90,8 +90,14 @@ def _call_new_text_callback(self, text_bytes) -> None: # save res def _call_logits_callback(self, logits: np.ndarray): + """ + Internal logits_callback that saves the logit representation at each token. + :return: None + """ + self.logits.append(logits.tolist()) + if Model._logits_callback is not None: - self.logits.append(logits.tolist()) + Model._logits_callback(logits) def braindump(self, path: str): np.save(path, np.asarray(self.logits))