Skip to content

Commit

Permalink
dry sampling added
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Oct 26, 2024
1 parent 2c1831b commit 4f89a52
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 29 deletions.
62 changes: 34 additions & 28 deletions llama_msgs/msg/SamplingConfig.msg
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
int32 n_prev 64 # number of previous tokens to remember
int32 n_probs 1 # if greater than 0, output the probabilities of top n_probs tokens
int32 min_keep 0 # 0 = disabled, otherwise samplers should return at least min_keep tokens
int32 n_prev 64 # number of previous tokens to remember
int32 n_probs 1 # if greater than 0, output the probabilities of top n_probs tokens
int32 min_keep 0 # 0 = disabled, otherwise samplers should return at least min_keep tokens

bool ignore_eos false # ignore end of stream token and continue generating (implies --logit-bias 2-inf)
LogitBiasArray logit_bias # logit bias for specific tokens
bool ignore_eos false # ignore end of stream token and continue generating (implies --logit-bias 2-inf)
LogitBiasArray logit_bias # logit bias for specific tokens

float32 temp 0.80 # temperature
float32 dynatemp_range 0.0 # 0.0 = disabled
float32 dynatemp_exponent 1.0 # controls how entropy maps to temperature in dynamic temperature sampler
float32 temp 0.80 # temperature
float32 dynatemp_range 0.0 # 0.0 = disabled
float32 dynatemp_exponent 1.0 # controls how entropy maps to temperature in dynamic temperature sampler

int32 top_k 40 # top-k sampling (0.0 = disabled)
float32 top_p 0.95 # top-p sampling (1.0 = disabled)
float32 xtc_probability 0.00 # xtc sampling (0.0 = disable)
float32 xtc_threshold 0.10 # xtc sampling threshold (> 0.5 disables XTC)
float32 min_p 0.05 # min-p sampling (0.0 = disabled)
float32 tfs_z 1.00 # tail free sampling, parameter z (1.0 = disabled)
float32 typical_p 1.00 # locally typical sampling, parameter p (1.0 = disabled)
int32 top_k 40 # top-k sampling (0.0 = disabled)
float32 top_p 0.95 # top-p sampling (1.0 = disabled)
float32 min_p 0.05 # min-p sampling (0.0 = disabled)
float32 xtc_probability 0.00 # xtc sampling (0.0 = disable)
float32 xtc_threshold 0.10 # xtc sampling threshold (> 0.5 disables XTC)
float32 tfs_z 1.00 # tail free sampling, parameter z (1.0 = disabled)
float32 typical_p 1.00 # locally typical sampling, parameter p (1.0 = disabled)

int32 penalty_last_n 64 # last n tokens consider for penalize (0 = disable penalty, -1 = context size)
float32 penalty_repeat 1.00 # penalize repeat sequence of tokens (1.0 = disabled)
float32 penalty_freq 0.00 # repeat alpha frequency penalty (0.0 = disable)
float32 penalty_present 0.00 # repeat alpha presence penalty (0.0 = disabled)
int32 penalty_last_n 64 # last n tokens consider for penalize (0 = disable penalty, -1 = context size)
float32 penalty_repeat 1.00 # penalize repeat sequence of tokens (1.0 = disabled)
float32 penalty_freq 0.00 # repeat alpha frequency penalty (0.0 = disable)
float32 penalty_present 0.00 # repeat alpha presence penalty (0.0 = disabled)

int32 mirostat 0 # Mirostart sampling (0 = disabled, 1 = mirostat, 2 = mirostat 2.0)
float32 mirostat_eta 0.10 # Mirostat learning rate, parameter eta
float32 mirostat_tau 5.0 # Mirostat target entropy, parameter tau
float32 dry_multiplier 0.0 # DRY repetition penalty for tokens extending repetition (0.0 = disabled)
float32 dry_base 1.75 # multiplier * base ^ (length of sequence before token - allowed length) (0.0 = disabled)
int32 dry_allowed_length 2 # tokens extending repetitions beyond this receive penalty
int32 dry_penalty_last_n -1 # how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
string[] dry_sequence_breakers ["\\n", ":", "\\\"", "*"] # default sequence breakers for DRY

bool penalize_nl false # consider newlines as a repeatable token
int32 mirostat 0 # Mirostart sampling (0 = disabled, 1 = mirostat, 2 = mirostat 2.0)
float32 mirostat_eta 0.10 # Mirostat learning rate, parameter eta
float32 mirostat_tau 5.0 # Mirostat target entropy, parameter tau

string samplers_sequence "kfypmxt" # TOP_K, TFS_Z, TYPICAL_P, TOP_P, MIN_P, XTC, TEMP
bool penalize_nl false # consider newlines as a repeatable token

string grammar "" # optional BNF-like grammar to constrain sampling
string grammar_schema "" # grammar schema that defines a JSON BNF grammar
string samplers_sequence "dkfypmxt" # TOP_K, TFS_Z, TYPICAL_P, TOP_P, MIN_P, XTC, TEMP

int32[] penalty_prompt_tokens # list of tokens to penalize
bool use_penalty_prompt_tokens false # whether to penalize tokens
string grammar "" # optional BNF-like grammar to constrain sampling
string grammar_schema "" # grammar schema that defines a JSON BNF grammar

int32[] penalty_prompt_tokens # list of tokens to penalize
bool use_penalty_prompt_tokens false # whether to penalize tokens
18 changes: 17 additions & 1 deletion llama_ros/llama_ros/langchain/llama_ros_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class LlamaROSCommon(BaseLanguageModel, ABC):
top_k: int = 40
top_p: float = 0.95
min_p: float = 0.05
xtc_probability: float = 0.0
xtc_threshold: float = 0.1
tfs_z: float = 1.00
typical_p: float = 1.00

Expand All @@ -64,13 +66,19 @@ class LlamaROSCommon(BaseLanguageModel, ABC):
penalty_freq: float = 0.00
penalty_present: float = 0.00

dry_multiplier: float = 0.0
dry_base: float = 1.75
dry_allowed_length: int = 2
dry_penalty_last_n: int = -1
dry_sequence_breakers: List[str] = ["\\n", ":", "\\\"", "*"]

mirostat: int = 0
mirostat_eta: float = 0.10
mirostat_tau: float = 5.0

penalize_nl: bool = False

samplers_sequence: str = "kfypmt"
samplers_sequence: str = "dkfypmxt"

grammar: str = ""
grammar_schema: str = ""
Expand Down Expand Up @@ -140,6 +148,8 @@ def _create_action_goal(
goal.sampling_config.top_k = self.top_k
goal.sampling_config.top_p = self.top_p
goal.sampling_config.min_p = self.min_p
goal.sampling_config.xtc_probability = self.xtc_probability
goal.sampling_config.xtc_threshold = self.xtc_threshold
goal.sampling_config.tfs_z = self.tfs_z
goal.sampling_config.typical_p = self.typical_p

Expand All @@ -148,6 +158,12 @@ def _create_action_goal(
goal.sampling_config.penalty_freq = self.penalty_freq
goal.sampling_config.penalty_present = self.penalty_present

goal.sampling_config.dry_multiplier = self.dry_multiplier
goal.sampling_config.dry_base = self.dry_base
goal.sampling_config.dry_allowed_length = self.dry_allowed_length
goal.sampling_config.dry_penalty_last_n = self.dry_penalty_last_n
goal.sampling_config.dry_sequence_breakers = self.dry_sequence_breakers

goal.sampling_config.mirostat = self.mirostat
goal.sampling_config.mirostat_eta = self.mirostat_eta
goal.sampling_config.mirostat_tau = self.mirostat_tau
Expand Down

0 comments on commit 4f89a52

Please sign in to comment.