diff --git a/bot/internal/joke/database_test.go b/bot/internal/joke/database_test.go index 66ddaeb..2e7d42a 100644 --- a/bot/internal/joke/database_test.go +++ b/bot/internal/joke/database_test.go @@ -3,9 +3,9 @@ package joke import ( "context" "github.com/wittano/komputer/db" + "github.com/wittano/komputer/test" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/integration/mtest" "testing" "time" @@ -25,19 +25,6 @@ var ( } ) -type testMongodbService struct { - client *mongo.Client - ctx context.Context -} - -func (t testMongodbService) Close() error { - return t.client.Disconnect(t.ctx) -} - -func (t testMongodbService) Client(_ context.Context) (*mongo.Client, error) { - return t.client, nil -} - func createMTest(t *testing.T) *mtest.T { return mtest.New(t, mtest.NewOptions(). ClientType(mtest.Mock). @@ -54,12 +41,7 @@ func TestJokeService_Add(t *testing.T) { ctx := context.Background() - mongodbService := testMongodbService{ - t.Client, - ctx, - } - - service := DatabaseJokeService{&mongodbService} + service := DatabaseJokeService{test.NewMockedMognodbService(ctx, t.Client)} if _, err := service.Add(ctx, testJoke); err != nil { mt.Fatal(err) @@ -71,12 +53,7 @@ func TestJokeService_AddButContextCancelled(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now()) cancel() - mongodbService := testMongodbService{ - nil, - ctx, - } - - service := DatabaseJokeService{&mongodbService} + service := DatabaseJokeService{test.NewMockedMognodbService(ctx, nil)} if _, err := service.Add(ctx, testJoke); err == nil { t.Fatal("Context wasn't cancelled") @@ -87,12 +64,7 @@ func TestJokeService_SearchButContextWasCancelled(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now()) cancel() - mongodbService := testMongodbService{ - nil, - ctx, - } - - service := DatabaseJokeService{&mongodbService} + service := DatabaseJokeService{test.NewMockedMognodbService(ctx, nil)} if _, err := service.Get(ctx, testJokeSearch); err == nil { t.Fatal("Context wasn't cancelled") @@ -107,12 +79,7 @@ func TestJokeService_SearchButNotingFound(t *testing.T) { ctx := context.Background() - mongodbService := testMongodbService{ - t.Client, - ctx, - } - - service := DatabaseJokeService{&mongodbService} + service := DatabaseJokeService{test.NewMockedMognodbService(ctx, t.Client)} if _, err := service.Get(ctx, testJokeSearch); err == nil { mt.Fatal("Something was found in database, but it shouldn't") @@ -137,12 +104,7 @@ func TestJokeService_SearchButFindRandomJoke(t *testing.T) { ctx := context.Background() - mongodbService := testMongodbService{ - t.Client, - ctx, - } - - service := DatabaseJokeService{&mongodbService} + service := DatabaseJokeService{test.NewMockedMognodbService(ctx, t.Client)} joke, err := service.Get(ctx, SearchParameters{}) if err != nil { @@ -173,12 +135,7 @@ func TestDatabaseJokeService_ActiveButContextCancelled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - mongodbService := testMongodbService{ - t.Client, - ctx, - } - - service := DatabaseJokeService{&mongodbService} + service := DatabaseJokeService{test.NewMockedMognodbService(ctx, t.Client)} if service.Active(ctx) { t.Fatal("service can still running and handle new requests") @@ -194,12 +151,7 @@ func TestDatabaseJokeService_Active(t *testing.T) { ctx := context.Background() - mongodbService := testMongodbService{ - t.Client, - ctx, - } - - service := DatabaseJokeService{&mongodbService} + service := DatabaseJokeService{test.NewMockedMognodbService(ctx, t.Client)} if !service.Active(ctx) { t.Fatal("service isn't responding") diff --git a/test/mongodb.go b/test/mongodb.go new file mode 100644 index 0000000..f90f625 --- /dev/null +++ b/test/mongodb.go @@ -0,0 +1,25 @@ +package test + +import ( + "context" + "github.com/wittano/komputer/db" + "go.mongodb.org/mongo-driver/mongo" +) + +// TestMongodbService This service is mock. It shouldn't use in production code +type TestMongodbService struct { + client *mongo.Client + ctx context.Context +} + +func (t TestMongodbService) Close() error { + return t.client.Disconnect(t.ctx) +} + +func (t TestMongodbService) Client(_ context.Context) (*mongo.Client, error) { + return t.client, nil +} + +func NewMockedMognodbService(ctx context.Context, client *mongo.Client) db.MongodbService { + return &TestMongodbService{client, ctx} +} diff --git a/test/multipart.go b/test/multipart.go new file mode 100644 index 0000000..f187ee6 --- /dev/null +++ b/test/multipart.go @@ -0,0 +1,67 @@ +package test + +import ( + "bytes" + "errors" + "io" + "mime/multipart" + "os" + "path/filepath" + "testing" +) + +func CreateTempAudioFiles(t *testing.T) (string, error) { + dir := t.TempDir() + f, err := os.CreateTemp(dir, "test.*.mp3") + if err != nil { + return "", err + } + defer f.Close() + + _, err = f.Write([]byte{0xff, 0xfb}) + if err != nil { + return "", err + } + + return f.Name(), nil +} + +func CreateMultipartFileHeader(path string) (*multipart.FileHeader, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var buf bytes.Buffer + + formWriter := multipart.NewWriter(&buf) + filename := filepath.Base(path) + formPart, err := formWriter.CreateFormFile(filename, filepath.Base(path)) + if err != nil { + return nil, err + } + + if _, err = io.Copy(formPart, f); err != nil { + return nil, err + } + + err = formWriter.Close() + if err != nil { + return nil, err + } + + reader := bytes.NewReader(buf.Bytes()) + formReader := multipart.NewReader(reader, formWriter.Boundary()) + + multipartForm, err := formReader.ReadForm(1 << 20) + if err != nil { + return nil, err + } + + if file, ok := multipartForm.File[filename]; !ok || len(file) <= 0 { + return nil, errors.New("failed create multipart audio") + } else { + return file[0], nil + } +} diff --git a/test/settings.go b/test/settings.go new file mode 100644 index 0000000..d045ba8 --- /dev/null +++ b/test/settings.go @@ -0,0 +1,18 @@ +package test + +import ( + "github.com/wittano/komputer/web/settings" + "path/filepath" + "testing" +) + +func LoadDefaultConfig(t *testing.T) error { + const defaultConfigFileName = "config.yml" + configFile := filepath.Join(t.TempDir(), defaultConfigFileName) + + if err := settings.Load(configFile); err != nil { + return err + } + + return settings.Config.Update(settings.Settings{AssetDir: filepath.Join(t.TempDir(), "assets")}) +} diff --git a/web/internal/audio/database.go b/web/internal/audio/database.go index 8fec394..762e931 100644 --- a/web/internal/audio/database.go +++ b/web/internal/audio/database.go @@ -9,8 +9,12 @@ import ( const audioCollectionName = "audio" -func saveFileDataInDatabase(ctx context.Context, filename string) error { - client, err := db.Mongodb(ctx).Client(ctx) +type DatabaseService struct { + Database db.MongodbService +} + +func (a DatabaseService) save(ctx context.Context, filename string) error { + client, err := a.Database.Client(ctx) if err != nil { return err } @@ -24,13 +28,13 @@ func saveFileDataInDatabase(ctx context.Context, filename string) error { return err } -func GetAudioInfo(ctx context.Context, id string) (result db.AudioInfo, err error) { +func (a DatabaseService) Get(ctx context.Context, id string) (result db.AudioInfo, err error) { hex, err := primitive.ObjectIDFromHex(id) if err != nil { return } - client, err := db.Mongodb(ctx).Client(ctx) + client, err := a.Database.Client(ctx) if err != nil { return db.AudioInfo{}, err } diff --git a/web/internal/audio/service.go b/web/internal/audio/service.go new file mode 100644 index 0000000..52a8987 --- /dev/null +++ b/web/internal/audio/service.go @@ -0,0 +1,129 @@ +package audio + +import ( + "context" + "errors" + "fmt" + "github.com/labstack/echo/v4" + "github.com/wittano/komputer/db" + "github.com/wittano/komputer/web/settings" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" +) + +type UploadService struct { + Db db.MongodbService +} + +func (u UploadService) Upload(ctx context.Context, files []*multipart.FileHeader) error { + var ( + errCh = make(chan error) + successCh = make(chan struct{}) + ) + defer close(errCh) + defer close(successCh) + + filesCount := len(files) + for _, f := range files { + if err := validRequestedFile(*f); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid '%s' audio", f.Filename)) + } + + go u.save(ctx, f, errCh, successCh) + } + + for { + select { + case <-ctx.Done(): + return context.Canceled + case err := <-errCh: + return err + case <-successCh: + filesCount -= 1 + break + } + + if filesCount <= 0 { + break + } + } + + return nil +} + +func (u UploadService) save(ctx context.Context, file *multipart.FileHeader, errCh chan<- error, successSig chan<- struct{}) { + select { + case <-ctx.Done(): + errCh <- context.Canceled + return + default: + } + + src, err := file.Open() + if err != nil { + errCh <- err + + return + } + defer src.Close() + + destPath := filepath.Join(settings.Config.AssetDir, file.Filename) + dest, err := os.Create(destPath) + if err != nil { + errCh <- err + + return + } + defer dest.Close() + + for { + select { + case <-ctx.Done(): + errCh <- context.Canceled + + return + default: + const bufSize = 1 << 20 // 1MB buffer size + + _, err := io.CopyN(dest, src, bufSize) + if errors.Is(err, io.EOF) { + audioService := DatabaseService{u.Db} + + err = audioService.save(ctx, dest.Name()) + if err != nil { + errCh <- err + } else { + successSig <- struct{}{} + } + + return + } else if err != nil { + errCh <- err + + os.Remove(destPath) + + return + } + } + } +} + +func validRequestedFile(file multipart.FileHeader) error { + if file.Size >= settings.Config.Upload.MaxFileSize { + return fmt.Errorf("audio '%s' is too big", file.Filename) + } + + if err := ValidMp3File(&file); err != nil { + return err + } + + destFile := filepath.Join(settings.Config.AssetDir, file.Filename) + if _, err := os.Stat(destFile); err == nil { + return os.ErrExist + } + + return nil +} diff --git a/web/internal/audio/service_test.go b/web/internal/audio/service_test.go new file mode 100644 index 0000000..ac3b9f9 --- /dev/null +++ b/web/internal/audio/service_test.go @@ -0,0 +1,46 @@ +package audio + +import ( + "context" + "github.com/wittano/komputer/db" + "github.com/wittano/komputer/test" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/integration/mtest" + "mime/multipart" + "testing" + "time" +) + +func TestUploadRequestedFile(t *testing.T) { + if err := test.LoadDefaultConfig(t); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + filePath, err := test.CreateTempAudioFiles(t) + if err != nil { + t.Fatal(err) + } + + multipartFileHeader, err := test.CreateMultipartFileHeader(filePath) + if err != nil { + t.Fatal(err) + } + + mt := mtest.New(t, mtest.NewOptions(). + ClientType(mtest.Mock). + CollectionName("audio"). + DatabaseName(db.DatabaseName)) + mt.Run("upload requested file", func(t *mtest.T) { + t.AddMockResponses(mtest.CreateSuccessResponse(bson.E{Key: "ok", Value: "1"}, + bson.E{Key: "_id", Value: primitive.NewObjectID()})) + + service := UploadService{Db: test.NewMockedMognodbService(ctx, t.Client)} + + if err := service.Upload(ctx, []*multipart.FileHeader{multipartFileHeader}); err != nil { + t.Fatal(err) + } + }) +} diff --git a/web/internal/audio/file.go b/web/internal/audio/validation.go similarity index 65% rename from web/internal/audio/file.go rename to web/internal/audio/validation.go index 89ee99a..46baf0b 100644 --- a/web/internal/audio/file.go +++ b/web/internal/audio/validation.go @@ -2,32 +2,11 @@ package audio import ( "bytes" - "context" "errors" - "io" "mime/multipart" - "os" "strings" ) -func UploadFile(ctx context.Context, src io.Reader, dest *os.File) error { - for { - select { - case <-ctx.Done(): - return context.Canceled - default: - const bufSize = 1 << 20 // 1MB buffer size - - _, err := io.CopyN(dest, src, bufSize) - if errors.Is(err, io.EOF) { - return saveFileDataInDatabase(ctx, dest.Name()) - } else if err != nil { - return err - } - } - } -} - func ValidMp3File(file *multipart.FileHeader) (err error) { if !strings.HasSuffix(file.Filename, "mp3") { return errors.New("invalid audio extension") diff --git a/web/internal/audio/validation_test.go b/web/internal/audio/validation_test.go new file mode 100644 index 0000000..fe252e3 --- /dev/null +++ b/web/internal/audio/validation_test.go @@ -0,0 +1,27 @@ +package audio + +import ( + "github.com/wittano/komputer/test" + "testing" +) + +func TestValidRequestedFile(t *testing.T) { + if err := test.LoadDefaultConfig(t); err != nil { + t.Fatal(err) + } + + path, err := test.CreateTempAudioFiles(t) + if err != nil { + t.Fatal(err) + } + + multipartFileHeader, err := test.CreateMultipartFileHeader(path) + if err != nil { + t.Fatal(err) + } + + err = validRequestedFile(*multipartFileHeader) + if err != nil { + t.Fatal(err) + } +} diff --git a/web/internal/handler/audio.go b/web/internal/handler/audio.go index df8a5a7..2bdfba7 100644 --- a/web/internal/handler/audio.go +++ b/web/internal/handler/audio.go @@ -1,20 +1,21 @@ package handler import ( - "context" - "fmt" "github.com/labstack/echo/v4" + "github.com/wittano/komputer/db" "github.com/wittano/komputer/web/internal/audio" - "github.com/wittano/komputer/web/internal/settings" + "github.com/wittano/komputer/web/settings" + "mime/multipart" "net/http" - "os" - "path/filepath" ) func GetAudio(c echo.Context) error { id := c.Param("id") - info, err := audio.GetAudioInfo(c.Request().Context(), id) + ctx := c.Request().Context() + service := audio.DatabaseService{Database: db.Mongodb(ctx)} + + info, err := service.Get(ctx, id) if err != nil { return err } @@ -28,100 +29,23 @@ func UploadNewAudio(c echo.Context) (err error) { return err } - filesCount := len(multipartForm.File) - if !settings.Config.CheckFileCountLimit(filesCount) { - return echo.NewHTTPError(http.StatusBadRequest, "invalid number of uploaded files") - } - - var ( - errCh = make(chan error) - successCh = make(chan struct{}) - ) - defer close(errCh) - defer close(successCh) + var files []*multipart.FileHeader - for k := range multipartForm.File { - if err = validRequestedFile(k, *c.Request()); err != nil { - c.Logger().Error(err) - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid '%s' audio", k)) - } - - go uploadRequestedFile(c.Request().Context(), k, c.Request(), errCh, successCh) + for _, v := range multipartForm.File { + files = append(files, v...) } - for { - select { - case <-c.Request().Context().Done(): - return context.Canceled - case err = <-errCh: - return err - case <-successCh: - filesCount -= 1 - break - } - - if filesCount <= 0 { - c.Response().WriteHeader(http.StatusCreated) - - break - } - } - - return nil -} - -func validRequestedFile(filename string, req http.Request) error { - _, fileHeader, err := req.FormFile(filename) - if err != nil { - return err + filesCount := len(files) + if !settings.Config.CheckFileCountLimit(filesCount) { + return echo.NewHTTPError(http.StatusBadRequest, "invalid number of uploaded files") } - if fileHeader.Size >= settings.Config.Upload.MaxFileSize { - return fmt.Errorf("audio '%s' is too big", filename) - } + ctx := c.Request().Context() + service := audio.UploadService{Db: db.Mongodb(ctx)} - if err = audio.ValidMp3File(fileHeader); err != nil { + if err := service.Upload(ctx, files); err != nil { return err } - destFile := filepath.Join(settings.Config.AssetDir, filename) - if _, err = os.Stat(destFile); err == nil { - return os.ErrExist - } - - return nil -} - -func uploadRequestedFile(ctx context.Context, filename string, req *http.Request, errCh chan<- error, successSig chan<- struct{}) { - select { - case <-ctx.Done(): - errCh <- context.Canceled - return - default: - } - - f, _, err := req.FormFile(filename) - if err != nil { - errCh <- err - - return - } - defer f.Close() - - dest, err := os.Create(filepath.Join(settings.Config.AssetDir, filename)) - if err != nil { - errCh <- err - - return - } - defer dest.Close() - - if err = audio.UploadFile(ctx, f, dest); err != nil { - errCh <- err - os.Remove(dest.Name()) - - return - } - - successSig <- struct{}{} + return c.String(http.StatusCreated, "") } diff --git a/web/internal/handler/audio_test.go b/web/internal/handler/audio_test.go deleted file mode 100644 index 4ef03f2..0000000 --- a/web/internal/handler/audio_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package handler - -import ( - "bytes" - "context" - "errors" - "github.com/wittano/komputer/web/internal/settings" - "io" - "mime/multipart" - "net/http" - "os" - "path/filepath" - "testing" - "time" -) - -const testFileName = "test.mp3" - -func createTempAudioFiles(t *testing.T) (string, error) { - dir := t.TempDir() - f, err := os.CreateTemp(dir, "test.*.mp3") - if err != nil { - return "", err - } - defer f.Close() - - _, err = f.Write([]byte{0xff, 0xfb}) - if err != nil { - return "", err - } - - return f.Name(), nil -} - -func createMultipartFileHeader(filename string) (*multipart.FileHeader, error) { - f, err := os.Open(filename) - if err != nil { - return nil, err - } - defer f.Close() - - var buf bytes.Buffer - - formWriter := multipart.NewWriter(&buf) - formPart, err := formWriter.CreateFormFile(testFileName, filepath.Base(filename)) - if err != nil { - return nil, err - } - - if _, err = io.Copy(formPart, f); err != nil { - return nil, err - } - - err = formWriter.Close() - if err != nil { - return nil, err - } - - reader := bytes.NewReader(buf.Bytes()) - formReader := multipart.NewReader(reader, formWriter.Boundary()) - - multipartForm, err := formReader.ReadForm(1 << 20) - if err != nil { - return nil, err - } - - if file, ok := multipartForm.File[testFileName]; !ok || len(file) <= 0 { - return nil, errors.New("failed create multipart audio") - } else { - return file[0], nil - } -} - -func loadDefaultConfig(t *testing.T) error { - configFile := filepath.Join(t.TempDir(), "config.yml") - - if err := settings.Load(configFile); err != nil { - return err - } - - return settings.Config.Update(settings.Settings{AssetDir: filepath.Join(t.TempDir(), "assets")}) -} - -func TestValidRequestedFile(t *testing.T) { - if err := loadDefaultConfig(t); err != nil { - t.Fatal(err) - } - - filePath, err := createTempAudioFiles(t) - if err != nil { - t.Fatal(err) - } - - multipartFileHeader, err := createMultipartFileHeader(filePath) - if err != nil { - t.Fatal(err) - } - - req := http.Request{ - MultipartForm: &multipart.Form{ - File: map[string][]*multipart.FileHeader{ - testFileName: { - multipartFileHeader, - }, - }, - }, - } - - err = validRequestedFile(testFileName, req) - if err != nil { - t.Fatal(err) - } -} - -func TestUploadRequestedFile(t *testing.T) { - if err := loadDefaultConfig(t); err != nil { - t.Fatal(err) - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - filePath, err := createTempAudioFiles(t) - if err != nil { - t.Fatal(err) - } - - multipartFileHeader, err := createMultipartFileHeader(filePath) - if err != nil { - t.Fatal(err) - } - - req := &http.Request{ - MultipartForm: &multipart.Form{ - File: map[string][]*multipart.FileHeader{ - testFileName: { - multipartFileHeader, - }, - }, - }, - } - - successCh := make(chan struct{}) - errCh := make(chan error) - defer close(successCh) - defer close(errCh) - - go uploadRequestedFile(ctx, testFileName, req, errCh, successCh) - - for { - select { - case <-ctx.Done(): - t.Fatal(context.Canceled) - case err = <-errCh: - t.Fatal(err) - case <-successCh: - return - } - } -} diff --git a/web/internal/handler/settings.go b/web/internal/handler/settings.go index 2cfcdb0..b5cf093 100644 --- a/web/internal/handler/settings.go +++ b/web/internal/handler/settings.go @@ -2,15 +2,14 @@ package handler import ( "github.com/labstack/echo/v4" - "github.com/wittano/komputer/web/internal/settings" + "github.com/wittano/komputer/web/settings" "net/http" ) func UpdateSettings(c echo.Context) error { var newSetting settings.Settings - err := c.Bind(&newSetting) - if err != nil { + if err := c.Bind(&newSetting); err != nil { return err } diff --git a/web/server.go b/web/server.go index 62223f5..0791982 100644 --- a/web/server.go +++ b/web/server.go @@ -4,7 +4,7 @@ import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/wittano/komputer/web/internal/handler" - "github.com/wittano/komputer/web/internal/settings" + "github.com/wittano/komputer/web/settings" ) func NewWebConsoleServer(configPath string) (*echo.Echo, error) { diff --git a/web/internal/settings/file.go b/web/settings/file.go similarity index 100% rename from web/internal/settings/file.go rename to web/settings/file.go diff --git a/web/internal/settings/file_test.go b/web/settings/file_test.go similarity index 100% rename from web/internal/settings/file_test.go rename to web/settings/file_test.go diff --git a/web/internal/settings/types.go b/web/settings/types.go similarity index 100% rename from web/internal/settings/types.go rename to web/settings/types.go diff --git a/web/internal/settings/types_test.go b/web/settings/types_test.go similarity index 100% rename from web/internal/settings/types_test.go rename to web/settings/types_test.go