Skip to content

Commit

Permalink
feat: add cancel observation
Browse files Browse the repository at this point in the history
Signed-off-by: 1998-felix <[email protected]>
  • Loading branch information
felixgateru committed Jul 29, 2024
1 parent 6a9c7e3 commit 7f0cbb2
Showing 1 changed file with 103 additions and 17 deletions.
120 changes: 103 additions & 17 deletions pkg/coap/coap.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@ import (
"github.com/absmach/mproxy"
"github.com/absmach/mproxy/pkg/session"
"github.com/plgd-dev/go-coap/v3/dtls"
dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
"github.com/plgd-dev/go-coap/v3/mux"
"github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/options"
"github.com/plgd-dev/go-coap/v3/udp"
udpServer "github.com/plgd-dev/go-coap/v3/udp/server"
)

const startObserve uint32 = 0

var errUnsupportedMethod = errors.New("unsupported CoAP method")

type Proxy struct {
Expand All @@ -30,6 +34,30 @@ type Proxy struct {
logger *slog.Logger
}

type udpNilMonitor struct{}

func (u *udpNilMonitor) UDPServerApply(cfg *udpServer.Config) {
cfg.CreateInactivityMonitor = nil
}

func NewUDPNilMonitor() udpServer.Option {
return &udpNilMonitor{}
}

var _ udpServer.Option = (*udpNilMonitor)(nil)

type dtlsNilMonitor struct{}

func (d *dtlsNilMonitor) DTLSServerApply(cfg *dtlsServer.Config) {
cfg.CreateInactivityMonitor = nil
}

func NewDTLSNilMonitor() dtlsServer.Option {
return &dtlsNilMonitor{}
}

var _ udpServer.Option = (*udpNilMonitor)(nil)

func NewProxy(config mproxy.Config, handler session.Handler, logger *slog.Logger) *Proxy {
return &Proxy{
config: config,
Expand Down Expand Up @@ -105,7 +133,21 @@ func (p *Proxy) observeUpstream(ctx context.Context, cc mux.Conn, opts []message
defer outbound.Close()
doneObserving := make(chan struct{})

obs, err := outbound.Observe(ctx, path, func(req *pool.Message) {
pm := outbound.AcquireMessage(outbound.Context())
defer outbound.ReleaseMessage(pm)
pm.SetToken(token)
pm.SetCode(codes.GET)
for _, opt := range opts {
pm.SetOptionBytes(opt.ID, opt.Value)
}
if err := pm.SetPath(path); err != nil {
if err := sendErrorMessage(cc, token, err, codes.BadOption); err != nil {
p.logger.Error(fmt.Sprintf("cannot send error response: %v", err))
}
return
}

obs, err := outbound.DoObserve(pm, func(req *pool.Message) {
req.SetToken(token)
if err := cc.WriteMessage(req); err != nil {
if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil {
Expand All @@ -116,7 +158,7 @@ func (p *Proxy) observeUpstream(ctx context.Context, cc mux.Conn, opts []message
if req.Code() == codes.NotFound {
close(doneObserving)
}
}, opts...)
})
if err != nil {
if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil {
p.logger.Error(fmt.Sprintf("cannot send error response: %v", err))
Expand All @@ -133,6 +175,35 @@ func (p *Proxy) observeUpstream(ctx context.Context, cc mux.Conn, opts []message
}
}

func (p *Proxy) CancelObservation(cc mux.Conn, opts []message.Option, token []byte, path string) error {
outbound, err := udp.Dial(p.config.Target)
if err != nil {
if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil {
p.logger.Error(fmt.Sprintf("cannot send error response: %v", err))
}
}
defer outbound.Close()

pm := outbound.AcquireMessage(outbound.Context())
defer outbound.ReleaseMessage(pm)
pm.SetToken(token)
pm.SetCode(codes.GET)
for _, opt := range opts {
pm.SetOptionBytes(opt.ID, opt.Value)
}
if err := pm.SetPath(path); err != nil {
if err := sendErrorMessage(cc, token, err, codes.BadOption); err != nil {
p.logger.Error(fmt.Sprintf("cannot send error response: %v", err))
}
return err
}
if err := outbound.WriteMessage(pm); err != nil {
return err
}
pm.SetCode(codes.Content)
return cc.WriteMessage(pm)
}

func (p *Proxy) handler(w mux.ResponseWriter, r *mux.Message) {
tok, err := r.Options().GetBytes(message.URIQuery)
if err != nil {
Expand All @@ -157,14 +228,7 @@ func (p *Proxy) handler(w mux.ResponseWriter, r *mux.Message) {
}
switch r.Code() {
case codes.GET:
obs, err := r.Options().Observe()
if err != nil {
if err := sendErrorMessage(w.Conn(), r.Token(), err, codes.BadRequest); err != nil {
p.logger.Error(err.Error())
}
return
}
p.handleGet(ctx, path, w.Conn(), r.Token(), obs, r)
p.handleGet(ctx, path, w.Conn(), r.Token(), r)

case codes.POST:
body, err := r.ReadBody()
Expand All @@ -182,7 +246,7 @@ func (p *Proxy) handler(w mux.ResponseWriter, r *mux.Message) {
}
}

func (p *Proxy) handleGet(ctx context.Context, path string, con mux.Conn, token []byte, obs uint32, r *mux.Message) {
func (p *Proxy) handleGet(ctx context.Context, path string, con mux.Conn, token []byte, r *mux.Message) {
if err := p.session.AuthSubscribe(ctx, &[]string{path}); err != nil {
if err := sendErrorMessage(con, token, err, codes.Unauthorized); err != nil {
p.logger.Error(err.Error())
Expand All @@ -196,10 +260,26 @@ func (p *Proxy) handleGet(ctx context.Context, path string, con mux.Conn, token
return
}
switch {
// obs == 0, start observe
case obs == 0:
go p.observeUpstream(ctx, con, r.Options(), token, path)

case r.HasOption(message.Observe):
obs, err := r.Options().Observe()
if err != nil {
if err := sendErrorMessage(con, r.Token(), err, codes.BadRequest); err != nil {
p.logger.Error(err.Error())
}
return
}
switch obs {
case startObserve:
go p.observeUpstream(ctx, con, r.Options(), token, path)
default:
if err := p.CancelObservation(con, r.Options(), token, path); err != nil {
p.logger.Error(fmt.Sprintf("error performing cancel observation: %v\n", err))
if err := sendErrorMessage(con, token, err, codes.BadGateway); err != nil {
p.logger.Error(err.Error())
}
return
}
}
default:
if err := p.getUpstream(con, r, token); err != nil {
p.logger.Error(fmt.Sprintf("error performing get: %v\n", err))
Expand Down Expand Up @@ -242,7 +322,10 @@ func (p *Proxy) Listen(ctx context.Context) error {
defer l.Close()

p.logger.Info(fmt.Sprintf("CoAP proxy server started on port %s with DTLS", p.config.Address))
s := dtls.NewServer(options.WithMux(mux.HandlerFunc(p.handler)))
var dialOpts []dtlsServer.Option
dialOpts = append(dialOpts, options.WithMux(mux.HandlerFunc(p.handler)), NewDTLSNilMonitor())

s := dtls.NewServer(dialOpts...)

errCh := make(chan error)
go func() {
Expand All @@ -266,7 +349,10 @@ func (p *Proxy) Listen(ctx context.Context) error {
defer l.Close()

p.logger.Info(fmt.Sprintf("CoAP proxy server started at %s without DTLS", p.config.Address))
s := udp.NewServer(options.WithMux(mux.HandlerFunc(p.handler)))
var dialOpts []udpServer.Option
dialOpts = append(dialOpts, options.WithMux(mux.HandlerFunc(p.handler)), NewUDPNilMonitor())

s := udp.NewServer(dialOpts...)

errCh := make(chan error)
go func() {
Expand Down

0 comments on commit 7f0cbb2

Please sign in to comment.