Skip to content

Commit

Permalink
feat(silero): add Silero-vad backend (#4204)
Browse files Browse the repository at this point in the history
* feat(vad): add silero-vad backend (WIP)

Signed-off-by: Ettore Di Giacinto <[email protected]>

* feat(vad): add API endpoint

Signed-off-by: Ettore Di Giacinto <[email protected]>

* fix(vad): correctly place the onnxruntime libs

Signed-off-by: Ettore Di Giacinto <[email protected]>

* chore(vad): hook silero-vad to binary and container builds

Signed-off-by: Ettore Di Giacinto <[email protected]>

* feat(gRPC): register VAD Server

Signed-off-by: Ettore Di Giacinto <[email protected]>

* fix(Makefile): consume ONNX_OS consistently

Signed-off-by: Ettore Di Giacinto <[email protected]>

* fix(Makefile): handle macOS

Signed-off-by: Ettore Di Giacinto <[email protected]>

---------

Signed-off-by: Ettore Di Giacinto <[email protected]>
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Nov 20, 2024
1 parent 9892d7d commit b1ea931
Show file tree
Hide file tree
Showing 15 changed files with 255 additions and 1 deletion.
39 changes: 39 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057

ONNX_VERSION?=1.20.0
ONNX_ARCH?=x64
ONNX_OS?=linux

export BUILD_TYPE?=
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
export CMAKE_ARGS?=
Expand Down Expand Up @@ -89,7 +93,20 @@ ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif

# Detect if we are running on arm64
ifneq (,$(findstring aarch64,$(shell uname -m)))
ONNX_ARCH=aarch64
endif

ifeq ($(OS),Darwin)
ONNX_OS=osx
ifneq (,$(findstring aarch64,$(shell uname -m)))
ONNX_ARCH=arm64
else ifneq (,$(findstring arm64,$(shell uname -m)))
ONNX_ARCH=arm64
else
ONNX_ARCH=x86_64
endif

ifeq ($(OSX_SIGNING_IDENTITY),)
OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/')
Expand Down Expand Up @@ -195,6 +212,7 @@ ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
# Use filter-out to remove the specified backends
ALL_GRPC_BACKENDS := $(filter-out $(SKIP_GRPC_BACKEND),$(ALL_GRPC_BACKENDS))
Expand Down Expand Up @@ -281,6 +299,20 @@ sources/go-stable-diffusion:
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a

sources/onnxruntime:
mkdir -p sources/onnxruntime
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
cd sources/onnxruntime && tar -xvf onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz && rm onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
cd sources/onnxruntime && mv onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION)/* ./

backend-assets/lib/libonnxruntime.so.1: backend-assets/lib sources/onnxruntime
cp -rfv sources/onnxruntime/lib/* backend-assets/lib/
ifeq ($(OS),Darwin)
mv backend-assets/lib/libonnxruntime.$(ONNX_VERSION).dylib backend-assets/lib/libonnxruntime.dylib
else
mv backend-assets/lib/libonnxruntime.so.$(ONNX_VERSION) backend-assets/lib/libonnxruntime.so.1
endif

## tiny-dream
sources/go-tiny-dream:
mkdir -p sources/go-tiny-dream
Expand Down Expand Up @@ -837,6 +869,13 @@ ifneq ($(UPX),)
$(UPX) backend-assets/grpc/stablediffusion
endif

backend-assets/grpc/silero-vad: backend-assets/grpc backend-assets/lib/libonnxruntime.so.1
CGO_LDFLAGS="$(CGO_LDFLAGS)" CPATH="$(CPATH):$(CURDIR)/sources/onnxruntime/include/" LIBRARY_PATH=$(CURDIR)/backend-assets/lib \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/silero-vad ./backend/go/vad/silero
ifneq ($(UPX),)
$(UPX) backend-assets/grpc/silero-vad
endif

backend-assets/grpc/tinydream: sources/go-tiny-dream sources/go-tiny-dream/libtinydream.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/go-tiny-dream \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream
Expand Down
15 changes: 15 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ service Backend {
rpc Rerank(RerankRequest) returns (RerankResult) {}

rpc GetMetrics(MetricsRequest) returns (MetricsResponse);

rpc VAD(VADRequest) returns (VADResponse) {}
}

// Define the empty request
Expand Down Expand Up @@ -293,6 +295,19 @@ message TTSRequest {
optional string language = 5;
}

message VADRequest {
repeated float audio = 1;
}

message VADSegment {
float start = 1;
float end = 2;
}

message VADResponse {
repeated VADSegment segments = 1;
}

message SoundGenerationRequest {
string text = 1;
string model = 2;
Expand Down
21 changes: 21 additions & 0 deletions backend/go/vad/silero/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package main

// Note: this is started internally by LocalAI and a server is allocated for each model

import (
"flag"

grpc "github.com/mudler/LocalAI/pkg/grpc"
)

var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)

func main() {
flag.Parse()

if err := grpc.StartServer(*addr, &VAD{}); err != nil {
panic(err)
}
}
54 changes: 54 additions & 0 deletions backend/go/vad/silero/vad.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package main

// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"

"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/streamer45/silero-vad-go/speech"
)

type VAD struct {
base.SingleThread
detector *speech.Detector
}

func (vad *VAD) Load(opts *pb.ModelOptions) error {
v, err := speech.NewDetector(speech.DetectorConfig{
ModelPath: opts.ModelFile,
SampleRate: 16000,
//WindowSize: 1024,
Threshold: 0.5,
MinSilenceDurationMs: 0,
SpeechPadMs: 0,
})
if err != nil {
return fmt.Errorf("create silero detector: %w", err)
}

vad.detector = v
return err
}

func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
audio := req.Audio

segments, err := vad.detector.Detect(audio)
if err != nil {
return pb.VADResponse{}, fmt.Errorf("detect: %w", err)
}

vadSegments := []*pb.VADSegment{}
for _, s := range segments {
vadSegments = append(vadSegments, &pb.VADSegment{
Start: float32(s.SpeechStartAt),
End: float32(s.SpeechEndAt),
})
}

return pb.VADResponse{
Segments: vadSegments,
}, nil
}
68 changes: 68 additions & 0 deletions core/http/endpoints/localai/vad.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package localai

import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)

// VADEndpoint is Voice-Activation-Detection endpoint
// @Summary Detect voice fragments in an audio stream
// @Accept json
// @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {

input := new(schema.VADRequest)

// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}

modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}

cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)

if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)

opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))

vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil {
return err
}

return c.JSON(resp)
}
}
1 change: 1 addition & 0 deletions core/http/routes/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func RegisterLocalAIRoutes(app *fiber.App,
}

app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))

// Stores
sl := model.NewModelLoader("")
Expand Down
8 changes: 7 additions & 1 deletion core/schema/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ type TTSRequest struct {
Input string `json:"input" yaml:"input"` // text input
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
Backend string `json:"backend" yaml:"backend"`
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
}

// @Description VAD request body
type VADRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
}

type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ require (
github.com/pion/turn/v2 v2.1.6 // indirect
github.com/pion/webrtc/v3 v3.3.0 // indirect
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
github.com/streamer45/silero-vad-go v0.2.1 // indirect
github.com/urfave/cli/v2 v2.27.4 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
github.com/wlynxg/anet v0.0.4 // indirect
go.uber.org/mock v0.4.0 // indirect
)
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
2 changes: 2 additions & 0 deletions pkg/grpc/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,6 @@ type Backend interface {
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)

GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error)

VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
}
4 changes: 4 additions & 0 deletions pkg/grpc/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ func (llm *Base) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
return pb.StoresFindResult{}, fmt.Errorf("unimplemented")
}

func (llm *Base) VAD(*pb.VADRequest) (pb.VADResponse, error) {
return pb.VADResponse{}, fmt.Errorf("unimplemented")
}

func memoryUsage() *pb.MemoryUsageData {
mud := pb.MemoryUsageData{
Breakdown: make(map[string]uint64),
Expand Down
18 changes: 18 additions & 0 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,21 @@ func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opt
client := pb.NewBackendClient(conn)
return client.GetMetrics(ctx, in, opts...)
}

func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.VAD(ctx, in, opts...)
}
4 changes: 4 additions & 0 deletions pkg/grpc/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts ..
return e.s.Rerank(ctx, in)
}

func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
return e.s.VAD(ctx, in)
}

func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) {
return e.s.GetMetrics(ctx, in)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/grpc/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type LLM interface {
StoresDelete(*pb.StoresDeleteOptions) error
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)

VAD(*pb.VADRequest) (pb.VADResponse, error)
}

func newReply(s string) *pb.Reply {
Expand Down
12 changes: 12 additions & 0 deletions pkg/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.
return &res, nil
}

func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.VAD(in)
if err != nil {
return nil, err
}
return &res, nil
}

func StartServer(address string, model LLM) error {
lis, err := net.Listen("tcp", address)
if err != nil {
Expand Down

0 comments on commit b1ea931

Please sign in to comment.