diff --git a/_examples/golang-sse/main.go b/_examples/golang-sse/main.go index 113b120c..0ec190b4 100644 --- a/_examples/golang-sse/main.go +++ b/_examples/golang-sse/main.go @@ -59,7 +59,8 @@ func router() http.Handler { w.Write([]byte(".")) }) - webrpcHandler := proto.NewChatServer(&ChatServer{}) + rpc := NewChatServer() + webrpcHandler := proto.NewChatServer(rpc) r.Handle("/*", webrpcHandler) return r diff --git a/_examples/golang-sse/rpc.go b/_examples/golang-sse/rpc.go index 73bc3e82..08d50ba4 100644 --- a/_examples/golang-sse/rpc.go +++ b/_examples/golang-sse/rpc.go @@ -3,25 +3,52 @@ package main import ( "context" "fmt" + "log/slog" "math/rand" - "time" + "sync" "github.com/webrpc/webrpc/_example/golang-sse/proto" ) type ChatServer struct { + mu sync.Mutex + lastId uint64 + subscriptions map[uint64]chan *proto.Message +} + +func NewChatServer() *ChatServer { + return &ChatServer{ + subscriptions: map[uint64]chan *proto.Message{}, + } } func (s *ChatServer) SendMessage(ctx context.Context, authorName string, text string) error { + msg := &proto.Message{ + Id: uint64(rand.Uint64()), + AuthorName: authorName, + Text: text, + } + + slog.Info("broadcasting message", + "author", msg.AuthorName, + "text", msg.Text, + "subscribers", len(s.subscriptions), + ) + for _, sub := range s.subscriptions { + sub := sub + go func() { + sub <- msg + }() + } + return nil } func (s *ChatServer) SubscribeMessages(ctx context.Context, serverTimeoutSec int, stream proto.SubscribeMessagesStreamWriter) error { - if serverTimeoutSec > 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(serverTimeoutSec)*time.Second) - } + msgs := make(chan *proto.Message, 10) + defer s.unsubscribe(s.subscribe(msgs)) - for i := 0; i < 10; i++ { + for { select { case <-ctx.Done(): switch err := ctx.Err(); err { @@ -31,20 +58,28 @@ func (s *ChatServer) SubscribeMessages(ctx context.Context, serverTimeoutSec int return proto.ErrConnectionTooLong.WithCause(fmt.Errorf("timed out after %vs", serverTimeoutSec)) } - case <-time.After(time.Duration(rand.Intn(1000)) * time.Millisecond): - // Simulate work. Delay each message by 0-1000ms. + case msg := <-msgs: + if err := stream.Write(msg); err != nil { + return err + } } + } +} - msg := &proto.Message{ - Id: rand.Uint64(), - AuthorName: "Alice", - Text: fmt.Sprintf("Message %v", i), - } +func (s *ChatServer) subscribe(c chan *proto.Message) uint64 { + s.mu.Lock() + defer s.mu.Unlock() - if err := stream.Write(msg); err != nil { - return err - } - } + id := s.lastId + s.subscriptions[id] = c + s.lastId++ - return nil + return id +} + +func (s *ChatServer) unsubscribe(subscriptionId uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.subscriptions, subscriptionId) }