Skip to content

Commit

Permalink
Dialektik prompting
Browse files Browse the repository at this point in the history
  • Loading branch information
JosefAlbers authored Aug 10, 2024
1 parent 5b6ab2e commit 052d661
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 109 deletions.
65 changes: 5 additions & 60 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Phi-3-MLX is a versatile AI framework that leverages both the Phi-3-Vision multi
Phi-3-MLX is designed to run on Apple Silicon Macs. The minimum requirements are:

- Apple Silicon Mac (M1, M2, or later)
- macOS 11.0 or later
- 8GB RAM (with quantization using `quantize_model=True` option)

For optimal performance, especially when working with larger models or datasets, we recommend using a Mac with 16GB RAM or more.
Expand Down Expand Up @@ -85,7 +84,7 @@ prompts = [
]

# Define constraints for the generated text
constraints = [(0, ' The'), (100, ' The correct answer is'), (1, 'X.')]
constraints = [(0, '\nThe'), (100, ' The correct answer is'), (1, 'X.')]

# Apply constrained beam decoding
results = constrain(prompts, constraints, blind_model=True, quantize_model=True, use_beam=True)
Expand Down Expand Up @@ -127,49 +126,13 @@ generate("Describe the potential applications of CRISPR gene editing in medicine
quantize_model=True,
use_adapter=True)

# Compare LoRA adapters
test_lora(adapter_path=None) # Without LoRA adapter
test_lora(adapter_path=True) # With default LoRA adapter
test_lora(adapter_path="/path/to/your/lora") # With specific adapter
# Test the performance of the trained LoRA adapter
test_lora()
```

![Alt text](https://raw.githubusercontent.com/JosefAlbers/Phi-3-Vision-MLX/main/assets/train_log.png)

## 2. HTTP Model Server

1. Start the server:

```
python server.py
```

2. Send POST requests to `http://localhost:8000/v1/completions` with a JSON body:

```bash
curl -X POST http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"prompt": [
"Hello, world!",
"Guten Tag!"
],
"max_tokens": 50
}'
```

3. Receive JSON responses with generated text for each prompt:

```json
{
"model": "phi-3-vision",
"responses": [
"Hello! How can I help you today?<|end|>",
"Guten Tag! Wie kann ich Ihnen helfen?<|end|>"
]
}
```

## 3. Agent Interactions
## 2. Agent Interactions

### Multi-turn Conversation

Expand Down Expand Up @@ -218,7 +181,7 @@ agent.end()

![Alt text](https://raw.githubusercontent.com/JosefAlbers/Phi-3-Vision-MLX/main/assets/api_agent.png)

## 4. Custom Toolchains
## 3. Custom Toolchains

### In-Context Learning Agent

Expand Down Expand Up @@ -310,24 +273,6 @@ benchmark()

*(On M1 Max 64GB)*

## More Examples

For advanced examples and external library integration, see `examples.py` in the project root. Preview:

```python
# Multimodal Reddit Thread Summarizer
from rd2md import rd2md
from pathlib import Path
import json

filename, contents, images = rd2md()
prompt = 'Write an executive summary of above (max 200 words). The article should capture the diverse range of opinions and key points discussed in the thread, presenting a balanced view of the topic without quoting specific users or comments directly. Focus on organizing the information cohesively, highlighting major arguments, counterarguments, and any emerging consensus or unresolved issues within the community.'
prompts = [f'{s}\n\n{prompt}' for s in contents]
results = [generate(prompts[i], images[i], max_tokens=512, blind_model=False, quantize_model=True, quantize_cache=False, verbose=False) for i in range(len(prompts))]
with open(Path(filename).with_suffix('.json'), 'w') as f:
json.dump({'prompts':prompts, 'images':images, 'results':results}, f, indent=4)
```

## Documentation

API references and additional information are available at:
Expand Down
40 changes: 5 additions & 35 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,35 @@

from huggingface_hub import InferenceClient

# def mistral_api(prompt, history):
# """
# Example:
# --------
# agent = Agent(toolchain = "responses, history = mistral_api(prompt, history)")
# agent('Write a neurology ICU admission note')
# """
# history = '<s>' if history is None else history
# history += f"[INST] {prompt} [/INST]"
# client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token = os.environ.get('HF_READ_TOKEN', False))
# generate_kwargs = dict(
# temperature=0.9,
# max_new_tokens=1024,
# top_p=0.95,
# repetition_penalty=1.0,
# do_sample=True,
# seed=42,
# stream=False,
# details=False,
# # details=True,
# return_full_text=False,
# )
# result = client.text_generation(history, **generate_kwargs)
# result = result.strip()
# # result = result.generated_text.strip() # if details=True
# history += f" {result}</s> "
# print(f'### Prompt ###\n{prompt}\n### Output ###\n{result}')
# return {'responses':result, 'history':history}

def mistral_api(prompt, history, verbose=True, api_model="mistralai/Mistral-Nemo-Instruct-2407"):
def mistral_api(prompt, history, verbose=True, return_dict=True, api_model="mistralai/Mistral-Nemo-Instruct-2407"):
"""
Example:
--------
agent = Agent(toolchain = "responses, history = mistral_api(prompt, history)")
agent('Write a neurology ICU admission note')
"""
# "mistralai/Mistral-Nemo-Instruct-2407" "mistralai/Mistral-7B-Instruct-v0.3"
history = '<s>' if history is None else history
history += f"[INST] {prompt} [/INST]"
client = InferenceClient(api_model, token = os.environ.get('HF_READ_TOKEN', False))
generate_kwargs = dict(
temperature=0.9,
max_new_tokens=1024,
max_new_tokens=8192,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
seed=42,
stream=False,
details=False,
# details=True,
return_full_text=False,
)
result = client.text_generation(history, **generate_kwargs)
result = result.strip()
# result = result.generated_text.strip() # if details=True
history += f" {result}</s> "
if verbose:
print(f'### Prompt ###\n{prompt}\n### Output ###\n{result}')
return {'responses':result, 'history':history}
if return_dict:
return {'responses':result, 'history':history}
return result

def bark_api(prompt):
"""
Expand Down
Binary file modified assets/ACB.pdf
Binary file not shown.
Binary file modified assets/agent_toolchain.pdf
Binary file not shown.
Binary file modified assets/dialektik.pdf
Binary file not shown.
128 changes: 128 additions & 0 deletions assets/dialektik.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from pathlib import Path
from datasets import load_dataset, concatenate_datasets
import random
import json
import os
from datetime import datetime
from huggingface_hub import InferenceClient
import phi_3_vision_mlx as pv
import mlx.core as mx
from functools import partial
import fire

PATH_DS = 'JosefAlbers/StampyAI-alignment-research-dataset'
PROMPT_THESIS = "Based on the above bullet points, create a detailed and engaging article that explores the main themes and insights. For each bullet point, provide context, elaborate on the key ideas, and discuss their implications. Ensure the article flows logically, connects related concepts, and presents a coherent narrative."
PROMPT_ANTITHESIS = "Read through the article and write a response that challenges its main ideas. Offer different viewpoints, suggest alternative explanations, and propose new approaches. Keep your response well-structured and relevant to the original content."
PROMPT_SYNTHESIS = """You have an initial article and a response to it:
**Article:**
{thesis}
**Response:**
{antithesis}
Create an improved version of the article that incorporates insights from both the original and the response. Address conflicting ideas and present a more comprehensive view. Add new insights based on this broader perspective. Your final article should be clear, balanced, and offer a deeper understanding of the topic."""

def setup(instruction="\n<|end|>\n<|user|>\nTLDR: Summarize the following text into concise, stand-alone bullet points (max 3-5 bullet points). Each bullet point should be self-contained and provide a clear and complete idea without referencing other bullet points or the original text.", list_source=['agentmodels', 'distill', 'arbital', 'blogs', 'lesswrong', 'youtube', 'arxiv', 'special_docs'], quantize_model=False, batch_size=4, path_ds=PATH_DS):
model, processor = pv.load(blind_model=True, quantize_model=quantize_model, quantize_cache=False, use_adapter=False)
def aggregate(example):
str_md = f"# {example['title']}\n\n{example['text']}"
example['str_md'] = str_md
example['len_md'] = processor(str_md)['input_ids'].size
return example
def summarize(example):
markdowns = example['str_md']
prompts = [f'{m}{instruction}' for m in markdowns]
summaries = pv.generate(prompts, preload=(model, processor), stream=False, verbose=False, max_tokens=512)
example['sum_md'] = summaries
return example
list_ds = []
try:
_ds_prev = load_dataset(path_ds, token=os.getenv("HF_WRITE_TOKEN"), split='train')
list_source = [i for i in list_source if i not in _ds_prev['source']]
list_ds.append(_ds_prev)
except:
print('Dataset not found.')
for src in list_source:
ds = load_dataset('StampyAI/alignment-research-dataset', src, trust_remote_code=True, split='train')
ds = ds.select_columns(['id', 'source', 'title', 'text', 'url', 'date_published', 'authors', 'summary', 'source_type'])
ds = ds.map(aggregate)
ds = ds.filter(lambda example: 600 < example["len_md"] < 6000)
if batch_size > 1:
ds = ds.sort('len_md')
ds = ds.map(summarize, batched=True, batch_size=batch_size)
ds = ds.filter(lambda example: ('<unk>' not in example['sum_md']) and ('<|end|>' in example['sum_md']))
list_ds.append(ds)
ds = concatenate_datasets(list_ds)
ds.push_to_hub(path_ds, token=os.getenv("HF_WRITE_TOKEN"), private=True)

def load_books(list_source=None, list_exclude=None, path_ds=PATH_DS):
ds = load_dataset(path_ds, token=os.getenv("HF_READ_TOKEN", None), split='train')
if list_source:
list_source = [list_source] if isinstance(list_source, str) else list_source
ds = ds.filter(lambda example: example['source'] in list_source)
if list_exclude:
list_exclude = [list_exclude] if isinstance(list_exclude, str) else list_exclude
ds = ds.filter(lambda example: not any(word in example['sum_md'] for word in list_exclude))
print(f"Loaded {len(ds)} from {', '.join(set(ds['source']))}")
books = ds['sum_md']
books = [i.split('\n- ') for i in books]
clean_str = lambda s: s[2:] if s.startswith('- ') else s[:-7] if s.endswith('<|end|>') else s
books = [[clean_str(s).strip() for s in book] for book in books]
return books

def pick_books(topic, list_idx, list_books, num_book=3):
if topic is None:
return random.sample(range(len(list_books)), num_book)
list_rand = list_idx if list_idx else random.sample(range(len(list_books)), 100)
list_text = [list_books[i][0] for i in list_rand]
embed = pv.GteModel()
l = embed(list_text)
q = embed(topic)
scores = mx.matmul(q, l.T)
list_idx = mx.argsort(scores)[:,:-1-num_book:-1].tolist()
list_idx = list_idx[0]
return [list_rand[i] for i in list_idx]

def get_bullets(topic='AI agents', list_source=None, list_exclude=['MIRI', 'Machine Intelligence Research Institute'], list_idx=None, num_book=3, per_book=3):
books = load_books(list_source, list_exclude)
list_idx = pick_books(topic, list_idx, books, num_book)
print(f"Picked {list_idx}")
picks = [books[i] for i in list_idx]
bullets = ''
for pick in picks:
pick=pick[:per_book]
bullets += '- ' + '\n - '.join(pick) + '\n'
bullets = bullets.strip()
print(f'Bullets:\n{bullets}')
return bullets, list_idx

def save_output(output, file_suffix=None, base_folder='syntheses'):
file_suffix = f'_{file_suffix}' if file_suffix else ''
os.makedirs(base_folder, exist_ok=True)
date_str = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
filename = os.path.join(base_folder, f'{date_str}{file_suffix}.md')
with open(filename, 'w') as f:
f.write(output)

def synthesize(topic=None, prompt_thesis=PROMPT_THESIS, prompt_antithesis=PROMPT_ANTITHESIS, prompt_synthesis=PROMPT_SYNTHESIS,
list_source=None, list_exclude=['MIRI', 'Machine Intelligence Research Institute'],
list_idx=None, num_book=3, per_book=3, llm_model=None):
if llm_model is None:
preload = pv.load(blind_model=True, quantize_model=True)
generate = partial(pv.generate, preload=preload)
else:
generate = partial(pv.mistral_api, api_model=llm_model, history=None, return_dict=False, verbose=False)
bullets, list_idx = get_bullets(topic, list_source, list_exclude, list_idx, num_book, per_book)
prompt = f"{bullets}\n\n{prompt_thesis}"
thesis_output = generate(prompt)
prompt_anti = f'{thesis_output}\n\n{prompt_antithesis}'
antithesis_output = generate(prompt_anti)
prompt_synth = prompt_synthesis.format(thesis=thesis_output, antithesis=antithesis_output)
synthesis_output = generate(prompt_synth)
all_output = f'Thesis:\n---\n\n{thesis_output}\n\nAntithesis:\n---\n\n{antithesis_output}\n\nSynthesis:\n---\n\n{synthesis_output}\n\nArguments:\n---\n\ndialektik.synthesize({list_source=}, {list_exclude=},{list_idx=}, {per_book=}, {llm_model=})\n\n{bullets}'
save_output(all_output)
return thesis_output, antithesis_output, synthesis_output

if __name__ == "__main__":
fire.Fire(synthesize)
Binary file modified assets/mlx_porting_guide.pdf
Binary file not shown.
File renamed without changes.
2 changes: 1 addition & 1 deletion assets/tutorial_0.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ We'll start by comparing the original Hugging Face implementation with our MLX p

### 2. Implementing SuRoPE for 128K Context

We'll explore the Surrogate Rotary Position Embedding (SuRoPE) implementation that enables Phi-3-Vision to handle impressive 128K token contexts.
We'll explore the Su-scaled Rotary Position Embedding (SuRoPE) implementation that enables Phi-3-Vision to handle impressive 128K token contexts.

### 3. Optimizing Text Generation in MLX: From Batching to Advanced Techniques

Expand Down
4 changes: 2 additions & 2 deletions assets/tutorial_6.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ It's especially useful in the context of AI agents and function calling. Constra

In multi-agent systems, constrained decoding maintains consistent interfaces between components, allowing outputs from one model to serve reliably as inputs for another. This consistency is key for building robust, multi-step AI workflows and seamlessly integrating AI-generated code into larger systems.

## 2. Guided Reasoning in Complex Decision-Making
### 2. Guided Reasoning in Complex Decision-Making

Constrained decoding can also guide the model's reasoning process in complex scenarios like medical diagnosis. Let's look at an example:

Expand Down Expand Up @@ -154,4 +154,4 @@ This method of constrained decoding is analogous to asking a student to "show th

By implementing constrained decoding in complex decision-making scenarios, we can create more reliable and interpretable AI systems. This is important in high-stakes domains like medical diagnosis, legal reasoning, or financial analysis, where understanding the reasoning behind a decision is as important as the decision itself.

In the next part of our series, we'll explore techniques for fine-tuning our model on custom datasets, allowing us to adapt Phi-3-Vision for specific tasks or domains.
In the next part of our series, we'll explore techniques for fine-tuning our model on custom datasets, allowing us to adapt Phi-3-Vision for specific tasks or domains.
4 changes: 2 additions & 2 deletions assets/tutorial_8.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Part 8: Implementing the Agent Class and Toolchain System
# Part 8: Implementing the Agent and Toolchain System

## Introduction

Expand All @@ -23,7 +23,7 @@ class Agent:
self.reset()

def __call__(self, prompt:str, images=None):
# Implementation details
# ...
```

The class is designed with a default toolchain and an initializer that sets up the agent's configuration.
Expand Down
4 changes: 3 additions & 1 deletion examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
)

## Test
pv.test_lora()
test_lora(adapter_path=None) # Without LoRA adapter
test_lora(adapter_path=True) # With default LoRA adapter
test_lora(adapter_path="adapters/phi3_mini_128k_Q") # With specific adapter

# Agent

Expand Down
Loading

0 comments on commit 052d661

Please sign in to comment.