Skip to content

Commit

Permalink
fix: upload the image as multipart
Browse files Browse the repository at this point in the history
This is quite memory hungry because the whole file is copied to memory first
  • Loading branch information
graugans committed Jun 10, 2024
1 parent e7d85d4 commit b86dc1d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 27 deletions.
60 changes: 60 additions & 0 deletions cmd/ovp8xx/cmd/swupdate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Copyright © 2023 Christian Ege <[email protected]>
*/
package cmd

import (
"fmt"
"time"

"github.com/graugans/go-ovp8xx/pkg/swupdater"
"github.com/spf13/cobra"
)

func swupdateCommand(cmd *cobra.Command, args []string) error {
var err error
host, err := rootCmd.PersistentFlags().GetString("ip")
if err != nil {
return fmt.Errorf("cannot get host: %w", err)
}

port, err := cmd.Flags().GetUint16("port")
if err != nil {
return fmt.Errorf("cannot get port: %w", err)
}

filename, err := cmd.Flags().GetString("file")
if err != nil {
return fmt.Errorf("cannot get filename: %w", err)
}

timeout, err := cmd.Flags().GetDuration("timeout")
if err != nil {
return fmt.Errorf("cannot get timeout: %w", err)
}

fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", host, port, filename, timeout)

swu := swupdater.NewSWUpdater(host, port)

err = swu.Update(filename, timeout)
if err != nil {
return fmt.Errorf("software update failed: %w", err)
}

return nil
}

// swupdateCmd represents the swupdate command
var swupdateCmd = &cobra.Command{
Use: "swupdate",
Short: "Update the firmware on the device",
RunE: swupdateCommand,
}

func init() {
rootCmd.AddCommand(swupdateCmd)
swupdateCmd.Flags().String("file", "", "A file conatining the firmware image")
swupdateCmd.Flags().Uint16("port", 8080, "Port number for SWUpdate")
swupdateCmd.Flags().Duration("timeout", 5*time.Minute, "The timeout for the upload")
}
80 changes: 53 additions & 27 deletions pkg/swupdater/swupdater.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,68 @@
package swupdater

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"strconv"
"time"

"github.com/gorilla/websocket"
)

type SWUpdater struct {
hostName string
port int
path string
port uint16
urlUpload string
urlStatus string
done chan error
}

func NewSWUpdater(hostName, path string, port int) *SWUpdater {
func NewSWUpdater(hostName string, port uint16) *SWUpdater {
return &SWUpdater{
hostName: hostName,
port: port,
path: path,
urlUpload: fmt.Sprintf("http://%s:%d%s/upload", hostName, port, path),
urlStatus: fmt.Sprintf("ws://%s:%d%s/ws", hostName, port, path),
done: make(chan error),
urlUpload: fmt.Sprintf("http://%s:%d/upload", hostName, port),
urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port),
}
}
func (s *SWUpdater) upload(filename string, timeout time.Duration) error {
image, err := os.Open(filename)
if err != nil {
return fmt.Errorf("cannot open file: %w", err)
}
fmt.Printf("Uploading software image to %s\n", s.urlUpload)

body := &bytes.Buffer{}
writer := multipart.NewWriter(body)

part, err := writer.CreateFormFile("file", filename)
if err != nil {
return fmt.Errorf("cannot create form file: %w", err)
}

func (s *SWUpdater) upload(image io.Reader, timeout time.Duration) error {
req, err := http.NewRequest("POST", s.urlUpload, image)
_, err = io.Copy(part, image)
if err != nil {
return fmt.Errorf("cannot write to form file: %w", err)
}

err = writer.Close()
if err != nil {
return fmt.Errorf("cannot close multipart writer: %w", err)
}

req, err := http.NewRequest("POST", s.urlUpload, bytes.NewReader(body.Bytes()))
if err != nil {
return fmt.Errorf("cannot create request: %w", err)
}

req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Content-Length", strconv.Itoa(body.Len()))

client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
Expand All @@ -51,55 +77,55 @@ func (s *SWUpdater) upload(image io.Reader, timeout time.Duration) error {
return nil
}

func (s *SWUpdater) waitForFinished() {
func (s *SWUpdater) waitForFinished(done chan error) {
c, _, err := websocket.DefaultDialer.Dial(s.urlStatus, nil)
if err != nil {
s.done <- fmt.Errorf("cannot connect to websocket: %w", err)
done <- fmt.Errorf("cannot connect to websocket: %w", err)
return
}
defer c.Close()

for {
_, message, err := c.ReadMessage()
if err != nil {
s.done <- fmt.Errorf("cannot read message from websocket: %w", err)
done <- fmt.Errorf("cannot read message from websocket: %w", err)
return
}

data := make(map[string]string)
err = json.Unmarshal(message, &data)
if err != nil {
continue
done <- fmt.Errorf("cannot unmarshal message: %w", err)
return
}

fmt.Println("Raw JSON: ", data)
if data["type"] != "message" {
continue
}

if data["text"] == "SWUPDATE successful" {
s.done <- nil
done <- nil
return
}
if data["text"] == "Installation failed" {
s.done <- errors.New("installation failed")
done <- errors.New("installation failed")
return
}
}
}

func (s *SWUpdater) Update(image io.Reader, timeout time.Duration) error {
go s.waitForFinished()
go func() {
err := s.upload(image, timeout)
if err != nil {
s.done <- err
}
}()
func (s *SWUpdater) Update(filename string, timeout time.Duration) error {
done := make(chan error)
go s.waitForFinished(done)
err := s.upload(filename, timeout)
if err != nil {
return fmt.Errorf("cannot upload software image: %w", err)
}

select {
case err := <-s.done:
case err := <-done:
if err != nil {
return err
return fmt.Errorf("update failed: %w", err)
}
return nil
case <-time.After(timeout):
Expand Down

0 comments on commit b86dc1d

Please sign in to comment.