Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revisiting the load/save state #28

Open
HarvieKrumpet opened this issue Nov 13, 2024 · 2 comments
Open

revisiting the load/save state #28

HarvieKrumpet opened this issue Nov 13, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@HarvieKrumpet
Copy link

I am a bit of a broken record on this. I have indeed tried I few times to do this on my own within your framework. I got this far with it.

I get stuck because this is not all that is needed for a true state capture. I am supposed to be saving the

 llama_decode(ctx, batch);
    n_past += batch.n_tokens;

I would be almost duplicating your inner loop to get at these values and I sure you could stuff away those n_past values in an array should I want to save the state without doing this as a 2nd pass. and your using functions that I have not seen other implementations use. So I am unsure on my levels what I am doing here.

Maybe your in a comfortable point here to want to tackle it. I know your expecting gerganov to roll this into some unified system. but I am not holding my breath for it. I use up nearly 4k for just my system messages nowadays, so I have an large inherent delay that I want to cache away with the load/save state system. I am not even sure if I implemented the pure load/save state properly.

Thanks,

public void LoadState(User _User, string _Name) {
    if (_User?.Mo?.ModelName is string Model) {
        LLM? llm = GetModel(Model);
        if (llm != null) {
            // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
            string filename = IAmandaServer.StateBase + _Name + ".State";
            if (File.Exists(filename)) {
                //using var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null);
                //using var view = file.CreateViewAccessor();
                unsafe {
                    //byte* ptr = null;
                    //view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
                    try {
                        byte[] State = File.ReadAllBytes(filename);
                        Native.llama_state_set_data(llm.Engine.ContextNativeHandle, State, (nuint)State.Length);
                    } finally {
                        //view.SafeMemoryMappedViewHandle.ReleasePointer();
                    }
                }
            }
        }
    }
}

public void SaveState(User _User, string _Name) {
    if (_User?.Mo?.ModelName is string Model) {
        LLM? llm = GetModel(Model);
        if (llm != null) {
            nint x = llm.Engine.ContextNativeHandle;
            uint ModelSize = (uint)Native.llama_state_get_size(x);
            byte[] data = new byte[ModelSize];
            Native.llama_state_get_data(x, data, ModelSize);
            string filename = IAmandaServer.StateBase + _Name + ".State";
            File.WriteAllBytes(filename, data);
        }
    }
}
@dranger003
Copy link
Owner

Ah, I see what you mean. For now, I added a raw sample to test state load/save. I will need to think about this, maybe I introduce a new LlmPrompt ctor signature to support state loading - not sure yet.

@dranger003 dranger003 added the enhancement New feature or request label Nov 13, 2024
@HarvieKrumpet
Copy link
Author

How I envision how to get this to work. An extra optional field in the LlmPrompt to trigger a state save which also tells your run thread the basename and path to save it. This can be slipped in without breaking any client code. Nor hurt the asynchronous nature of the server thread.

... ClientSide...
public class LlmPrompt {
  public string? saveStatePath;   // Includes path and basename for a statesave
  // just an extra optional field in your constructor
  public LlmPrompt(List<LlmMessage> messages, SamplingOptions samplingOptions, string? savePath = null) 
  {
    this.saveStatePath = savePath;
  }
}
  

... ServerSide In LLMEngine...
private unsafe void _Run() {
    ...
    var prompt = _prompts.Dequeue(cancellationToken);
    var ctxLength = (int)llama_n_ctx(_context.Handle);
    var tokens = Tokenize(prompt.Messages, llama_add_bos_token(_model.Handle) > 0, true);

    if (prompt.saveStatePath != null) { // If a path is defined, save the state and lastntokens
        string llmstate = SaveStatePath + ".state";
        string llmtokens = SaveStatePath + ".tokens"
        
        ... State Save...
        ... Token Save...
    }
    ...
    
    To restore a state, you could have a simple clientside operation with just the same basepathname
    
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants