diff --git a/cmd/ovp8xx/cmd/swupdate.go b/cmd/ovp8xx/cmd/swupdate.go index 04fbac8..80a50ce 100644 --- a/cmd/ovp8xx/cmd/swupdate.go +++ b/cmd/ovp8xx/cmd/swupdate.go @@ -5,6 +5,7 @@ package cmd import ( "fmt" + "path/filepath" "time" "github.com/graugans/go-ovp8xx/pkg/swupdater" @@ -33,7 +34,12 @@ func swupdateCommand(cmd *cobra.Command, args []string) error { return fmt.Errorf("cannot get timeout: %w", err) } - fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", host, port, filename, timeout) + fmt.Printf("Updating firmware on %s:%d with file %s (%v)\n", + host, + port, + filepath.Base(filename), + timeout, + ) swu := swupdater.NewSWUpdater(host, port) diff --git a/go.mod b/go.mod index 5164633..2e7859e 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/technoweenie/multipartstreamer v1.0.1 // indirect golang.org/x/net v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bad1fcc..491a028 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/technoweenie/multipartstreamer v1.0.1 h1:XRztA5MXiR1TIRHxH2uNxXxaIkKQDeX7m2XsSOlQEnM= +github.com/technoweenie/multipartstreamer v1.0.1/go.mod h1:jNVxdtShOxzAsukZwTSw6MDx5eUJoiEBsSvzDU9uzog= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/pkg/swupdater/swupdater.go b/pkg/swupdater/swupdater.go index 0c854c9..696ebd1 100644 --- a/pkg/swupdater/swupdater.go +++ b/pkg/swupdater/swupdater.go @@ -1,28 +1,27 @@ package swupdater import ( - "bytes" "encoding/json" "errors" "fmt" - "io" - "mime/multipart" "net/http" "os" - "path/filepath" - "strconv" + "strings" "time" "github.com/gorilla/websocket" + "github.com/technoweenie/multipartstreamer" ) +// SWUpdater represents a software updater. type SWUpdater struct { - hostName string - port uint16 - urlUpload string - urlStatus string + hostName string // The hostname of the updater. + port uint16 // The port number of the updater. + urlUpload string // The URL for uploading software updates. + urlStatus string // The URL for checking the status of software updates. } +// NewSWUpdater creates a new instance of SWUpdater with the specified host name and port. func NewSWUpdater(hostName string, port uint16) *SWUpdater { return &SWUpdater{ hostName: hostName, @@ -31,53 +30,61 @@ func NewSWUpdater(hostName string, port uint16) *SWUpdater { urlStatus: fmt.Sprintf("ws://%s:%d/ws", hostName, port), } } -func (s *SWUpdater) upload(filename string, timeout time.Duration) error { - image, err := os.Open(filename) - if err != nil { - return fmt.Errorf("cannot open file: %w", err) - } - fmt.Printf("Uploading software image to %s\n", s.urlUpload) - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - part, err := writer.CreateFormFile("file", filepath.Base(filename)) - if err != nil { - return fmt.Errorf("cannot create form file: %w", err) - } +// Upload performs the upload of the specified file. +// The filename parameter specifies the name of the file to be uploaded. +// Returns an error if the upload fails. +func (s *SWUpdater) upload(filename string) error { + fmt.Printf("Uploading software image to %s\n", s.urlUpload) + const fieldname string = "file" - _, err = io.Copy(part, image) + file, err := os.Open(filename) if err != nil { - return fmt.Errorf("cannot write to form file: %w", err) + return fmt.Errorf("cannot open file: %w", err) } + defer file.Close() - err = writer.Close() + fileInfo, err := file.Stat() if err != nil { - return fmt.Errorf("cannot close multipart writer: %w", err) + return fmt.Errorf("cannot get file info: %w", err) } - req, err := http.NewRequest("POST", s.urlUpload, bytes.NewReader(body.Bytes())) - if err != nil { - return fmt.Errorf("cannot create request: %w", err) - } + ms := multipartstreamer.New() + ms.WriteReader(fieldname, filename, fileInfo.Size(), file) - req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("Content-Length", strconv.Itoa(body.Len())) + req, _ := http.NewRequest("POST", s.urlUpload, nil) + ms.SetupRequest(req) - client := &http.Client{Timeout: timeout} - resp, err := client.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { - return fmt.Errorf("cannot upload software image: %w", err) + return fmt.Errorf("cannot send request: %w", err) } defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("cannot upload software image: status code %d", resp.StatusCode) - } - - return nil + return err } +// waitForFinished waits for the SWUpdater process to finish by listening to a WebSocket connection. +// It continuously reads messages from the WebSocket and checks for specific conditions to determine +// if the SWUpdater process has completed successfully or has failed. +// +// Parameters: +// - done: A channel used to signal the completion of the SWUpdater process. If the process finishes +// successfully, nil is sent to the channel. If the process fails, an error is sent to the channel. +// +// Returns: +// +// None +// +// Example usage: +// +// done := make(chan error) +// go s.waitForFinished(done) +// err := <-done +// if err != nil { +// // Handle error +// } else { +// // SWUpdater process completed successfully +// } func (s *SWUpdater) waitForFinished(done chan error) { c, _, err := websocket.DefaultDialer.Dial(s.urlStatus, nil) if err != nil { @@ -104,21 +111,24 @@ func (s *SWUpdater) waitForFinished(done chan error) { continue } - if data["text"] == "SWUPDATE successful" { + if strings.Contains(data["text"], "SWUPDATE successful") { done <- nil return } - if data["text"] == "Installation failed" { + if strings.Contains(data["text"], "Installation failed") { done <- errors.New("installation failed") return } } } +// Update uploads a software image and waits for the update process to finish. +// It takes a filename string and a timeout duration as parameters. +// It returns an error if the upload fails, or if the operation times out. func (s *SWUpdater) Update(filename string, timeout time.Duration) error { done := make(chan error) go s.waitForFinished(done) - err := s.upload(filename, timeout) + err := s.upload(filename) if err != nil { return fmt.Errorf("cannot upload software image: %w", err) }