Skip to content

Commit

Permalink
feat(diffusers): allow multiple lora adapters (#4081)
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Nov 5, 2024
1 parent 20cd881 commit 947224b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 18 deletions.
3 changes: 3 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ message ModelOptions {
bool NoKVOffload = 57;

string ModelPath = 59;

repeated string LoraAdapters = 60;
repeated float LoraScales = 61;
}

message Result {
Expand Down
16 changes: 15 additions & 1 deletion backend/python/diffusers/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,24 @@ def LoadModel(self, request, context):
if request.LoraAdapter:
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
# self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
self.pipe.load_lora_weights(request.LoraAdapter)
else:
self.pipe.unet.load_attn_procs(request.LoraAdapter)
if len(request.LoraAdapters) > 0:
i = 0
adapters_name = []
adapters_weights = []
for adapter in request.LoraAdapters:
if not os.path.isabs(adapter):
adapter = os.path.join(request.ModelPath, adapter)
self.pipe.load_lora_weights(adapter, adapter_name=f"adapter_{i}")
adapters_name.append(f"adapter_{i}")
i += 1

for adapters_weight in request.LoraScales:
adapters_weights.append(adapters_weight)

self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)

if request.CUDA:
self.pipe.to('cuda')
Expand Down
2 changes: 2 additions & 0 deletions core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraAdapters: c.LoraAdapters,
LoraScales: c.LoraScales,
F16Memory: f16,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
Expand Down
36 changes: 19 additions & 17 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,25 @@ type LLMConfig struct {
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`

ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
MMProj string `yaml:"mmproj"`
ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
MMProj string `yaml:"mmproj"`

FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
Expand Down

0 comments on commit 947224b

Please sign in to comment.