Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get logits as numpy array #4

Merged
merged 3 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion pygptj/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import _pygptj as pp
from pygptj._logger import set_log_level
import numpy as np


class Model:
Expand All @@ -36,6 +37,7 @@ def new_text_callback(text):
```
"""
_new_text_callback = None
_logits_callback = None

def __init__(self,
model_path: str,
Expand Down Expand Up @@ -64,6 +66,8 @@ def __init__(self,

self.res = ""

self.logits = []

def _load_model(self):
"""
Helper function to load the model
Expand All @@ -84,10 +88,24 @@ def _call_new_text_callback(self, text_bytes) -> None:
except UnicodeDecodeError:
logging.warning(f"UnicodeDecodeError of bytes {text_bytes}")
# save res

def _call_logits_callback(self, logits: np.ndarray):
"""
Internal logits_callback that saves the logit representation at each token.
:return: None
"""
self.logits.append(logits.tolist())

if Model._logits_callback is not None:
Model._logits_callback(logits)

def braindump(self, path: str):
np.save(path, np.asarray(self.logits))

def generate(self,
prompt: str,
new_text_callback: Callable[[str], None] = None,
logits_callback: Callable = None,
n_predict: int = 128,
seed: int = -1,
n_threads: int = 4,
Expand Down Expand Up @@ -124,8 +142,11 @@ def generate(self,
self.res = ""
Model._new_text_callback = new_text_callback

# assign _logits_callback used for saving logits, token by token
Model._logits_callback = logits_callback

# run the prediction
pp.gptj_generate(self.gpt_params, self._model, self._vocab, self._call_new_text_callback)
pp.gptj_generate(self.gpt_params, self._model, self._vocab, self._call_new_text_callback, self._call_logits_callback)
return self.res

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions src/gptj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <iostream>

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;


Expand Down Expand Up @@ -595,7 +596,7 @@ bool gptj_eval(
return true;
}

int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback) {
int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback, py::function logits_callback) {
// auto model = context->model;
// auto vocab = context->vocab;

Expand Down Expand Up @@ -656,7 +657,9 @@ int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab
printf("Failed to predict\n");
return 1;
}

// collect logits for each token
py::array_t<float> _logits = py::array_t<float>{model.hparams.n_vocab, logits.data(), py::none()};
logits_callback(_logits);
t_predict_us += ggml_time_us() - t_start_us;
}

Expand Down
2 changes: 1 addition & 1 deletion src/gptj.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ bool gptj_eval(
std::vector<float> & embd_w,
size_t & mem_per_token);

int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback);
int gptj_generate(gpt_params params, struct gptj_model & model, struct gpt_vocab & vocab, py::function new_text_callback, py::function logits_callback);