Skip to content

Commit

Permalink
Fix stream close (#2)
Browse files Browse the repository at this point in the history
Signed-off-by: Sasha Savchuk <[email protected]>
Co-authored-by: Oleg Kovalov <[email protected]>
  • Loading branch information
mr-linch and cristaloleg authored Oct 5, 2023
1 parent e64b660 commit 253c117
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
3 changes: 2 additions & 1 deletion sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (u Upgrader) UpgradeHTTP(r *http.Request, w http.ResponseWriter) (*Stream,
return nil, ErrNotHijacker
}

_, bw, err := hj.Hijack()
nc, bw, err := hj.Hijack()
if err != nil {
http.Error(w, http.ErrHijacked.Error(), http.StatusInternalServerError)
return nil, http.ErrHijacked
Expand All @@ -40,6 +40,7 @@ func (u Upgrader) UpgradeHTTP(r *http.Request, w http.ResponseWriter) (*Stream,
}

s := &Stream{
nc: nc,
bw: bw,
w: w,
}
Expand Down
8 changes: 6 additions & 2 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"strconv"
"time"
)

type Stream struct {
bw *bufio.ReadWriter
w io.Writer
nc net.Conn
}

type BinaryMarshaler interface {
Expand All @@ -26,9 +28,11 @@ func (s *Stream) Flush() error {
return s.bw.Flush()
}

// Close sends close event with empth data.
// Close sends close event with empty data and closes underlying connection.
func (s *Stream) Close() error {
_, err := s.w.Write([]byte("event:close\ndata:\n\n"))
defer s.nc.Close()

_, err := s.bw.Write([]byte("event:close\ndata:\n\n"))
if err != nil {
return err
}
Expand Down
36 changes: 36 additions & 0 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sse
import (
"bufio"
"bytes"
"io"
"net/http"
"testing"
"time"
)
Expand Down Expand Up @@ -38,3 +40,37 @@ func newStream() (*Stream, *bytes.Buffer) {
}
return s, buf
}

func TestStream_Close(t *testing.T) {
server := newServer(func(w http.ResponseWriter, r *http.Request) {
u := Upgrader{}

stream, err := u.UpgradeHTTP(r, w)
if err != nil {
t.Fatal(err)
}

if err := stream.Close(); err != nil {
t.Fatalf("stream.Close() = %v, want nil", err)
}
})

client := http.DefaultClient

resp, err := client.Do(newStreamRequest(server.URL))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

want := "event:close\ndata:\n\n"

if got := string(body); got != want {
t.Fatalf("got %#v, want %#v", got, want)
}
}

0 comments on commit 253c117

Please sign in to comment.