From b86dc1dffd56fea7e523e7c653a9439e90bfaef9 Mon Sep 17 00:00:00 2001 From: Christian Ege Date: Mon, 10 Jun 2024 17:45:04 +0200 Subject: [PATCH] fix: upload the image as multipart This is quite memory hungry because the whole file is copied to memory first --- cmd/ovp8xx/cmd/swupdate.go | 60 ++++++++++++++++++++++++++++ pkg/swupdater/swupdater.go | 80 +++++++++++++++++++++++++------------- 2 files changed, 113 insertions(+), 27 deletions(-) create mode 100644 cmd/ovp8xx/cmd/swupdate.go diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go new file mode 100644 index 0000000..04fbac8 --- /dev/null +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -0,0 +1,60 @@ +/* +Copyright © 2023 Christian Ege +*/ +package cmd + +import ( + "fmt" + "time" + + "github.com/graugans/go-ovp8xx/pkg/swupdater" + "github.com/spf13/cobra" +) + +func swupdateCommand(cmd *cobra.Command, args []string) error { + var err error + host, err := rootCmd.PersistentFlags().GetString("ip") + if err != nil { + return fmt.Errorf("cannot get host: %w", err) + } + + port, err := cmd.Flags().GetUint16("port") + if err != nil { + return fmt.Errorf("cannot get port: %w", err) + } + + filename, err := cmd.Flags().GetString("file") + if err != nil { + return fmt.Errorf("cannot get filename: %w", err) + } + + timeout, err := cmd.Flags().GetDuration("timeout") + if err != nil { + return fmt.Errorf("cannot get timeout: %w", err) + } + + fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", host, port, filename, timeout) + + swu := swupdater.NewSWUpdater(host, port) + + err = swu.Update(filename, timeout) + if err != nil { + return fmt.Errorf("software update failed: %w", err) + } + + return nil +} + +// swupdateCmd represents the swupdate command +var swupdateCmd = &cobra.Command{ + Use: "swupdate", + Short: "Update the firmware on the device", + RunE: swupdateCommand, +} + +func init() { + rootCmd.AddCommand(swupdateCmd) + swupdateCmd.Flags().String("file", "", "A file conatining the firmware image") + swupdateCmd.Flags().Uint16("port", 8080, "Port number for SWUpdate") + swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload") +} diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 554fc31..10ab820 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -1,11 +1,15 @@ package swupdater import ( + "bytes" "encoding/json" "errors" "fmt" "io" + "mime/multipart" "net/http" + "os" + "strconv" "time" "github.com/gorilla/websocket" @@ -13,30 +17,52 @@ import ( type SWUpdater struct { hostName string - port int - path string + port uint16 urlUpload string urlStatus string - done chan error } -func NewSWUpdater(hostName, path string, port int) *SWUpdater { +func NewSWUpdater(hostName string, port uint16) *SWUpdater { return &SWUpdater{ hostName: hostName, port: port, - path: path, - urlUpload: fmt.Sprintf("http://%s:%d%s/upload", hostName, port, path), - urlStatus: fmt.Sprintf("ws://%s:%d%s/ws", hostName, port, path), - done: make(chan error), + urlUpload: fmt.Sprintf("http://%s:%d/upload", hostName, port), + urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port), } } +func (s *SWUpdater) upload(filename string, timeout time.Duration) error { + image, err := os.Open(filename) + if err != nil { + return fmt.Errorf("cannot open file: %w", err) + } + fmt.Printf("Uploading software image to %s\n", s.urlUpload) + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return fmt.Errorf("cannot create form file: %w", err) + } -func (s *SWUpdater) upload(image io.Reader, timeout time.Duration) error { - req, err := http.NewRequest("POST", s.urlUpload, image) + _, err = io.Copy(part, image) + if err != nil { + return fmt.Errorf("cannot write to form file: %w", err) + } + + err = writer.Close() + if err != nil { + return fmt.Errorf("cannot close multipart writer: %w", err) + } + + req, err := http.NewRequest("POST", s.urlUpload, bytes.NewReader(body.Bytes())) if err != nil { return fmt.Errorf("cannot create request: %w", err) } + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Length", strconv.Itoa(body.Len())) + client := &http.Client{Timeout: timeout} resp, err := client.Do(req) if err != nil { @@ -51,10 +77,10 @@ func (s *SWUpdater) upload(image io.Reader, timeout time.Duration) error { return nil } -func (s *SWUpdater) waitForFinished() { +func (s *SWUpdater) waitForFinished(done chan error) { c, _, err := websocket.DefaultDialer.Dial(s.urlStatus, nil) if err != nil { - s.done <- fmt.Errorf("cannot connect to websocket: %w", err) + done <- fmt.Errorf("cannot connect to websocket: %w", err) return } defer c.Close() @@ -62,44 +88,44 @@ func (s *SWUpdater) waitForFinished() { for { _, message, err := c.ReadMessage() if err != nil { - s.done <- fmt.Errorf("cannot read message from websocket: %w", err) + done <- fmt.Errorf("cannot read message from websocket: %w", err) return } data := make(map[string]string) err = json.Unmarshal(message, &data) if err != nil { - continue + done <- fmt.Errorf("cannot unmarshal message: %w", err) + return } - + fmt.Println("Raw JSON: ", data) if data["type"] != "message" { continue } if data["text"] == "SWUPDATE successful" { - s.done <- nil + done <- nil return } if data["text"] == "Installation failed" { - s.done <- errors.New("installation failed") + done <- errors.New("installation failed") return } } } -func (s *SWUpdater) Update(image io.Reader, timeout time.Duration) error { - go s.waitForFinished() - go func() { - err := s.upload(image, timeout) - if err != nil { - s.done <- err - } - }() +func (s *SWUpdater) Update(filename string, timeout time.Duration) error { + done := make(chan error) + go s.waitForFinished(done) + err := s.upload(filename, timeout) + if err != nil { + return fmt.Errorf("cannot upload software image: %w", err) + } select { - case err := <-s.done: + case err := <-done: if err != nil { - return err + return fmt.Errorf("update failed: %w", err) } return nil case <-time.After(timeout):