Skip to content

Commit

Permalink
cleaned code and increased logging, set timeout on opened websockets …
Browse files Browse the repository at this point in the history
…that do not auth
  • Loading branch information
purehyperbole committed Feb 5, 2019
1 parent d044eed commit ea2c65f
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 91 deletions.
13 changes: 11 additions & 2 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

95 changes: 70 additions & 25 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,63 +9,108 @@ import (
"errors"
"fmt"
"log"
"net/http"
"time"

"github.com/dgrijalva/jwt-go"
"github.com/gorilla/websocket"
"github.com/r3labs/broadcast"
)

// Session : stores authentication data
type Session struct {
Token string `json:"token"`
Stream *string `json:"stream"`
EventID *string `json:"event_id"`
Username string
Authenticated bool
Username string `json:"-"`
authenticated bool
subscriber *broadcast.Subscriber
channel chan *broadcast.Event
}

func unauthorized(w http.ResponseWriter) error {
log.Println("Unauthorized")
func unauthorized(c *websocket.Conn, err error) error {
if err != nil {
log.Println("Unauthorized:", err.Error())
} else {
log.Println("Unauthorized")
}
_ = c.WriteMessage(websocket.CloseMessage, []byte(`{"status": "unauthorized"}`))
return errors.New("Unauthorized")
}

func authenticate(w http.ResponseWriter, c *websocket.Conn) (*Session, error) {
var s Session
func getAuthMessage(c *websocket.Conn, s *Session) error {
// timeout after 2 seconds if no request is sent
c.SetReadDeadline(time.Now().Add(time.Second * 2))

mt, message, err := c.ReadMessage()
_, message, err := c.ReadMessage()
if err != nil {
log.Println(string(message))
return nil, badrequest(w)
return err
}

err = json.Unmarshal(message, &s)
if err != nil {
log.Println(string(message))
return nil, badrequest(w)
return json.Unmarshal(message, &s)
}

func register(stream *string, username, requestID string) (*broadcast.Subscriber, chan *broadcast.Event, error) {
if stream == nil {
return nil, nil, errors.New("no stream specified")
}

token, err := jwt.Parse(s.Token, func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return []byte(secret), nil
})
log.Printf("[%s] subscribing to stream: %s\n", requestID, *stream)

if !bc.StreamExists(*stream) && !bc.AutoStream {
return nil, nil, errors.New("stream does not exist")
} else if !bc.StreamExists(*stream) && bc.AutoStream {
bc.CreateStream(*stream)
}

sub := bc.GetStreamSubscriber(*stream, username)
if sub == nil {
sub = broadcast.NewSubscriber(username)
bc.Register(*stream, sub)
}

return sub, sub.Connect(), nil
}

func authenticate(c *websocket.Conn, requestID string) (*Session, error) {
var s Session

log.Printf("[%s] authenticating user\n", requestID)

err := getAuthMessage(c, &s)
if err != nil {
return nil, err
}

token, err := jwt.Parse(s.Token, jwtVerify)
if err != nil || !token.Valid {
_ = c.WriteMessage(mt, []byte(`{"status": "unauthorized"}`))
return nil, unauthorized(w)
return nil, errors.New("invalid token")
}

s.Authenticated = true
s.authenticated = true

claims, ok := token.Claims.(jwt.MapClaims)
if ok {
s.Username = claims["username"].(string)
}

err = c.WriteMessage(mt, []byte(`{"status": "ok"}`))
log.Printf("[%s] user authenticated\n", requestID)
err = c.WriteMessage(websocket.TextMessage, []byte(`{"status": "ok"}`))
if err != nil {
return nil, internalerror(w)
return nil, err
}

// register to stream
s.subscriber, s.channel, err = register(s.Stream, s.Username, requestID)
if err != nil {
return nil, err
}

return &s, nil
}

func jwtVerify(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return []byte(secret), nil
}
98 changes: 34 additions & 64 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ package main

import (
//
"errors"

"log"
"net/http"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/r3labs/broadcast"
)

var upgrader = websocket.Upgrader{
Expand All @@ -21,91 +21,61 @@ var upgrader = websocket.Upgrader{
}

func handler(w http.ResponseWriter, r *http.Request) {
var session *Session

reqid := uuid.New().String()

log.Println("client connected:", reqid)

c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
upgradefail(w)
upgradefail(w, err)
return
}

var authorized bool
var areq *Session
var ch chan *broadcast.Event
var sub *broadcast.Subscriber

defer func() {
log.Println("client disconnected:", reqid)
_ = c.Close()

if ch != nil && sub != nil {
sub.Disconnect(ch)
if session.channel != nil && session.subscriber != nil {
session.subscriber.Disconnect(
session.channel,
)
}
}()

for {
if !authorized {
areq, err = authenticate(w, c)
if err != nil {
return
}

sub, ch, err = register(w, areq)
if !session.authenticated {
session, err = authenticate(c, reqid)
if err != nil {
return
}

authorized = true
} else {
msg, ok := <-ch
if !ok {
return
}

log.Println("Sending Message to ", areq.Stream)
err := c.WriteMessage(websocket.TextMessage, msg.Data)
if err != nil {
log.Println("failed to write to connection")
_ = internalerror(w)
badrequest(c, reqid, err)
return
}
}
}
}

func register(w http.ResponseWriter, s *Session) (*broadcast.Subscriber, chan *broadcast.Event, error) {
if s.Stream == nil {
return nil, nil, badstream(w)
}
msg, ok := <-session.channel
if !ok {
log.Printf("[%s] event channel closed: %s\n", reqid, *session.Stream)
return
}

if !bc.StreamExists(*s.Stream) && !bc.AutoStream {
return nil, nil, badstream(w)
} else if !bc.StreamExists(*s.Stream) && bc.AutoStream {
bc.CreateStream(*s.Stream)
}
log.Println("sending message to:", *session.Stream)

sub := bc.GetStreamSubscriber(*s.Stream, s.Username)
if sub == nil {
sub = broadcast.NewSubscriber(s.Username)
bc.Register(*s.Stream, sub)
err := c.WriteMessage(websocket.TextMessage, msg.Data)
if err != nil {
badrequest(c, reqid, err)
return
}
}

return sub, sub.Connect(), nil
}

func upgradefail(w http.ResponseWriter) {
log.Println("Unable to upgrade to websocket connection")
func upgradefail(w http.ResponseWriter, err error) {
log.Println("Unable to upgrade to websocket connection:", err.Error())
http.Error(w, "Unable to upgrade to websocket connection", http.StatusBadRequest)
}

func badrequest(w http.ResponseWriter) error {
log.Println("Could not process sent data")
return errors.New("Could not process sent data")
}

func badstream(w http.ResponseWriter) error {
log.Println("Please specify a valid stream")
return errors.New("Please specify a valid stream")
}

func internalerror(w http.ResponseWriter) error {
log.Println("Internal server error")
return errors.New("Internal server error")
func badrequest(c *websocket.Conn, reqid string, err error) {
log.Printf("[%s] bad request: %s\n", reqid, err.Error())
_ = c.WriteMessage(websocket.CloseUnsupportedData, []byte(`{"error": "bad request"}`))
c.Close()
}

0 comments on commit ea2c65f

Please sign in to comment.