Skip to content

Commit

Permalink
Merge pull request #48 from AutoResearch/various-fixes
Browse files Browse the repository at this point in the history
docs: restructure documentation
  • Loading branch information
younesStrittmatter authored Dec 1, 2024
2 parents 9ebeca9 + cd76014 commit bc7effc
Show file tree
Hide file tree
Showing 8 changed files with 489 additions and 49 deletions.
364 changes: 364 additions & 0 deletions docs/Use Case Tutorials/AI Alignment/Reinforcement Learning.ipynb

Large diffs are not rendered by default.

19 changes: 18 additions & 1 deletion src/sweetbean/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,26 @@ def process_js(js_data):
for d in js_data:
res_dict = {}
for key, value in d.items():
if key in ["rt", "stimulus", "type"]:
if key in ["rt", "stimulus", "type", "response"]:
res_dict[key] = value
if key.startswith("bean_"):
res_dict[key[5:]] = value
res.append(res_dict)
return res


def get_n_responses(data):
n = 0
for d in data:
if "response" in d and d["response"] is not None:
n += 1
return n


def until_response(data, n):
i = 0
for idx, d in enumerate(data):
if "response" in d and d["response"] is not None:
i += 1
if i > n:
return data[:idx]
64 changes: 38 additions & 26 deletions src/sweetbean/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
TEXT_APPENDIX,
)
from sweetbean.block import Block
from sweetbean.variable import CodeVariable, SharedVariable


class Experiment:
Expand All @@ -23,11 +22,9 @@ def to_js(self, path_local_download=None):
for b in self.blocks:
b.to_js()
for s in b.stimuli:
for key in s.arg:
if isinstance(s.arg[key], SharedVariable) or isinstance(
s.arg[key], CodeVariable
):
self.js += f"{s.arg[key].set()}\n"
shared_variables = s.return_shared_variables()
for s_key in shared_variables:
self.js += f"{shared_variables[s_key].set()}\n"
if path_local_download:
self.js += (
"jsPsych = initJsPsych("
Expand Down Expand Up @@ -60,11 +57,9 @@ def to_js_string(self, as_function=True, is_async=True):
for b in self.blocks:
b.to_js()
for s in b.stimuli:
for key in s.arg:
if isinstance(s.arg[key], SharedVariable) or isinstance(
s.arg[key], CodeVariable
):
text += f"{s.arg[key].set()}\n"
shared_variables = s.return_shared_variables()
for s_key in shared_variables:
text += f"{shared_variables[s_key].set()}\n"
text += "const jsPsych = initJsPsych()\n"
text += "const trials = [\n"
for b in self.blocks:
Expand All @@ -78,43 +73,60 @@ def run_on_language(
self,
get_input=input,
multi_turn=False,
data=None,
):
data = []
out_data = []
prompts = []
shared_variables = {}
for b in self.blocks:
for s in b.stimuli:
for key in s.arg:
if isinstance(s.arg[key], SharedVariable) or isinstance(
s.arg[key], CodeVariable
):
shared_variables[s.arg[key].name] = s.arg[key].value

_shared_variables = s.return_shared_variables()
for s_key in _shared_variables:
shared_variables[s_key] = _shared_variables[s_key].value
datum_index = 0
for b in self.blocks:
timeline = b.timeline
stimuli = b.stimuli
if not timeline:
timeline = [{}]
for timeline_element in timeline:
data, prompts, shared_variables = run_stimuli(
out_data, prompts, shared_variables, datum_index = run_stimuli(
stimuli,
timeline_element,
data,
out_data,
shared_variables,
prompts,
get_input,
multi_turn,
datum_index,
data,
)
return data, prompts
return out_data, prompts


def run_stimuli(
stimuli, timeline_element, data, shared_variables, prompts, get_input, multi_turn
stimuli,
timeline_element,
out_data,
shared_variables,
prompts,
get_input,
multi_turn,
datum_index,
data,
):
for s in stimuli:
s._prepare_args_l(timeline_element, data, shared_variables)
s_data, prompts = s.process_l(prompts, get_input, multi_turn)
data.append(s_data)
if datum_index < len(data):
datum = data[datum_index]
else:
datum = None

s._prepare_args_l(timeline_element, out_data, shared_variables, datum)
s_out_data, prompts = s.process_l(prompts, get_input, multi_turn, datum)
out_data.append(s_out_data)
if s.side_effects:
s._resolve_side_effects(timeline_element, data, shared_variables)
s._resolve_side_effects(timeline_element, out_data, shared_variables)
shared_variables.update(s.l_ses)
return data, prompts, shared_variables
datum_index += 1
return out_data, prompts, shared_variables, datum_index
7 changes: 5 additions & 2 deletions src/sweetbean/stimulus/Choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _set_before(self):
)
self.js_before = f"on_load:()=>{{{res}}},"

def process_l(self, prompts, get_input, multi_turn):
def process_l(self, prompts, get_input, multi_turn, datum=None):
current_prompt = f' You see {len(self.l_args["bandits"])} bandits.'
for idx, bandit in enumerate(self.l_args["bandits"]):
current_prompt += f' Bandit {idx + 1} is {bandit["color"]}.'
Expand All @@ -132,7 +132,10 @@ def process_l(self, prompts, get_input, multi_turn):
in_prompt = " ".join([p for p in prompts]) + current_prompt + "<<"
else:
in_prompt = current_prompt + "<<"
response = get_input(in_prompt)
if not datum:
response = get_input(in_prompt)
else:
response = datum["response"] + 1
if int(response) < 1 or int(response) > len(self.l_args["bandits"]):
prompts.append(
current_prompt + f"<<{response}>>. " f"The response was invalid."
Expand Down
2 changes: 1 addition & 1 deletion src/sweetbean/stimulus/RO.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _process_response(self):
def _set_before(self):
pass

def process_l(self, prompts, get_input, multi_turn):
def process_l(self, prompts, get_input, multi_turn, datum=None):
raise NotImplementedError


Expand Down
55 changes: 45 additions & 10 deletions src/sweetbean/stimulus/Stimulus.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,26 @@ def __init__(self, args, side_effects=None):
self.arg_js["trial_duration"] = self.arg["duration"]
for key in self.arg:
self.arg_js[key] = args[key]
# self.to_js()

def return_shared_variables(self):
shared_variables = {}
for key in self.arg:

def extract_shared_variables(value):
if isinstance(value, SharedVariable):
shared_variables[value.name] = value
elif isinstance(value, dict):
for v in value.values():
extract_shared_variables(v)
elif isinstance(value, list):
for item in value:
extract_shared_variables(item)
elif isinstance(value, FunctionVariable):
for arg in value.args:
extract_shared_variables(arg)

extract_shared_variables(self.arg[key])
return shared_variables

def to_js(self):
self.js = ""
Expand All @@ -58,19 +77,28 @@ def _params_to_js(self):
self.js_body += f'type: {self.arg["type"]},'
for key in self.arg_js:
self._param_to_js(key, self.arg_js[key])
for key in self.arg:
if key not in self.arg_js:
self._param_to_js_arg(key, self.arg[key])
self._add_special_param()
self._process_response()
self._set_before()
if self.side_effects:
self._set_side_effects()

def _prepare_args_l(self, timeline_element, data, shared_variables):
self.l_args = {}
self.l_ses = {}
for key, value in self.arg.items():
key_ = key
value_ = _parse_variable(value, timeline_element, data, shared_variables)
self.l_args[key_] = value_
def _prepare_args_l(self, timeline_element, data, shared_variables, datum=None):
if not datum:
self.l_args = {}
self.l_ses = {}
for key, value in self.arg.items():
key_ = key
value_ = _parse_variable(
value, timeline_element, data, shared_variables
)
self.l_args[key_] = value_
else:
for key, value in datum.items():
self.l_args[key] = value

def _resolve_side_effects(self, timeline_element, data, shared_variables):
if self.side_effects:
Expand All @@ -80,7 +108,7 @@ def _resolve_side_effects(self, timeline_element, data, shared_variables):
)
self.l_ses[se.set_variable.name] = get_variable

def process_l(self, prompts, get_input, multi_turn):
def process_l(self, prompts, get_input, multi_turn, datum=None):
prompts.append(self._get_prompt_l())
prompt_response = self._get_response_prompt_l()
s_data = {}
Expand All @@ -91,7 +119,10 @@ def process_l(self, prompts, get_input, multi_turn):
_in_prompt = prompts[-1]
else:
_in_prompt = " ".join([p for p in prompts])
response = get_input(_in_prompt).upper()
if not datum:
response = get_input(_in_prompt).upper()
else:
response = datum["response"].upper()
s_data = self._process_response_l(response)
prompts[-1] += f"{response}>>"
data.update(s_data)
Expand All @@ -114,6 +145,10 @@ def _param_to_js(self, key, param):
self.js_body += body
self.js_data += data

def _param_to_js_arg(self, key, param):
_, data = _set_param_js(key, param)
self.js_data += data

def _set_side_effects(self):
for se in self.side_effects:
self.js_data += se.to_js()
Expand Down
25 changes: 17 additions & 8 deletions src/sweetbean/stimulus/Survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_prompts(_prompts):
questions_ = FunctionVariable("questions", get_prompts, [questions])
super().__init__(questions_, side_effects=side_effects)

def process_l(self, prompts, get_input, multi_turn):
def process_l(self, prompts, get_input, multi_turn, datum=None):
current_prompt = []
responses = {}
data = self.l_args.copy()
Expand All @@ -46,9 +46,12 @@ def process_l(self, prompts, get_input, multi_turn):
)
else:
_in_prompt = current_prompt[-1] + "<<"
response = get_input(_in_prompt)
if not datum:
response = get_input(_in_prompt)
else:
response = datum["response"][f"Q{str(idx)}"]
current_prompt[-1] += f"<<{response}>>"
responses[f"Q_{str(idx + 1)}"] = response
responses[f"Q{str(idx)}"] = response
data.update({"response": responses})
prompts += current_prompt
return data, prompts
Expand All @@ -70,7 +73,7 @@ def get_prompts(_prompts):
questions_ = FunctionVariable("questions", get_prompts, [questions])
super().__init__(questions_, side_effects=side_effects)

def process_l(self, prompts, get_input, multi_turn):
def process_l(self, prompts, get_input, multi_turn, datum=None):
current_prompt = []
responses = {}
data = self.l_args.copy()
Expand All @@ -87,7 +90,10 @@ def process_l(self, prompts, get_input, multi_turn):
)
else:
_in_prompt = current_prompt[-1] + "<<"
response = get_input(_in_prompt)
if not datum:
response = get_input(_in_prompt)
else:
response = datum["response"][f"Q{str(idx)}"]
current_prompt[-1] += f"<<{response}>>"
responses[f"Q{str(idx)}"] = response
data.update({"response": responses})
Expand Down Expand Up @@ -125,7 +131,7 @@ def from_scale(cls, prompts=None, scale=None, side_effects=None):
prompts_.append({"prompt": p, "labels": scale})
return cls(prompts_, side_effects=side_effects)

def process_l(self, prompts, get_input, multi_turn):
def process_l(self, prompts, get_input, multi_turn, datum=None):
current_prompt = []
responses = {}
data = self.l_args.copy()
Expand All @@ -142,9 +148,12 @@ def process_l(self, prompts, get_input, multi_turn):
)
else:
_in_prompt = current_prompt[-1] + "<<"
response = get_input(_in_prompt)
if not datum:
response = get_input(_in_prompt)
else:
response = datum["response"][f"Q{str(idx)}"]
current_prompt[-1] += f"<<{response}>>"
responses[f"Q_{str(idx + 1)}"] = response
responses[f"Q{str(idx)}"] = response
data.update({"response": responses})
prompts += current_prompt
return data, prompts
2 changes: 1 addition & 1 deletion src/sweetbean/stimulus/_Template_.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, args, side_effects=None):
"""
super().__init__(args, side_effects)

def process_l(self, prompts, get_input, multi_turn):
def process_l(self, prompts, get_input, multi_turn, datum):
"""
This is used to process the arguments, generate a prompt and
get a response from language input
Expand Down

0 comments on commit bc7effc

Please sign in to comment.