diff --git a/pkg/coap/coap.go b/pkg/coap/coap.go index ba778ef..3569d5d 100644 --- a/pkg/coap/coap.go +++ b/pkg/coap/coap.go @@ -13,6 +13,7 @@ import ( "github.com/absmach/mproxy" "github.com/absmach/mproxy/pkg/session" "github.com/plgd-dev/go-coap/v3/dtls" + dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/plgd-dev/go-coap/v3/message/pool" @@ -20,8 +21,11 @@ import ( "github.com/plgd-dev/go-coap/v3/net" "github.com/plgd-dev/go-coap/v3/options" "github.com/plgd-dev/go-coap/v3/udp" + udpServer "github.com/plgd-dev/go-coap/v3/udp/server" ) +const startObserve uint32 = 0 + var errUnsupportedMethod = errors.New("unsupported CoAP method") type Proxy struct { @@ -30,6 +34,30 @@ type Proxy struct { logger *slog.Logger } +type udpNilMonitor struct{} + +func (u *udpNilMonitor) UDPServerApply(cfg *udpServer.Config) { + cfg.CreateInactivityMonitor = nil +} + +func NewUDPNilMonitor() udpServer.Option { + return &udpNilMonitor{} +} + +var _ udpServer.Option = (*udpNilMonitor)(nil) + +type dtlsNilMonitor struct{} + +func (d *dtlsNilMonitor) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.CreateInactivityMonitor = nil +} + +func NewDTLSNilMonitor() dtlsServer.Option { + return &dtlsNilMonitor{} +} + +var _ udpServer.Option = (*udpNilMonitor)(nil) + func NewProxy(config mproxy.Config, handler session.Handler, logger *slog.Logger) *Proxy { return &Proxy{ config: config, @@ -105,7 +133,21 @@ func (p *Proxy) observeUpstream(ctx context.Context, cc mux.Conn, opts []message defer outbound.Close() doneObserving := make(chan struct{}) - obs, err := outbound.Observe(ctx, path, func(req *pool.Message) { + pm := outbound.AcquireMessage(outbound.Context()) + defer outbound.ReleaseMessage(pm) + pm.SetToken(token) + pm.SetCode(codes.GET) + for _, opt := range opts { + pm.SetOptionBytes(opt.ID, opt.Value) + } + if err := pm.SetPath(path); err != nil { + if err := sendErrorMessage(cc, token, err, codes.BadOption); err != nil { + p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) + } + return + } + + obs, err := outbound.DoObserve(pm, func(req *pool.Message) { req.SetToken(token) if err := cc.WriteMessage(req); err != nil { if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil { @@ -116,7 +158,7 @@ func (p *Proxy) observeUpstream(ctx context.Context, cc mux.Conn, opts []message if req.Code() == codes.NotFound { close(doneObserving) } - }, opts...) + }) if err != nil { if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil { p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) @@ -133,6 +175,35 @@ func (p *Proxy) observeUpstream(ctx context.Context, cc mux.Conn, opts []message } } +func (p *Proxy) CancelObservation(cc mux.Conn, opts []message.Option, token []byte, path string) error { + outbound, err := udp.Dial(p.config.Target) + if err != nil { + if err := sendErrorMessage(cc, token, err, codes.BadGateway); err != nil { + p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) + } + } + defer outbound.Close() + + pm := outbound.AcquireMessage(outbound.Context()) + defer outbound.ReleaseMessage(pm) + pm.SetToken(token) + pm.SetCode(codes.GET) + for _, opt := range opts { + pm.SetOptionBytes(opt.ID, opt.Value) + } + if err := pm.SetPath(path); err != nil { + if err := sendErrorMessage(cc, token, err, codes.BadOption); err != nil { + p.logger.Error(fmt.Sprintf("cannot send error response: %v", err)) + } + return err + } + if err := outbound.WriteMessage(pm); err != nil { + return err + } + pm.SetCode(codes.Content) + return cc.WriteMessage(pm) +} + func (p *Proxy) handler(w mux.ResponseWriter, r *mux.Message) { tok, err := r.Options().GetBytes(message.URIQuery) if err != nil { @@ -157,14 +228,7 @@ func (p *Proxy) handler(w mux.ResponseWriter, r *mux.Message) { } switch r.Code() { case codes.GET: - obs, err := r.Options().Observe() - if err != nil { - if err := sendErrorMessage(w.Conn(), r.Token(), err, codes.BadRequest); err != nil { - p.logger.Error(err.Error()) - } - return - } - p.handleGet(ctx, path, w.Conn(), r.Token(), obs, r) + p.handleGet(ctx, path, w.Conn(), r.Token(), r) case codes.POST: body, err := r.ReadBody() @@ -182,7 +246,7 @@ func (p *Proxy) handler(w mux.ResponseWriter, r *mux.Message) { } } -func (p *Proxy) handleGet(ctx context.Context, path string, con mux.Conn, token []byte, obs uint32, r *mux.Message) { +func (p *Proxy) handleGet(ctx context.Context, path string, con mux.Conn, token []byte, r *mux.Message) { if err := p.session.AuthSubscribe(ctx, &[]string{path}); err != nil { if err := sendErrorMessage(con, token, err, codes.Unauthorized); err != nil { p.logger.Error(err.Error()) @@ -196,10 +260,26 @@ func (p *Proxy) handleGet(ctx context.Context, path string, con mux.Conn, token return } switch { - // obs == 0, start observe - case obs == 0: - go p.observeUpstream(ctx, con, r.Options(), token, path) - + case r.HasOption(message.Observe): + obs, err := r.Options().Observe() + if err != nil { + if err := sendErrorMessage(con, r.Token(), err, codes.BadRequest); err != nil { + p.logger.Error(err.Error()) + } + return + } + switch obs { + case startObserve: + go p.observeUpstream(ctx, con, r.Options(), token, path) + default: + if err := p.CancelObservation(con, r.Options(), token, path); err != nil { + p.logger.Error(fmt.Sprintf("error performing cancel observation: %v\n", err)) + if err := sendErrorMessage(con, token, err, codes.BadGateway); err != nil { + p.logger.Error(err.Error()) + } + return + } + } default: if err := p.getUpstream(con, r, token); err != nil { p.logger.Error(fmt.Sprintf("error performing get: %v\n", err)) @@ -242,7 +322,10 @@ func (p *Proxy) Listen(ctx context.Context) error { defer l.Close() p.logger.Info(fmt.Sprintf("CoAP proxy server started on port %s with DTLS", p.config.Address)) - s := dtls.NewServer(options.WithMux(mux.HandlerFunc(p.handler))) + var dialOpts []dtlsServer.Option + dialOpts = append(dialOpts, options.WithMux(mux.HandlerFunc(p.handler)), NewDTLSNilMonitor()) + + s := dtls.NewServer(dialOpts...) errCh := make(chan error) go func() { @@ -266,7 +349,10 @@ func (p *Proxy) Listen(ctx context.Context) error { defer l.Close() p.logger.Info(fmt.Sprintf("CoAP proxy server started at %s without DTLS", p.config.Address)) - s := udp.NewServer(options.WithMux(mux.HandlerFunc(p.handler))) + var dialOpts []udpServer.Option + dialOpts = append(dialOpts, options.WithMux(mux.HandlerFunc(p.handler)), NewUDPNilMonitor()) + + s := udp.NewServer(dialOpts...) errCh := make(chan error) go func() {