diff --git a/api/handler/audio.go b/api/handler/audio.go index a82f9d9..27795c7 100644 --- a/api/handler/audio.go +++ b/api/handler/audio.go @@ -2,60 +2,61 @@ package handler import ( "context" + "fmt" + "github.com/wittano/komputer/pkgs/settings" "github.com/wittano/komputer/pkgs/voice" "net/http" - "time" + "os" + "path/filepath" ) -// TODO export property to config file/environment variable -const maxFileSize = 8 * 1024 * 1024 // 8MB in bytes +const oneMegaByte = 1 << 20 // 8MB in bytes func UploadNewAudio(res http.ResponseWriter, req *http.Request) (err error) { - err = req.ParseMultipartForm(maxFileSize) + err = req.ParseMultipartForm(settings.Config.Upload.MaxFileSize * oneMegaByte) if err != nil { return newInternalApiError(err) } - // TODO export to property how much user can upload files - if counts := len(req.MultipartForm.Value); counts < 1 || counts > 5 { + filesCount := len(req.MultipartForm.File) + if !settings.Config.CheckFileCountLimit(filesCount) { return apiError{ Status: http.StatusBadRequest, Msg: "illegal uploaded files count", } } - // TODO export uploading timeout to external properties - const uploadingTimeout = time.Second * 2 - ctx, cancel := context.WithTimeout(req.Context(), uploadingTimeout) - defer cancel() - - filesCount := len(req.MultipartForm.File) var ( - errCh = make(chan error) - successSigCh = make(chan struct{}, filesCount) + errCh = make(chan error) + successCh = make(chan struct{}, filesCount) ) defer close(errCh) - defer close(successSigCh) + defer close(successCh) for k := range req.MultipartForm.File { - go uploadFile(ctx, *req, k, errCh, successSigCh) + err = validRequestedFile(k, *req) + if err != nil { + return apiError{ + Status: http.StatusBadRequest, + Msg: fmt.Sprintf("invalid '%s' file", k), + Err: err, + } + } + + go uploadRequestedFile(req.Context(), k, req, errCh, successCh) } - var ( - resError error - successCounter = filesCount - ) + successCounter := filesCount for { select { - case <-ctx.Done(): - resError = context.Canceled - break + case <-req.Context().Done(): + return context.Canceled case err = <-errCh: - resError = err - break - case <-successSigCh: + return err + case <-successCh: successCounter -= 1 + break } if successCounter <= 0 { @@ -65,24 +66,54 @@ func UploadNewAudio(res http.ResponseWriter, req *http.Request) (err error) { } } - return resError + return nil } -func uploadFile(ctx context.Context, req http.Request, name string, errCh chan<- error, successSig chan<- struct{}) { - file, fileHeader, err := req.FormFile(name) +func validRequestedFile(filename string, req http.Request) error { + _, fileHeader, err := req.FormFile(filename) + if err != nil { + return err + } + + if err = voice.ValidMp3File(fileHeader); err != nil { + return err + } + + destFile := filepath.Join(settings.Config.AssetDir, filename) + if _, err = os.Stat(destFile); err == nil { + return os.ErrExist + } + + return nil +} + +func uploadRequestedFile(ctx context.Context, filename string, req *http.Request, errCh chan<- error, successSig chan<- struct{}) { + select { + case <-ctx.Done(): + errCh <- context.Canceled + return + default: + } + + f, _, err := req.FormFile(filename) if err != nil { errCh <- newInternalApiError(err) + return } - defer file.Close() + defer f.Close() - if err := voice.ValidMp3File(fileHeader); err != nil { + dest, err := os.Create(filepath.Join(settings.Config.AssetDir, filename)) + if err != nil { errCh <- newInternalApiError(err) + return } + defer dest.Close() - if err = voice.UploadFile(ctx, fileHeader.Filename, file); err != nil { + if err = voice.UploadFile(ctx, f, dest); err != nil { errCh <- newInternalApiError(err) + os.Remove(dest.Name()) return } diff --git a/api/handler/audio_test.go b/api/handler/audio_test.go new file mode 100644 index 0000000..a735df2 --- /dev/null +++ b/api/handler/audio_test.go @@ -0,0 +1,159 @@ +package handler + +import ( + "bytes" + "context" + "errors" + "github.com/wittano/komputer/pkgs/settings" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +const testFileName = "test.mp3" + +func createTempAudioFiles(t *testing.T) (string, error) { + dir := t.TempDir() + f, err := os.CreateTemp(dir, "test.*.mp3") + if err != nil { + return "", err + } + defer f.Close() + + _, err = f.Write([]byte{0xff, 0xfb}) + if err != nil { + return "", err + } + + return f.Name(), nil +} + +func createMultipartFileHeader(filename string) (*multipart.FileHeader, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + var buf bytes.Buffer + + formWriter := multipart.NewWriter(&buf) + formPart, err := formWriter.CreateFormFile(testFileName, filepath.Base(filename)) + if err != nil { + return nil, err + } + + if _, err = io.Copy(formPart, f); err != nil { + return nil, err + } + + err = formWriter.Close() + if err != nil { + return nil, err + } + + reader := bytes.NewReader(buf.Bytes()) + formReader := multipart.NewReader(reader, formWriter.Boundary()) + + multipartForm, err := formReader.ReadForm(1 << 20) + if err != nil { + return nil, err + } + + if file, ok := multipartForm.File[testFileName]; !ok || len(file) <= 0 { + return nil, errors.New("failed create multipart file") + } else { + return file[0], nil + } +} + +func loadDefaultConfig(t *testing.T) error { + configFile := filepath.Join(t.TempDir(), "config.yml") + + if err := settings.Load(configFile); err != nil { + return err + } + + return settings.Config.Update(settings.Settings{AssetDir: filepath.Join(t.TempDir(), "assets")}) +} + +func TestValidRequestedFile(t *testing.T) { + if err := loadDefaultConfig(t); err != nil { + t.Fatal(err) + } + + filePath, err := createTempAudioFiles(t) + if err != nil { + t.Fatal(err) + } + + multipartFileHeader, err := createMultipartFileHeader(filePath) + if err != nil { + t.Fatal(err) + } + + req := http.Request{ + MultipartForm: &multipart.Form{ + File: map[string][]*multipart.FileHeader{ + testFileName: { + multipartFileHeader, + }, + }, + }, + } + + err = validRequestedFile(testFileName, req) + if err != nil { + t.Fatal(err) + } +} + +func TestUploadRequestedFile(t *testing.T) { + if err := loadDefaultConfig(t); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + filePath, err := createTempAudioFiles(t) + if err != nil { + t.Fatal(err) + } + + multipartFileHeader, err := createMultipartFileHeader(filePath) + if err != nil { + t.Fatal(err) + } + + req := &http.Request{ + MultipartForm: &multipart.Form{ + File: map[string][]*multipart.FileHeader{ + testFileName: { + multipartFileHeader, + }, + }, + }, + } + + successCh := make(chan struct{}) + errCh := make(chan error) + defer close(successCh) + defer close(errCh) + + go uploadRequestedFile(ctx, testFileName, req, errCh, successCh) + + for { + select { + case <-ctx.Done(): + t.Fatal(context.Canceled) + case err = <-errCh: + t.Fatal(err) + case <-successCh: + return + } + } +} diff --git a/pkgs/settings/types.go b/pkgs/settings/types.go index 748f037..58af624 100644 --- a/pkgs/settings/types.go +++ b/pkgs/settings/types.go @@ -3,10 +3,10 @@ package settings import ( "errors" "github.com/mitchellh/go-homedir" + "github.com/wittano/komputer/internal/assets" "gopkg.in/yaml.v3" "os" "path/filepath" - "sync" ) const ( @@ -14,11 +14,11 @@ const ( DefaultSettingsPath = ".config/komputer/settings.yml" ) -const maxFileSize = 8 * (1 << 20) // 8MB in bytes +const defaultMaxFileSize = 8 * (1 << 20) // 8MB in bytes type UploadSettings struct { - MaxFileCount uint `yaml:"max_file_count" json:"max_file_count"` - MaxFileSize uint `yaml:"max_file_size" json:"max_file_size"` + MaxFileCount int64 `yaml:"max_file_count" json:"max_file_count"` + MaxFileSize int64 `yaml:"max_file_size" json:"max_file_size"` } type Settings struct { @@ -32,7 +32,7 @@ func (s *Settings) Update(new Settings) error { return err } - err := moveAssets(s.AssetDir, new.AssetDir) + err := assets.Move(s.AssetDir, new.AssetDir) if err != nil { return err } @@ -51,6 +51,10 @@ func (s *Settings) Update(new Settings) error { return nil } +func (s Settings) CheckFileCountLimit(count int) bool { + return count >= 1 && int64(count) <= s.Upload.MaxFileCount +} + var Config *Settings func Load(path string) error { @@ -98,7 +102,7 @@ func defaultSettings(path string) (*Settings, error) { AssetDir: DefaultAssertDir, Upload: UploadSettings{ MaxFileCount: 5, - MaxFileSize: maxFileSize, + MaxFileSize: defaultMaxFileSize, }, } @@ -110,33 +114,3 @@ func defaultSettings(path string) (*Settings, error) { return &defaultSettings, nil } - -func moveAssets(oldSrc string, path string) (err error) { - dirs, err := os.ReadDir(oldSrc) - if err != nil { - return err - } - - var wg sync.WaitGroup - wg.Add(len(dirs)) - - for _, dir := range dirs { - go func(wg *sync.WaitGroup, oldSrc string, file os.DirEntry) { - defer wg.Done() - - if err != nil { - return - } - - filename := filepath.Join(oldSrc, file.Name()) - newPath := filepath.Join(path, filepath.Base(filename)) - if err = os.Rename(filename, newPath); err != nil { - return - } - }(&wg, oldSrc, dir) - } - - wg.Wait() - - return -} diff --git a/pkgs/voice/file.go b/pkgs/voice/file.go index 0dc7137..aeb785a 100644 --- a/pkgs/voice/file.go +++ b/pkgs/voice/file.go @@ -1,48 +1,66 @@ package voice import ( + "bytes" "context" "errors" "io" "mime/multipart" "os" - "path/filepath" + "strings" ) -// TODO added external property for uploading audio path -const uploadDir = "" - -func UploadFile(ctx context.Context, filename string, file multipart.File) (err error) { - path := filepath.Join(uploadDir, filename) - if _, err := os.Stat(path); err == nil { - return os.ErrExist - } - - destFile, err := os.Create(path) - if err != nil { - return - } - defer destFile.Close() - +func UploadFile(ctx context.Context, src io.Reader, dest *os.File) error { for { select { case <-ctx.Done(): - destFile.Close() - - if err = os.Remove(path); err == nil { - err = context.Canceled - } - - return + return context.Canceled default: - const bufSize = 1024 * 1024 + const bufSize = 1 << 20 // 1MB buffer size - _, err = io.CopyN(destFile, file, bufSize) + _, err := io.CopyN(dest, src, bufSize) if errors.Is(err, io.EOF) { return nil - } else { + } else if err != nil { return err } } } } + +func ValidMp3File(file *multipart.FileHeader) (err error) { + if !strings.HasSuffix(file.Filename, "mp3") { + return errors.New("invalid file extension") + } + + f, err := file.Open() + if err != nil { + return + } + defer f.Close() + + if err = checkAudioFileBinary(f); err != nil { + return + } + return nil +} + +func checkAudioFileBinary(f multipart.File) (err error) { + const headerBytesSize = 2 + err = errors.New("invalid file") + + buf := make([]byte, headerBytesSize) + n, err := f.Read(buf) + if err != nil { + return + } else if n != headerBytesSize { + return + } + + mp3MagicNumbersHeader := []byte{0xff, 0xfb} + if len(buf) != headerBytesSize && bytes.Equal(buf, mp3MagicNumbersHeader) { + return + } + + return nil +} diff --git a/pkgs/voice/validation.go b/pkgs/voice/validation.go deleted file mode 100644 index a8e3344..0000000 --- a/pkgs/voice/validation.go +++ /dev/null @@ -1,45 +0,0 @@ -package voice - -import ( - "bytes" - "errors" - "mime/multipart" - "strings" -) - -func ValidMp3File(file *multipart.FileHeader) (err error) { - if !strings.HasSuffix("mp3", file.Filename) { - return errors.New("invalid file extension") - } - - f, err := file.Open() - if err != nil { - return - } - defer f.Close() - - if err = checkAudioFileBinary(f); err != nil { - return - } - return nil -} - -func checkAudioFileBinary(f multipart.File) (err error) { - const headerBytesSize = 2 - err = errors.New("invalid file") - - buf := make([]byte, headerBytesSize) - n, err := f.Read(buf) - if err != nil { - return - } else if n != headerBytesSize { - return - } - - mp3MagicNumbersHeader := []byte{0xff, 0xfb} - if len(buf) != headerBytesSize && bytes.Equal(buf, mp3MagicNumbersHeader) { - return - } - - return nil -}