Skip to content

Commit

Permalink
Merge pull request #276 from griptape-ai:131-add-min_p-and-top_k-for-…
Browse files Browse the repository at this point in the history
…lmstudio-and-ollama

131-add-min_p-and-top_k-for-lmstudio-and-ollama
  • Loading branch information
shhlife authored Feb 27, 2025
2 parents a591cad + 597ca8b commit 2a11d50
Show file tree
Hide file tree
Showing 16 changed files with 84 additions and 10 deletions.
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
### Security -->

## [2.2.11] - 2025-28-02
### Added
- Added `min_p` and/or `top_k` to prompt drivers that support them.

- Top-k: Controls the variety of words the AI can choose from. A lower number (like 10) makes responses more focused and predictable, while a higher number (like 50) allows for more creativity and surprise.

- Min-p: Sets a quality threshold for word choices. Only words with at least a certain percentage of confidence compared to the best option are considered. Lower values (like 0.05) allow more variety, while higher values (like 0.3) stick closer to the most predictable options.

For example, if you ask the question "Give me one unexpected use for a paperclip"..
- With low top_k/high min_p: You'll get common answers (holding papers, makeshift hook)

- With high top_k/low min_p: You'll get more creative or unusual answers (lockpick, tiny sculpture material)

ℹ️ Note: Some models take `top_p` instead of `min_p`. To keep the parameters persistent, I'm still using `min_p` in the node, but setting it to `1-min_p`, so it _acts_ like `top_p`.


## [2.2.10] - 2025-27-02
### Added
- Added ability to use key value pair replacement with Griptape Cloud Assistant as well.
Expand Down
3 changes: 2 additions & 1 deletion nodes/drivers/gtUIAmazonBedrockPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def build_params(self, **kwargs):
params["aws_secret_access_key"] = secret_access_key
if api_key:
params["aws_access_key_id"] = api_key

params["min_p"] = kwargs.get("min_p", 0.1)
params["top_k"] = kwargs.get("top_k", 40)
return params

def create(self, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions nodes/drivers/gtUIAmazonSageMakerJumpstartPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def INPUT_TYPES(cls):
inputs = super().INPUT_TYPES()

inputs["required"].update()
del inputs["optional"]["min_p"]
del inputs["optional"]["top_k"]
inputs["optional"].update(
{
"model": (
Expand Down Expand Up @@ -100,6 +102,7 @@ def create(self, **kwargs):
params["endpoint"] = endpoint
if use_native_tools:
params["use_native_tools"] = use_native_tools

try:
driver = AmazonSageMakerJumpstartPromptDriver(**params)
return (driver,)
Expand Down
2 changes: 1 addition & 1 deletion nodes/drivers/gtUIAnthropicPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def INPUT_TYPES(cls):

# Add the optional inputs
inputs["optional"].update(base_optional_inputs)

del inputs["optional"]["min_p"]
# Set model default
inputs["optional"]["model"] = (
models,
Expand Down
6 changes: 6 additions & 0 deletions nodes/drivers/gtUIAzureOpenAiChatPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def INPUT_TYPES(cls):

# Add the optional inputs
inputs["optional"].update(base_optional_inputs)
del inputs["optional"]["min_p"]
del inputs["optional"]["top_k"]
inputs["optional"].update(
{
"model": (models, {"default": models[0]}),
Expand Down Expand Up @@ -78,6 +80,10 @@ def build_params(self, **kwargs):
"temperature": temperature,
"max_attempts": max_attempts_on_fail,
"use_native_tools": use_native_tools,
# "extra_params": {
# # "min_p": min_p,
# "top_k": top_k,
# },
}
if response_format == "json_object":
response_format = {"type": "json_object"}
Expand Down
17 changes: 17 additions & 0 deletions nodes/drivers/gtUIBasePromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ def INPUT_TYPES(cls):
"tooltip": "Maximum tokens to generate. If <=0, it will use the default based on the tokenizer.",
},
),
"min_p": (
"FLOAT",
{
"default": 0.1,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Minimum probability for sampling. Lower values will be more random.",
},
),
"top_k": (
"INT",
{
"default": 40,
"tooltip": "Top k for sampling. Lower values are more deterministic.",
},
),
},
)
return inputs
Expand Down
6 changes: 6 additions & 0 deletions nodes/drivers/gtUICoherePromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ def build_params(self, **kwargs):
max_attempts = kwargs.get("max_attempts_on_fail", None)
use_native_tools = kwargs.get("use_native_tools", False)
max_tokens = kwargs.get("max_tokens", None)
min_p = kwargs.get("min_p", None)
top_k = kwargs.get("top_k", None)
params = {
"api_key": api_key,
"model": model,
"max_attempts": max_attempts,
"use_native_tools": use_native_tools,
"extra_params": {
"p": 1 - min_p,
"k": top_k,
},
}
if max_tokens > 0:
params["max_tokens"] = max_tokens
Expand Down
2 changes: 1 addition & 1 deletion nodes/drivers/gtUIGooglePromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class gtUIGooglePromptDriver(gtUIBasePromptDriver):
@classmethod
def INPUT_TYPES(cls):
inputs = super().INPUT_TYPES()

del inputs["optional"]["min_p"]
inputs["optional"].update(
{
"model": (models, {"default": models[0]}),
Expand Down
3 changes: 2 additions & 1 deletion nodes/drivers/gtUIGriptapeCloudPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def INPUT_TYPES(cls):

# Add the optional inputs
inputs["optional"].update(base_optional_inputs)

del inputs["optional"]["top_k"]
# Set model default
inputs["optional"]["model"] = (models, {"default": DEFAULT_MODEL})
inputs["optional"].update(
Expand Down Expand Up @@ -83,6 +83,7 @@ def build_params(self, **kwargs):
params["use_native_tools"] = use_native_tools
if max_tokens > 0:
params["max_tokens"] = max_tokens
params["extra_params"]["top_p"] = 1 - kwargs.get("min_p", None)

return params

Expand Down
3 changes: 3 additions & 0 deletions nodes/drivers/gtUIGrokPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def INPUT_TYPES(cls):
# Add the optional inputs
inputs["optional"].update(base_optional_inputs)

del inputs["optional"]["top_k"]

# Set model default
inputs["optional"]["model"] = (models, {"default": DEFAULT_MODEL})
inputs["optional"].update(
Expand Down Expand Up @@ -83,6 +85,7 @@ def build_params(self, **kwargs):
params["use_native_tools"] = use_native_tools
if max_tokens > 0:
params["max_tokens"] = max_tokens
params["extra_params"]["top_p"] = 1 - kwargs.get("min_p", None)

return params

Expand Down
4 changes: 4 additions & 0 deletions nodes/drivers/gtUIGroqChatPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def build_params(self, **kwargs):
"temperature": temperature,
"use_native_tools": use_native_tools,
"max_attempts": max_attempts,
"modalities": [],
"extra_params": {
"top_p": 1 - kwargs.get("min_p", None),
},
}
if response_format == "json_object":
params["response_format"] = {"type": response_format}
Expand Down
2 changes: 2 additions & 0 deletions nodes/drivers/gtUIHuggingFaceHubPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def INPUT_TYPES(cls):
inputs = super().INPUT_TYPES()

inputs["required"].update({})
del inputs["optional"]["min_p"]
del inputs["optional"]["top_k"]
inputs["optional"].update(
{
"model": (
Expand Down
6 changes: 5 additions & 1 deletion nodes/drivers/gtUILMStudioChatPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@ def build_params(self, **kwargs):
max_attempts = kwargs.get("max_attempts_on_fail", None)
use_native_tools = kwargs.get("use_native_tools", False)
max_tokens = kwargs.get("max_tokens", None)

top_p = 1 - kwargs.get("min_p", None)
params = {
"model": model,
"base_url": f"{base_url}:{port}/v1",
"api_key": api_key,
"temperature": temperature,
"use_native_tools": use_native_tools,
"max_attempts": max_attempts,
"extra_params": {
"top_p": top_p,
# "top_k": top_k,
},
}
if response_format == "json_object":
params["response_format"] = {"type": response_format}
Expand Down
11 changes: 9 additions & 2 deletions nodes/drivers/gtUIOllamaPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def build_params(self, **kwargs):
use_native_tools = kwargs.get("use_native_tools", False)
max_tokens = kwargs.get("max_tokens", None)
keep_alive = kwargs.get("keep_alive", 240)

min_p = kwargs.get("min_p")
top_k = kwargs.get("top_k")
params = {
"model": model,
"temperature": temperature,
Expand All @@ -96,7 +97,13 @@ def build_params(self, **kwargs):
params["host"] = f"{base_url}:{port}"
if max_tokens > 0:
params["max_tokens"] = max_tokens
params["extra_params"] = {"keep_alive": int(keep_alive)}
params["extra_params"] = {
"keep_alive": int(keep_alive),
"options": {
"min_p": min_p,
"top_k": top_k,
},
}
return params

def create(self, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion nodes/drivers/gtUIOpenAiChatPromptDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def INPUT_TYPES(cls):
# Add the optional inputs
inputs["optional"].update(base_optional_inputs)

del inputs["optional"]["top_k"]

# Set model default
inputs["optional"]["model"] = (models, {"default": DEFAULT_MODEL})
inputs["optional"].update(
Expand Down Expand Up @@ -85,7 +87,9 @@ def build_params(self, **kwargs):
params["use_native_tools"] = use_native_tools
if max_tokens > 0:
params["max_tokens"] = max_tokens

params["extra_params"] = {
"top_p": 1 - kwargs.get("min_p", None),
}
return params

def create(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "comfyui-griptape"
version = "2.2.10"
version = "2.2.11"
description = "Griptape LLM(Large Language Model) Nodes for ComfyUI."
authors = ["Jason Schleifer <[email protected]>"]
readme = "README.md"
Expand All @@ -9,7 +9,7 @@ readme = "README.md"
[project]
name = "comfyui-griptape"
description = "Griptape LLM(Large Language Model) Nodes for ComfyUI."
version = "2.2.10"
version = "2.2.11"
license = {file = "LICENSE"}
dependencies = ["attrs>=24.3.0,<26.0.0", "openai>=1.58.1,<2.0.0", "griptape[all]>=1.4.0", "python-dotenv", "poetry==1.8.5", "griptape-black-forest @ git+https://github.com/griptape-ai/griptape-black-forest.git", "griptape_serper_driver_extension @ git+https://github.com/mertdeveci5/griptape-serper-driver-extension.git"]

Expand Down

0 comments on commit 2a11d50

Please sign in to comment.