Skip to content

Commit

Permalink
Merge pull request #57 from AutoResearch/55-doc-add-docstrings-in-src…
Browse files Browse the repository at this point in the history
…-code

docs: add docstrings
  • Loading branch information
younesStrittmatter authored Dec 2, 2024
2 parents 735e2c1 + 02a3a7e commit ad04eb3
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 1 deletion.
2 changes: 2 additions & 0 deletions mkdocs/gen_ref_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
ignore = [
Path("./src/sweetbean/_const.py"),
Path("./src/sweetbean/util/parse.py"),
Path("./src/sweetbean/stimulus/_Template_.py"),
Path("./src/sweetbean/stimulus/Stimulus.py"),
]

source_paths = sorted(Path("./").rglob("**/src"))
Expand Down
9 changes: 9 additions & 0 deletions src/sweetbean/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@


class Block:
"""
A block of stimuli (for example, an instruction, training, or test block)
"""

stimuli: List[Any] = []
js = ""
timeline = None

def __init__(self, stimuli, timeline=None):
"""
Arguments:
stimuli: a list of stimuli
timeline: a list of dictionaries with the name of the timeline variables
"""
if timeline is None:
timeline = []
self.stimuli = stimuli
Expand Down
11 changes: 11 additions & 0 deletions src/sweetbean/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
def process_js(js_data):
"""
Process the data from a SweetBean generated jsPsych experiment.
"""
res = []
for d in js_data:
res_dict = {}
Expand All @@ -12,6 +15,9 @@ def process_js(js_data):


def get_n_responses(data):
"""
Get the number of responses in the data.
"""
n = 0
for d in data:
if "response" in d and d["response"] is not None:
Expand All @@ -20,6 +26,11 @@ def get_n_responses(data):


def until_response(data, n):
"""
Get the data until the nth response.
(This is helpful, for example, to get the data up a certain point
and then run the rest of the experiment with `run_on_language`)
"""
i = 0
for idx, d in enumerate(data):
if "response" in d and d["response"] is not None:
Expand Down
27 changes: 27 additions & 0 deletions src/sweetbean/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@


class Experiment:
"""
An experiment consisting of blocks
"""

blocks: List[Block] = []
js = ""

def __init__(self, blocks: List[Block]):
"""
Arguments:
blocks: a list of blocks
"""
self.blocks = blocks

def to_js(self, path_local_download=None):
Expand All @@ -41,6 +49,9 @@ def to_js(self, path_local_download=None):
self.js += ";jsPsych.run(trials)"

def to_html(self, path, path_local_download=None):
"""
Save the experiment to an HTML file
"""
self.to_js(path_local_download)
html = HTML_PREAMBLE
blocks = 0
Expand All @@ -53,6 +64,9 @@ def to_html(self, path, path_local_download=None):
f.write(html)

def to_js_string(self, as_function=True, is_async=True):
"""
Return the experiment as a JavaScript string
"""
text = FUNCTION_PREAMBLE(is_async) if as_function else ""
for b in self.blocks:
b.to_js()
Expand All @@ -75,6 +89,19 @@ def run_on_language(
multi_turn=False,
data=None,
):
"""
Run the experiment in a language
Arguments:
get_input: a function to get input from the response
(for example, a function that prompts language model and returns the response)
multi_turn: a boolean to allow multi-turn input.
If True, the prompts are not concatenated.
data: a list of dictionaries with the data.
This will rerun the experiment with the data as input.
If the data is not provided for the full experiment,
the rest of it will be simulated with the get_input function.
"""
out_data = []
prompts = []
shared_variables = {}
Expand Down
18 changes: 18 additions & 0 deletions src/sweetbean/stimulus/Choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def __init__(
time_after_response=3000,
side_effects=None,
):
"""
Arguments:
duration (int): The duration of the stimulus
html_array (list): An array of html elements that can be clicked
values (list): An array of values corresponding to the html elements
time_after_response (int): The time after a response is made
(for example, for animations)
side_effects (dict): A dictionary of side effects
"""
if values is None:
values = []
if html_array is None:
Expand Down Expand Up @@ -54,6 +63,15 @@ def __init__(
time_after_response=2000,
side_effects=None,
):
"""
Arguments:
duration (int): The duration of the stimulus
bandits (list): A list of bandits
(in the form of dictionaries with entries for color and value)
time_after_response (int): The time after a response is made
(for example, for animations)
side_effects (dict): A dictionary of side effects
"""
if bandits is None:
bandits = []

Expand Down
12 changes: 12 additions & 0 deletions src/sweetbean/stimulus/HtmlKeyboardResponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
color: the color of the text
choices: the keys that will be recorded if pressed
correct_key: the correct key to press
side_effects: a dictionary of side effects
"""
if choices is None:
choices = []
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
duration: time in ms the stimulus is presented
choices: the keys that will be recorded if pressed
correct_key: the correct key to press
side_effects: a dictionary of side effects
"""
super().__init__(
duration=duration,
Expand All @@ -108,6 +110,7 @@ def __init__(self, duration=None, side_effects=None):
"""
Arguments:
duration: time in ms the stimulus is presented
side_effects: a dictionary of side effects
"""
super().__init__(
duration=duration,
Expand Down Expand Up @@ -137,8 +140,13 @@ def __init__(
"""
Arguments:
duration: time in ms the stimulus is presented
correct_message: the message to show if the response was correct
false_message: the message to show if the response was false
correct_color: the color of the message if the response was correct
false_color: the color of the message if the response was false
window: how far back is the stimulus to check
(that stimulus needs to have a choice and a correct_key parameter)
side_effects: a dictionary of side effects
"""
correct = DataVariable("correct", window)

Expand Down Expand Up @@ -187,6 +195,9 @@ def __init__(
distractor: the direction of the distractor (allowed: left, right, l, r, L, R)
choices: the keys that will be recorded if pressed
correct_key: the correct key to press
color: the color of the text
n_flankers: the number of distractors
side_effects: a dictionary of side effects
"""

def _txt(dr, dst, n):
Expand Down Expand Up @@ -238,6 +249,7 @@ def __init__(
color: the color of the symbol
choices: the keys that will be recorded if pressed
correct_key: the correct key to press
side_effects: a dictionary of side effects
"""

def stim(symbl, clr):
Expand Down
3 changes: 2 additions & 1 deletion src/sweetbean/stimulus/Image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def __init__(
"""
Arguments:
duration: time in ms the stimulus is presented
src: the path to the image
stimulus: the path to the image
choices: the keys that will be recorded if pressed
correct_key: the correct key to press
side_effects: a dictionary of side effects
"""
super().__init__(locals(), side_effects)

Expand Down
2 changes: 2 additions & 0 deletions src/sweetbean/stimulus/RO.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
background_color: the background color
choices: the valid keys that the subject can press to indicate a response
correct_key: the correct key to press
side_effects: a dictionary of side effects
"""
if choices is None:
choices = []
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
background_color: the background color
choices: the valid keys that the subject can press to indicate a response
correct_key: the correct key to press
side_effects: a dictionary of side effects
"""
if choices is None:
choices = []
Expand Down
39 changes: 39 additions & 0 deletions src/sweetbean/stimulus/Survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


class _Survey(_BaseStimulus):
"""
A base class for surveys
"""

def __init__(self, questions=None, side_effects=None):
super().__init__(locals(), side_effects=side_effects)

Expand All @@ -17,9 +21,18 @@ def _set_before(self):


class TextSurvey(_Survey):
"""
A survey that asks for text input
"""

type = "jsPsychSurveyText"

def __init__(self, questions=None, side_effects=None):
"""
Arguments:
questions: a list of strings representing the questions
side_effects: a dictionary of side effects
"""
if not questions:
questions = []

Expand Down Expand Up @@ -58,9 +71,18 @@ def process_l(self, prompts, get_input, multi_turn, datum=None):


class MultiChoiceSurvey(_Survey):
"""
A survey that asks for multiple choice input
"""

type = "jsPsychSurveyMultiChoice"

def __init__(self, questions=None, side_effects=None):
"""
Arguments:
questions: a list of dictionaries with the keys "prompt" and "options"
side_effects: a dictionary of side effects
"""
if not questions:
questions = []

Expand Down Expand Up @@ -104,9 +126,18 @@ def process_l(self, prompts, get_input, multi_turn, datum=None):
#
#
class LikertSurvey(_Survey):
"""
A survey that asks for Likert scale input
"""

type = "jsPsychSurveyLikert"

def __init__(self, questions=None, side_effects=None):
"""
Arguments:
questions: a list of dictionaries with the keys "prompt" and "labels"
side_effects: a dictionary of side effects
"""
if not questions:
questions = []

Expand All @@ -122,6 +153,14 @@ def get_prompts(_prompts):
#
@classmethod
def from_scale(cls, prompts=None, scale=None, side_effects=None):
"""
Create a LikertSurvey from a scale
Arguments:
prompts: a list of strings representing the prompts
scale: a list of strings representing the scale
side_effects: a dictionary of side effects
"""
if not prompts:
prompts = []
if not scale:
Expand Down
1 change: 1 addition & 0 deletions src/sweetbean/stimulus/Video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
src: the path to the image
choices: the keys that will be recorded if pressed
correct_key: the correct key to press
side_effects: a dictionary of side effects
"""
super().__init__(locals(), side_effects=side_effects)

Expand Down
29 changes: 29 additions & 0 deletions src/sweetbean/variable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@


class Variable(ABC):
"""
A base class for variables.
"""

def __init__(self, name):
self.name = name

Expand All @@ -27,6 +31,11 @@ class DataVariable(Variable):
"""

def __init__(self, name, window):
"""
Arguments:
name: the name of the variable
window: the window of the data
"""
super().__init__(name)
self.name = f'data["bean_{name}"]'
self.raw_name = name
Expand All @@ -50,6 +59,12 @@ class FunctionVariable(Variable):
"""

def __init__(self, name, fct, args):
"""
Arguments:
name: the name of the variable
fct: the function
args: the arguments of the function
"""
super().__init__(name)
self.fct = fct
self.args = args
Expand All @@ -66,6 +81,11 @@ class CodeVariable(Variable):
"""

def __init__(self, name, value):
"""
Arguments:
name: the name of the variable
value: the initial value of the variable
"""
super().__init__(name)
self.value = value

Expand All @@ -82,6 +102,11 @@ class SharedVariable:
"""

def __init__(self, name, value):
"""
Arguments:
name: the name of the variable
value: the initial value of the variable
"""
self.name = str(name)
self.value = value

Expand All @@ -96,6 +121,10 @@ class SideEffect:
def __init__(self, set_variable, get_variable):
"""
A side effect that can set variables.
Arguments:
set_variable: the variable to set (often a SharedVariable)
get_variable: the variable to get (e.g, the variable the set variable should be set to)
"""
self.set_variable = set_variable
self.get_variable = get_variable
Expand Down

0 comments on commit ad04eb3

Please sign in to comment.