diff --git a/docs/Use Case Tutorials/AI Alignment/Reinforcement Learning.ipynb b/docs/Use Case Tutorials/AI Alignment/Reinforcement Learning.ipynb new file mode 100644 index 00000000..3297a107 --- /dev/null +++ b/docs/Use Case Tutorials/AI Alignment/Reinforcement Learning.ipynb @@ -0,0 +1,364 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Reinforcement Learning\n", + "\n", + "Here, we will implement a simple two armed bandit task. We then run the same task on a language model specifically trained on tasks like these ([centaur](https://marcelbinz.github.io/centaur/)) and compare the results.\n", + "\n", + "## Two-Armed Bandit Task\n", + "\n", + "### Imports" + ], + "id": "9b59ed742e18a0f5" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:55:31.993583Z", + "start_time": "2024-12-01T18:55:31.966327Z" + } + }, + "cell_type": "code", + "source": [ + "from sweetbean import Block, Experiment\n", + "from sweetbean.stimulus import Bandit, Text\n", + "from sweetbean.variable import (\n", + " DataVariable,\n", + " FunctionVariable,\n", + " SharedVariable,\n", + " SideEffect,\n", + " TimelineVariable,\n", + ")" + ], + "id": "465d8af802c206ec", + "outputs": [], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Timeline\n", + "\n", + "Here, we slowly change the values of `bandit_1` 10 to 0 and for `bandit_2` in reverse order from 0 to 10.\n" + ], + "id": "838416866ba97dbd" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:55:32.918504Z", + "start_time": "2024-12-01T18:55:32.916393Z" + } + }, + "cell_type": "code", + "source": [ + "timeline = []\n", + "for i in range(11):\n", + " timeline.append(\n", + " {\n", + " \"bandit_1\": {\"color\": \"orange\", \"value\": 10 - i},\n", + " \"bandit_2\": {\"color\": \"blue\", \"value\": i},\n", + " }\n", + " )" + ], + "id": "98d345f996176fad", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Implementation\n", + "\n", + "We also keep track of the score with a shared variable to present it between the bandit tasks." + ], + "id": "3496e0ca5f0ee9ce" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:55:34.101112Z", + "start_time": "2024-12-01T18:55:34.098339Z" + } + }, + "cell_type": "code", + "source": [ + "bandit_1 = TimelineVariable(\"bandit_1\")\n", + "bandit_2 = TimelineVariable(\"bandit_2\")\n", + "\n", + "score = SharedVariable(\"score\", 0)\n", + "value = DataVariable(\"value\", 0)\n", + "\n", + "update_score = FunctionVariable(\n", + " \"update_score\", lambda sc, val: sc + val, [score, value]\n", + ")\n", + "\n", + "update_score_side_effect = SideEffect(score, update_score)\n", + "\n", + "bandit_task = Bandit(\n", + " bandits=[bandit_1, bandit_2],\n", + " side_effects=[update_score_side_effect],\n", + ")\n", + "\n", + "score_text = FunctionVariable(\"score_text\", lambda sc: f\"Score: {sc}\", [score])\n", + "\n", + "show_score = Text(duration=2000, text=score_text)\n", + "\n", + "trial_sequence = Block([bandit_task, show_score], timeline=timeline)\n", + "experiment = Experiment([trial_sequence])" + ], + "id": "6a1ada08dec8c348", + "outputs": [], + "execution_count": 3 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Export the experiment to a html file and run it in the browser.", + "id": "65ae4706ed556de1" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:45:14.067346Z", + "start_time": "2024-12-01T18:44:53.841870Z" + } + }, + "cell_type": "code", + "source": "experiment.to_html(\"bandit.html\", path_local_download=\"bandit.json\")", + "id": "80fbd261e9ae251a", + "outputs": [], + "execution_count": 4 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Results\n", + "After running bandit.html, there should be a file called `bandit.json` in the download directory. You can open the file in your browser to see the results. First, we process it so that it only contains relevant data:" + ], + "id": "94f8c0dff2ef6200" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:55:38.963190Z", + "start_time": "2024-12-01T18:55:38.959118Z" + } + }, + "cell_type": "code", + "source": [ + "import json\n", + "from sweetbean.data import process_js, get_n_responses, until_response\n", + "\n", + "with open(\"bandit.json\") as f:\n", + " data_raw = json.load(f)\n", + " \n", + "data = process_js(data_raw)" + ], + "id": "55d66aac9b404c84", + "outputs": [], + "execution_count": 4 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "", + "id": "d41ef6e70a2174c5" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "We can now get the number of times a response was made and get the data until before the third response:", + "id": "b1332cb3777464bf" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:55:40.224472Z", + "start_time": "2024-12-01T18:55:40.219974Z" + } + }, + "cell_type": "code", + "source": [ + "n_responses = get_n_responses(data)\n", + "data_third_response = until_response(data, 3)\n", + "data_third_response" + ], + "id": "89f6f6a3996ac454", + "outputs": [ + { + "data": { + "text/plain": [ + "[{'rt': 1154,\n", + " 'stimulus': ['
',\n", + " '
'],\n", + " 'response': 0,\n", + " 'trial_duration': None,\n", + " 'duration': None,\n", + " 'html_array': ['
',\n", + " '
'],\n", + " 'values': [10, 0],\n", + " 'time_after_response': 2000,\n", + " 'type': 'jsPsychHtmlChoice',\n", + " 'bandits': [{'color': 'orange', 'value': 10}, {'color': 'blue', 'value': 0}],\n", + " 'value': 10,\n", + " 'score': 10},\n", + " {'rt': None,\n", + " 'stimulus': \"
Score: 10
\",\n", + " 'response': None,\n", + " 'trial_duration': 2000,\n", + " 'duration': 2000,\n", + " 'choices': [],\n", + " 'correct_key': '',\n", + " 'type': 'jsPsychHtmlKeyboardResponse',\n", + " 'text': 'Score: 10',\n", + " 'color': 'white',\n", + " 'correct': False},\n", + " {'rt': 378,\n", + " 'stimulus': ['
',\n", + " '
'],\n", + " 'response': 0,\n", + " 'trial_duration': None,\n", + " 'duration': None,\n", + " 'html_array': ['
',\n", + " '
'],\n", + " 'values': [9, 1],\n", + " 'time_after_response': 2000,\n", + " 'type': 'jsPsychHtmlChoice',\n", + " 'bandits': [{'color': 'orange', 'value': 9}, {'color': 'blue', 'value': 1}],\n", + " 'value': 9,\n", + " 'score': 19},\n", + " {'rt': None,\n", + " 'stimulus': \"
Score: 19
\",\n", + " 'response': None,\n", + " 'trial_duration': 2000,\n", + " 'duration': 2000,\n", + " 'choices': [],\n", + " 'correct_key': '',\n", + " 'type': 'jsPsychHtmlKeyboardResponse',\n", + " 'text': 'Score: 19',\n", + " 'color': 'white',\n", + " 'correct': False},\n", + " {'rt': 360,\n", + " 'stimulus': ['
',\n", + " '
'],\n", + " 'response': 0,\n", + " 'trial_duration': None,\n", + " 'duration': None,\n", + " 'html_array': ['
',\n", + " '
'],\n", + " 'values': [8, 2],\n", + " 'time_after_response': 2000,\n", + " 'type': 'jsPsychHtmlChoice',\n", + " 'bandits': [{'color': 'orange', 'value': 8}, {'color': 'blue', 'value': 2}],\n", + " 'value': 8,\n", + " 'score': 27},\n", + " {'rt': None,\n", + " 'stimulus': \"
Score: 27
\",\n", + " 'response': None,\n", + " 'trial_duration': 2000,\n", + " 'duration': 2000,\n", + " 'choices': [],\n", + " 'correct_key': '',\n", + " 'type': 'jsPsychHtmlKeyboardResponse',\n", + " 'text': 'Score: 27',\n", + " 'color': 'white',\n", + " 'correct': False}]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Experiment on language model\n", + "\n", + "With the partial data, we can now run the experiment up to that point and then run the rest of the experiment on language input. To test this, we run it manually:" + ], + "id": "e7d854441716de90" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:55:56.760055Z", + "start_time": "2024-12-01T18:55:41.207411Z" + } + }, + "cell_type": "code", + "source": "data_new, _ = experiment.run_on_language(input, data=data_third_response)", + "id": "df2b43db78303c0", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hi\n" + ] + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-01T18:56:05.300751Z", + "start_time": "2024-12-01T18:56:05.298031Z" + } + }, + "cell_type": "code", + "source": "print(data_new)", + "id": "9619a73c6f650c7e", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "([{'rt': 1154, 'stimulus': ['
', '
'], 'response': 0, 'trial_duration': None, 'duration': None, 'html_array': ['
', '
'], 'values': [10, 0], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 10}, {'color': 'blue', 'value': 0}], 'value': 10, 'score': 10}, {'rt': None, 'stimulus': \"
Score: 10
\", 'response': None, 'trial_duration': 2000, 'duration': 2000, 'html_array': ['
', '
'], 'values': [10, 0], 'time_after_response': 2000, 'type': 'jsPsychHtmlKeyboardResponse', 'bandits': [{'color': 'orange', 'value': 10}, {'color': 'blue', 'value': 0}], 'value': 10, 'score': 10, 'choices': [], 'correct_key': '', 'text': 'Score: 10', 'color': 'white', 'correct': False}, {'rt': 378, 'stimulus': ['
', '
'], 'response': 0, 'trial_duration': None, 'duration': None, 'html_array': ['
', '
'], 'values': [9, 1], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 9}, {'color': 'blue', 'value': 1}], 'value': 9, 'score': 19, 'choices': [], 'correct_key': '', 'text': 'Score: 10', 'color': 'white', 'correct': False}, {'rt': None, 'stimulus': \"
Score: 19
\", 'response': None, 'trial_duration': 2000, 'duration': 2000, 'html_array': ['
', '
'], 'values': [9, 1], 'time_after_response': 2000, 'type': 'jsPsychHtmlKeyboardResponse', 'bandits': [{'color': 'orange', 'value': 9}, {'color': 'blue', 'value': 1}], 'value': 9, 'score': 19, 'choices': [], 'correct_key': '', 'text': 'Score: 19', 'color': 'white', 'correct': False}, {'rt': 360, 'stimulus': ['
', '
'], 'response': 0, 'trial_duration': None, 'duration': None, 'html_array': ['
', '
'], 'values': [8, 2], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 8}, {'color': 'blue', 'value': 2}], 'value': 8, 'score': 27, 'choices': [], 'correct_key': '', 'text': 'Score: 19', 'color': 'white', 'correct': False}, {'rt': None, 'stimulus': \"
Score: 27
\", 'response': None, 'trial_duration': 2000, 'duration': 2000, 'html_array': ['
', '
'], 'values': [8, 2], 'time_after_response': 2000, 'type': 'jsPsychHtmlKeyboardResponse', 'bandits': [{'color': 'orange', 'value': 8}, {'color': 'blue', 'value': 2}], 'value': 8, 'score': 27, 'choices': [], 'correct_key': '', 'text': 'Score: 27', 'color': 'white', 'correct': False}, {'duration': None, 'html_array': ['
', '
'], 'values': [7, 3], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 7}, {'color': 'blue', 'value': 3}], 'response': -1, 'value': 0}, {'duration': 2000, 'stimulus': \"
Score: 40
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 40', 'color': 'white'}, {'duration': None, 'html_array': ['
', '
'], 'values': [6, 4], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 6}, {'color': 'blue', 'value': 4}], 'response': -1, 'value': 0}, {'duration': 2000, 'stimulus': \"
Score: 50
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 50', 'color': 'white'}, {'duration': None, 'html_array': ['
', '
'], 'values': [5, 5], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 5}, {'color': 'blue', 'value': 5}], 'response': -1, 'value': 0}, {'duration': 2000, 'stimulus': \"
Score: 60
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 60', 'color': 'white'}, {'duration': None, 'html_array': ['
', '
'], 'values': [4, 6], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 4}, {'color': 'blue', 'value': 6}], 'response': -1, 'value': 0}, {'duration': 2000, 'stimulus': \"
Score: 70
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 70', 'color': 'white'}, {'duration': None, 'html_array': ['
', '
'], 'values': [3, 7], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 3}, {'color': 'blue', 'value': 7}], 'response': -1, 'value': 0}, {'duration': 2000, 'stimulus': \"
Score: 80
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 80', 'color': 'white'}, {'duration': None, 'html_array': ['
', '
'], 'values': [2, 8], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 2}, {'color': 'blue', 'value': 8}], 'response': 0, 'value': 2}, {'duration': 2000, 'stimulus': \"
Score: 90
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 90', 'color': 'white'}, {'duration': None, 'html_array': ['
', '
'], 'values': [1, 9], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 1}, {'color': 'blue', 'value': 9}], 'response': 0, 'value': 1}, {'duration': 2000, 'stimulus': \"
Score: 100
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 100', 'color': 'white'}, {'duration': None, 'html_array': ['
', '
'], 'values': [0, 10], 'time_after_response': 2000, 'type': 'jsPsychHtmlChoice', 'bandits': [{'color': 'orange', 'value': 0}, {'color': 'blue', 'value': 10}], 'response': 0, 'value': 0}, {'duration': 2000, 'stimulus': \"
Score: 110
\", 'choices': [], 'correct_key': '', 'type': 'jsPsychHtmlKeyboardResponse', 'text': 'Score: 110', 'color': 'white'}], [' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<1>>. The value of the chosen bandit was 10.', 'You see \"Score: 10\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<1>>. The value of the chosen bandit was 9.', 'You see \"Score: 19\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<1>>. The value of the chosen bandit was 8.', 'You see \"Score: 27\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<0>>. The response was invalid.', 'You see \"Score: 40\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<0>>. The response was invalid.', 'You see \"Score: 50\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<0>>. The response was invalid.', 'You see \"Score: 60\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<0>>. The response was invalid.', 'You see \"Score: 70\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<0>>. The response was invalid.', 'You see \"Score: 80\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<01>>. The value of the chosen bandit was 2.', 'You see \"Score: 90\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<1>>. The value of the chosen bandit was 1.', 'You see \"Score: 100\" in \"white\" for 2000ms.', ' You see 2 bandits. Bandit 1 is orange. Bandit 2 is blue. Choose a bandit by naming the number of the bandit. You name <<1>>. The value of the chosen bandit was 0.', 'You see \"Score: 110\" in \"white\" for 2000ms.'])\n" + ] + } + ], + "execution_count": 7 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "bed65e5637aef81f" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/sweetbean/data.py b/src/sweetbean/data.py index 487ff97b..1738a606 100644 --- a/src/sweetbean/data.py +++ b/src/sweetbean/data.py @@ -3,9 +3,26 @@ def process_js(js_data): for d in js_data: res_dict = {} for key, value in d.items(): - if key in ["rt", "stimulus", "type"]: + if key in ["rt", "stimulus", "type", "response"]: res_dict[key] = value if key.startswith("bean_"): res_dict[key[5:]] = value res.append(res_dict) return res + + +def get_n_responses(data): + n = 0 + for d in data: + if "response" in d and d["response"] is not None: + n += 1 + return n + + +def until_response(data, n): + i = 0 + for idx, d in enumerate(data): + if "response" in d and d["response"] is not None: + i += 1 + if i > n: + return data[:idx] diff --git a/src/sweetbean/experiment.py b/src/sweetbean/experiment.py index cb2a92f6..453ff347 100644 --- a/src/sweetbean/experiment.py +++ b/src/sweetbean/experiment.py @@ -8,7 +8,6 @@ TEXT_APPENDIX, ) from sweetbean.block import Block -from sweetbean.variable import CodeVariable, SharedVariable class Experiment: @@ -23,11 +22,9 @@ def to_js(self, path_local_download=None): for b in self.blocks: b.to_js() for s in b.stimuli: - for key in s.arg: - if isinstance(s.arg[key], SharedVariable) or isinstance( - s.arg[key], CodeVariable - ): - self.js += f"{s.arg[key].set()}\n" + shared_variables = s.return_shared_variables() + for s_key in shared_variables: + self.js += f"{shared_variables[s_key].set()}\n" if path_local_download: self.js += ( "jsPsych = initJsPsych(" @@ -60,11 +57,9 @@ def to_js_string(self, as_function=True, is_async=True): for b in self.blocks: b.to_js() for s in b.stimuli: - for key in s.arg: - if isinstance(s.arg[key], SharedVariable) or isinstance( - s.arg[key], CodeVariable - ): - text += f"{s.arg[key].set()}\n" + shared_variables = s.return_shared_variables() + for s_key in shared_variables: + text += f"{shared_variables[s_key].set()}\n" text += "const jsPsych = initJsPsych()\n" text += "const trials = [\n" for b in self.blocks: @@ -78,43 +73,60 @@ def run_on_language( self, get_input=input, multi_turn=False, + data=None, ): - data = [] + out_data = [] prompts = [] shared_variables = {} for b in self.blocks: for s in b.stimuli: - for key in s.arg: - if isinstance(s.arg[key], SharedVariable) or isinstance( - s.arg[key], CodeVariable - ): - shared_variables[s.arg[key].name] = s.arg[key].value + + _shared_variables = s.return_shared_variables() + for s_key in _shared_variables: + shared_variables[s_key] = _shared_variables[s_key].value + datum_index = 0 for b in self.blocks: timeline = b.timeline stimuli = b.stimuli if not timeline: timeline = [{}] for timeline_element in timeline: - data, prompts, shared_variables = run_stimuli( + out_data, prompts, shared_variables, datum_index = run_stimuli( stimuli, timeline_element, - data, + out_data, shared_variables, prompts, get_input, multi_turn, + datum_index, + data, ) - return data, prompts + return out_data, prompts def run_stimuli( - stimuli, timeline_element, data, shared_variables, prompts, get_input, multi_turn + stimuli, + timeline_element, + out_data, + shared_variables, + prompts, + get_input, + multi_turn, + datum_index, + data, ): for s in stimuli: - s._prepare_args_l(timeline_element, data, shared_variables) - s_data, prompts = s.process_l(prompts, get_input, multi_turn) - data.append(s_data) + if datum_index < len(data): + datum = data[datum_index] + else: + datum = None + + s._prepare_args_l(timeline_element, out_data, shared_variables, datum) + s_out_data, prompts = s.process_l(prompts, get_input, multi_turn, datum) + out_data.append(s_out_data) if s.side_effects: - s._resolve_side_effects(timeline_element, data, shared_variables) + s._resolve_side_effects(timeline_element, out_data, shared_variables) shared_variables.update(s.l_ses) - return data, prompts, shared_variables + datum_index += 1 + return out_data, prompts, shared_variables, datum_index diff --git a/src/sweetbean/stimulus/Choice.py b/src/sweetbean/stimulus/Choice.py index ea80f2b1..33f7a0b6 100644 --- a/src/sweetbean/stimulus/Choice.py +++ b/src/sweetbean/stimulus/Choice.py @@ -121,7 +121,7 @@ def _set_before(self): ) self.js_before = f"on_load:()=>{{{res}}}," - def process_l(self, prompts, get_input, multi_turn): + def process_l(self, prompts, get_input, multi_turn, datum=None): current_prompt = f' You see {len(self.l_args["bandits"])} bandits.' for idx, bandit in enumerate(self.l_args["bandits"]): current_prompt += f' Bandit {idx + 1} is {bandit["color"]}.' @@ -132,7 +132,10 @@ def process_l(self, prompts, get_input, multi_turn): in_prompt = " ".join([p for p in prompts]) + current_prompt + "<<" else: in_prompt = current_prompt + "<<" - response = get_input(in_prompt) + if not datum: + response = get_input(in_prompt) + else: + response = datum["response"] + 1 if int(response) < 1 or int(response) > len(self.l_args["bandits"]): prompts.append( current_prompt + f"<<{response}>>. " f"The response was invalid." diff --git a/src/sweetbean/stimulus/RO.py b/src/sweetbean/stimulus/RO.py index c523442d..3c7c0219 100644 --- a/src/sweetbean/stimulus/RO.py +++ b/src/sweetbean/stimulus/RO.py @@ -63,7 +63,7 @@ def _process_response(self): def _set_before(self): pass - def process_l(self, prompts, get_input, multi_turn): + def process_l(self, prompts, get_input, multi_turn, datum=None): raise NotImplementedError diff --git a/src/sweetbean/stimulus/Stimulus.py b/src/sweetbean/stimulus/Stimulus.py index 76535d21..27e8fe77 100644 --- a/src/sweetbean/stimulus/Stimulus.py +++ b/src/sweetbean/stimulus/Stimulus.py @@ -42,7 +42,26 @@ def __init__(self, args, side_effects=None): self.arg_js["trial_duration"] = self.arg["duration"] for key in self.arg: self.arg_js[key] = args[key] - # self.to_js() + + def return_shared_variables(self): + shared_variables = {} + for key in self.arg: + + def extract_shared_variables(value): + if isinstance(value, SharedVariable): + shared_variables[value.name] = value + elif isinstance(value, dict): + for v in value.values(): + extract_shared_variables(v) + elif isinstance(value, list): + for item in value: + extract_shared_variables(item) + elif isinstance(value, FunctionVariable): + for arg in value.args: + extract_shared_variables(arg) + + extract_shared_variables(self.arg[key]) + return shared_variables def to_js(self): self.js = "" @@ -58,19 +77,28 @@ def _params_to_js(self): self.js_body += f'type: {self.arg["type"]},' for key in self.arg_js: self._param_to_js(key, self.arg_js[key]) + for key in self.arg: + if key not in self.arg_js: + self._param_to_js_arg(key, self.arg[key]) self._add_special_param() self._process_response() self._set_before() if self.side_effects: self._set_side_effects() - def _prepare_args_l(self, timeline_element, data, shared_variables): - self.l_args = {} - self.l_ses = {} - for key, value in self.arg.items(): - key_ = key - value_ = _parse_variable(value, timeline_element, data, shared_variables) - self.l_args[key_] = value_ + def _prepare_args_l(self, timeline_element, data, shared_variables, datum=None): + if not datum: + self.l_args = {} + self.l_ses = {} + for key, value in self.arg.items(): + key_ = key + value_ = _parse_variable( + value, timeline_element, data, shared_variables + ) + self.l_args[key_] = value_ + else: + for key, value in datum.items(): + self.l_args[key] = value def _resolve_side_effects(self, timeline_element, data, shared_variables): if self.side_effects: @@ -80,7 +108,7 @@ def _resolve_side_effects(self, timeline_element, data, shared_variables): ) self.l_ses[se.set_variable.name] = get_variable - def process_l(self, prompts, get_input, multi_turn): + def process_l(self, prompts, get_input, multi_turn, datum=None): prompts.append(self._get_prompt_l()) prompt_response = self._get_response_prompt_l() s_data = {} @@ -91,7 +119,10 @@ def process_l(self, prompts, get_input, multi_turn): _in_prompt = prompts[-1] else: _in_prompt = " ".join([p for p in prompts]) - response = get_input(_in_prompt).upper() + if not datum: + response = get_input(_in_prompt).upper() + else: + response = datum["response"].upper() s_data = self._process_response_l(response) prompts[-1] += f"{response}>>" data.update(s_data) @@ -114,6 +145,10 @@ def _param_to_js(self, key, param): self.js_body += body self.js_data += data + def _param_to_js_arg(self, key, param): + _, data = _set_param_js(key, param) + self.js_data += data + def _set_side_effects(self): for se in self.side_effects: self.js_data += se.to_js() diff --git a/src/sweetbean/stimulus/Survey.py b/src/sweetbean/stimulus/Survey.py index f9b26ea5..17ebb5f9 100644 --- a/src/sweetbean/stimulus/Survey.py +++ b/src/sweetbean/stimulus/Survey.py @@ -32,7 +32,7 @@ def get_prompts(_prompts): questions_ = FunctionVariable("questions", get_prompts, [questions]) super().__init__(questions_, side_effects=side_effects) - def process_l(self, prompts, get_input, multi_turn): + def process_l(self, prompts, get_input, multi_turn, datum=None): current_prompt = [] responses = {} data = self.l_args.copy() @@ -46,9 +46,12 @@ def process_l(self, prompts, get_input, multi_turn): ) else: _in_prompt = current_prompt[-1] + "<<" - response = get_input(_in_prompt) + if not datum: + response = get_input(_in_prompt) + else: + response = datum["response"][f"Q{str(idx)}"] current_prompt[-1] += f"<<{response}>>" - responses[f"Q_{str(idx + 1)}"] = response + responses[f"Q{str(idx)}"] = response data.update({"response": responses}) prompts += current_prompt return data, prompts @@ -70,7 +73,7 @@ def get_prompts(_prompts): questions_ = FunctionVariable("questions", get_prompts, [questions]) super().__init__(questions_, side_effects=side_effects) - def process_l(self, prompts, get_input, multi_turn): + def process_l(self, prompts, get_input, multi_turn, datum=None): current_prompt = [] responses = {} data = self.l_args.copy() @@ -87,7 +90,10 @@ def process_l(self, prompts, get_input, multi_turn): ) else: _in_prompt = current_prompt[-1] + "<<" - response = get_input(_in_prompt) + if not datum: + response = get_input(_in_prompt) + else: + response = datum["response"][f"Q{str(idx)}"] current_prompt[-1] += f"<<{response}>>" responses[f"Q{str(idx)}"] = response data.update({"response": responses}) @@ -125,7 +131,7 @@ def from_scale(cls, prompts=None, scale=None, side_effects=None): prompts_.append({"prompt": p, "labels": scale}) return cls(prompts_, side_effects=side_effects) - def process_l(self, prompts, get_input, multi_turn): + def process_l(self, prompts, get_input, multi_turn, datum=None): current_prompt = [] responses = {} data = self.l_args.copy() @@ -142,9 +148,12 @@ def process_l(self, prompts, get_input, multi_turn): ) else: _in_prompt = current_prompt[-1] + "<<" - response = get_input(_in_prompt) + if not datum: + response = get_input(_in_prompt) + else: + response = datum["response"][f"Q{str(idx)}"] current_prompt[-1] += f"<<{response}>>" - responses[f"Q_{str(idx + 1)}"] = response + responses[f"Q{str(idx)}"] = response data.update({"response": responses}) prompts += current_prompt return data, prompts diff --git a/src/sweetbean/stimulus/_Template_.py b/src/sweetbean/stimulus/_Template_.py index 05e3e983..c19f5cce 100644 --- a/src/sweetbean/stimulus/_Template_.py +++ b/src/sweetbean/stimulus/_Template_.py @@ -13,7 +13,7 @@ def __init__(self, args, side_effects=None): """ super().__init__(args, side_effects) - def process_l(self, prompts, get_input, multi_turn): + def process_l(self, prompts, get_input, multi_turn, datum): """ This is used to process the arguments, generate a prompt and get a response from language input