diff --git a/config/config.yaml b/config/config.yaml index 30f52f4..05f8aa5 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -31,11 +31,15 @@ server: # Public-facing URL (used for generating links) base_url: http://localhost:3000 - # Maximum upload size in bytes (default: 5MB) - max_upload_size: 100MB - api_upload_size: 50MB - default_upload_size: 10MB - + # Maximum upload size in bytes (hard server limit) + max_upload_size: 104857600 # 100 MB + + # Maximum upload size in bytes (normal users) + default_upload_size: 10485760 # 10 MB + + # Maximim upload size in bytes (users with API key) + api_upload_size: 52428800 # 50 MB + # Preforking prefork: false diff --git a/internal/config/config.go b/internal/config/config.go index 18a9711..6295006 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,7 +5,6 @@ import ( "time" "github.com/spf13/viper" - "github.com/watzon/0x45/internal/utils/bytesize" ) type StorageConfig struct { @@ -56,19 +55,19 @@ type RateLimitConfig struct { } type ServerConfig struct { - Address string `mapstructure:"address"` - BaseURL string `mapstructure:"base_url"` - MaxUploadSize bytesize.ByteSize `mapstructure:"max_upload_size"` - DefaultUploadSize bytesize.ByteSize `mapstructure:"default_upload_size"` - APIUploadSize bytesize.ByteSize `mapstructure:"api_upload_size"` - Prefork bool `mapstructure:"prefork"` - ServerHeader string `mapstructure:"server_header"` - AppName string `mapstructure:"app_name"` - Cleanup CleanupConfig `mapstructure:"cleanup"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - CORSOrigins []string `mapstructure:"cors_origins"` - ViewsDirectory string `mapstructure:"views_directory"` - PublicDirectory string `mapstructure:"public_directory"` + Address string `mapstructure:"address"` + BaseURL string `mapstructure:"base_url"` + MaxUploadSize int `mapstructure:"max_upload_size"` + DefaultUploadSize int `mapstructure:"default_upload_size"` + APIUploadSize int `mapstructure:"api_upload_size"` + Prefork bool `mapstructure:"prefork"` + ServerHeader string `mapstructure:"server_header"` + AppName string `mapstructure:"app_name"` + Cleanup CleanupConfig `mapstructure:"cleanup"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + CORSOrigins []string `mapstructure:"cors_origins"` + ViewsDirectory string `mapstructure:"views_directory"` + PublicDirectory string `mapstructure:"public_directory"` } type SMTPConfig struct { @@ -128,6 +127,8 @@ func Load() (*Config, error) { _ = viper.BindEnv("server.address", "0X_SERVER_ADDRESS") _ = viper.BindEnv("server.base_url", "0X_SERVER_BASE_URL") _ = viper.BindEnv("server.max_upload_size", "0X_SERVER_MAX_UPLOAD_SIZE") + _ = viper.BindEnv("server.default_upload_size", "0X_SERVER_DEFAULT_UPLOAD_SIZE") + _ = viper.BindEnv("server.api_upload_size", "0X_SERVER_API_UPLOAD_SIZE") _ = viper.BindEnv("server.prefork", "0X_SERVER_PREFORK") _ = viper.BindEnv("server.server_header", "0X_SERVER_SERVER_HEADER") _ = viper.BindEnv("server.app_name", "0X_SERVER_APP_NAME") diff --git a/internal/models/api_key.go b/internal/models/api_key.go index c443230..e470972 100644 --- a/internal/models/api_key.go +++ b/internal/models/api_key.go @@ -14,13 +14,13 @@ type APIKey struct { DeletedAt gorm.DeletedAt `gorm:"index"` // Paste-related limits and permissions - MaxFileSize int64 `gorm:"default:10485760"` // 10MB default - RateLimit int `gorm:"default:100"` // Requests per hour + MaxFileSize int64 // 10MB default + RateLimit int // Requests per hour AllowPrivate bool `gorm:"default:true"` - AllowUpdates bool `gorm:"default:false"` + AllowUpdates bool `gorm:"default:true"` // URL shortening permissions - AllowShortlinks bool `gorm:"default:false"` // Whether this key can create shortlinks + AllowShortlinks bool `gorm:"default:true"` // Whether this key can create shortlinks ShortlinkQuota int `gorm:"default:0"` // 0 = unlimited ShortlinkPrefix string `gorm:"type:varchar(16)"` // Optional custom prefix for shortened URLs @@ -51,17 +51,6 @@ func (k *APIKey) BeforeCreate(tx *gorm.DB) error { k.Key = GenerateAPIKey() } - // Set defaults if not specified - if k.MaxFileSize == 0 { - k.MaxFileSize = 10485760 // 10MB - } - if k.RateLimit == 0 { - k.RateLimit = 100 - } - if !k.AllowPrivate && !k.AllowUpdates && !k.AllowShortlinks { - k.AllowPrivate = true // Default to allowing private pastes - } - return nil } diff --git a/internal/server/middleware/middleware.go b/internal/server/middleware/middleware.go index 36e16ef..583122d 100644 --- a/internal/server/middleware/middleware.go +++ b/internal/server/middleware/middleware.go @@ -110,7 +110,7 @@ func (m *Middleware) ETag() fiber.Handler { func (m *Middleware) GetMiddleware() []fiber.Handler { return []fiber.Handler{ m.RequestID(), - m.Logger(), + // m.Logger(), m.Recover(), m.CORS(), m.Compression(), diff --git a/internal/server/server.go b/internal/server/server.go index b9a76fb..2056cc6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,6 @@ import ( "github.com/watzon/hdur" "go.uber.org/zap" "gorm.io/gorm" - "moul.io/zapgorm2" ) type Server struct { @@ -30,8 +29,8 @@ type Server struct { } func New(config *config.Config, logger *zap.Logger) *Server { - gormLogger := zapgorm2.New(logger) - gormLogger.SetAsDefault() + // gormLogger := zapgorm2.New(logger) + // gormLogger.SetAsDefault() // Custom parsers for fiber fiber.SetParserDecoder(fiber.ParserConfig{ @@ -47,7 +46,7 @@ func New(config *config.Config, logger *zap.Logger) *Server { // Initialize database db, err := database.New(config, &gorm.Config{ - Logger: gormLogger, + // Logger: gormLogger, }) if err != nil { logger.Fatal("Error connecting to database", zap.Error(err)) @@ -79,7 +78,7 @@ func New(config *config.Config, logger *zap.Logger) *Server { // Initialize Fiber app app := fiber.New(fiber.Config{ ErrorHandler: errorHandler, - BodyLimit: int(config.Server.MaxUploadSize.Int64()), + BodyLimit: int(config.Server.MaxUploadSize), Views: engine, Prefork: config.Server.Prefork, ServerHeader: config.Server.ServerHeader, diff --git a/internal/server/services/apikey.go b/internal/server/services/apikey.go index e0dab80..6058b15 100644 --- a/internal/server/services/apikey.go +++ b/internal/server/services/apikey.go @@ -69,6 +69,8 @@ func (s *APIKeyService) RequestKey(c *fiber.Ctx) error { apiKey.Name = req.Name apiKey.VerifyToken = token apiKey.VerifyExpiry = time.Now().Add(24 * time.Hour) + apiKey.MaxFileSize = int64(s.config.Server.APIUploadSize) + apiKey.RateLimit = int(s.config.Server.RateLimit.Global.Rate) if err := s.db.Create(apiKey).Error; err != nil { s.logger.Error("failed to create API key", zap.Error(err)) diff --git a/internal/server/services/paste.go b/internal/server/services/paste.go index 827090a..536af0f 100644 --- a/internal/server/services/paste.go +++ b/internal/server/services/paste.go @@ -397,18 +397,18 @@ func (s *PasteService) CleanupExpired() (int64, error) { // validateFileSize checks if the file size is within the allowed limits func (s *PasteService) validateFileSize(size int64, apiKey *models.APIKey) error { // First check against absolute maximum size for security - if size > s.config.Server.MaxUploadSize.Int64() { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("File exceeds maximum allowed size of %s", s.config.Server.MaxUploadSize)) + if size > int64(s.config.Server.MaxUploadSize) { + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("File exceeds maximum allowed size of %d bytes", s.config.Server.MaxUploadSize)) } // Then check against the appropriate tier limit if apiKey != nil { - if size > s.config.Server.APIUploadSize.Int64() { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("File exceeds API upload limit of %s", s.config.Server.APIUploadSize)) + if size > int64(s.config.Server.APIUploadSize) { + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("File exceeds API upload limit of %d bytes", s.config.Server.APIUploadSize)) } } else { - if size > s.config.Server.DefaultUploadSize.Int64() { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("File exceeds default upload limit of %s", s.config.Server.DefaultUploadSize)) + if size > int64(s.config.Server.DefaultUploadSize) { + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("File exceeds default upload limit of %d bytes", s.config.Server.DefaultUploadSize)) } } diff --git a/internal/server/tests/paste_test.go b/internal/server/tests/paste_test.go index f8378c4..1ab73fa 100644 --- a/internal/server/tests/paste_test.go +++ b/internal/server/tests/paste_test.go @@ -87,7 +87,7 @@ func TestMultipartPasteUpload(t *testing.T) { }, { name: "Large content within API limit", - content: strings.Repeat("a", 1024*1024*9), // 9MB + content: strings.Repeat("a", 1024*1024*7), // 7MB private: false, mimeType: "text/plain; charset=utf-8", expectedStatus: 200, @@ -96,7 +96,7 @@ func TestMultipartPasteUpload(t *testing.T) { }, { name: "Large content exceeding API limit", - content: strings.Repeat("a", 1024*1024*11), // 11MB + content: strings.Repeat("a", 1024*1024*9), // 9MB private: false, mimeType: "text/plain; charset=utf-8", expectedStatus: 400, diff --git a/internal/server/tests/testutils/setup.go b/internal/server/tests/testutils/setup.go index ff175df..177f2b5 100644 --- a/internal/server/tests/testutils/setup.go +++ b/internal/server/tests/testutils/setup.go @@ -11,7 +11,6 @@ import ( "github.com/watzon/0x45/internal/models" "github.com/watzon/0x45/internal/server" "github.com/watzon/0x45/internal/storage" - "github.com/watzon/0x45/internal/utils/bytesize" "go.uber.org/zap" ) @@ -63,9 +62,9 @@ func SetupTestEnv(t *testing.T) *TestEnv { }, }, Server: config.ServerConfig{ - MaxUploadSize: bytesize.ByteSize(10 * 1024 * 1024), // 10MB - DefaultUploadSize: bytesize.ByteSize(5 * 1024 * 1024), // 5MB - APIUploadSize: bytesize.ByteSize(10 * 1024 * 1024), // 10MB + MaxUploadSize: 100 * 1024 * 1024, // 10MB + DefaultUploadSize: 5 * 1024 * 1024, // 5MB + APIUploadSize: 8 * 1024 * 1024, // 8MB AppName: "0x45-test", ServerHeader: "0x45-test", ViewsDirectory: viewsDir, @@ -92,8 +91,8 @@ func SetupTestEnv(t *testing.T) *TestEnv { defer func() { _ = logger.Sync() }() // Create server instance with modified config - origCfg := *cfg // Make a copy of the original config - cfg.Server.MaxUploadSize = bytesize.ByteSize(10 * 1024 * 1024) // 10MB + origCfg := *cfg // Make a copy of the original config + cfg.Server.MaxUploadSize = 10 * 1024 * 1024 // 10MB cfg.Server.AppName = "0x45-test" cfg.Server.ServerHeader = "0x45-test" diff --git a/internal/utils/bytesize/bytesize.go b/internal/utils/bytesize/bytesize.go deleted file mode 100644 index e3bada1..0000000 --- a/internal/utils/bytesize/bytesize.go +++ /dev/null @@ -1,163 +0,0 @@ -package bytesize - -import ( - "errors" - "fmt" - "regexp" - "strconv" - "strings" -) - -// ByteSize represents a size in bytes with string parsing and formatting -type ByteSize int64 - -// Common byte sizes for IEC (binary) units -const ( - _ = iota - KiB ByteSize = 1 << (10 * iota) - MiB - GiB - TiB - PiB -) - -// Common byte sizes for SI (decimal) units -const ( - KB ByteSize = 1000 - MB ByteSize = KB * 1000 - GB ByteSize = MB * 1000 - TB ByteSize = GB * 1000 - PB ByteSize = TB * 1000 -) - -var ( - ErrInvalidByteSize = errors.New("invalid byte size") - // Support both IEC and SI units, with optional space and case insensitive - byteSizeRegex = regexp.MustCompile(`^(\d+(?:\.\d+)?)\s*(?i:([KMGTP]I?B|[KMGTP]|B(?:YTE(?:S)?)?)?)\s*$`) -) - -// String returns a human-readable representation of the byte size using IEC units -func (b ByteSize) String() string { - return b.Format(true) -} - -// Format returns a human-readable representation of the byte size -// If useIEC is true, uses binary units (KiB, MiB, etc.) -// If useIEC is false, uses decimal units (KB, MB, etc.) -func (b ByteSize) Format(useIEC bool) string { - abs := b - if b < 0 { - abs = -b - } - - if useIEC { - switch { - case abs >= PiB: - return fmt.Sprintf("%.2fPiB", float64(b)/float64(PiB)) - case abs >= TiB: - return fmt.Sprintf("%.2fTiB", float64(b)/float64(TiB)) - case abs >= GiB: - return fmt.Sprintf("%.2fGiB", float64(b)/float64(GiB)) - case abs >= MiB: - return fmt.Sprintf("%.2fMiB", float64(b)/float64(MiB)) - case abs >= KiB: - return fmt.Sprintf("%.2fKiB", float64(b)/float64(KiB)) - default: - return fmt.Sprintf("%dB", b) - } - } else { - switch { - case abs >= PB: - return fmt.Sprintf("%.2fPB", float64(b)/float64(PB)) - case abs >= TB: - return fmt.Sprintf("%.2fTB", float64(b)/float64(TB)) - case abs >= GB: - return fmt.Sprintf("%.2fGB", float64(b)/float64(GB)) - case abs >= MB: - return fmt.Sprintf("%.2fMB", float64(b)/float64(MB)) - case abs >= KB: - return fmt.Sprintf("%.2fKB", float64(b)/float64(KB)) - default: - return fmt.Sprintf("%dB", b) - } - } -} - -// Int64 returns the size as an int64 -func (b ByteSize) Int64() int64 { - return int64(b) -} - -// ParseByteSize parses a string representation of bytes into a ByteSize value -func ParseByteSize(s string) (ByteSize, error) { - if s == "" { - return 0, ErrInvalidByteSize - } - - matches := byteSizeRegex.FindStringSubmatch(strings.ToUpper(s)) - if matches == nil { - return 0, ErrInvalidByteSize - } - - value, err := strconv.ParseFloat(matches[1], 64) - if err != nil { - return 0, ErrInvalidByteSize - } - - unit := matches[2] - if unit == "" || unit == "B" || unit == "BYTE" || unit == "BYTES" { - return ByteSize(value), nil - } - - // Check if it's an IEC unit (has 'I' in it) - isIEC := strings.Contains(unit, "I") - unitChar := rune(unit[0]) - - var multiplier ByteSize - switch unitChar { - case 'K': - multiplier = KiB - if !isIEC { - multiplier = KB - } - case 'M': - multiplier = MiB - if !isIEC { - multiplier = MB - } - case 'G': - multiplier = GiB - if !isIEC { - multiplier = GB - } - case 'T': - multiplier = TiB - if !isIEC { - multiplier = TB - } - case 'P': - multiplier = PiB - if !isIEC { - multiplier = PB - } - default: - return 0, ErrInvalidByteSize - } - - return ByteSize(value * float64(multiplier)), nil -} - -// MarshalText implements the encoding.TextMarshaler interface -func (b ByteSize) MarshalText() ([]byte, error) { - return []byte(b.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface -func (b *ByteSize) UnmarshalText(text []byte) error { - size, err := ParseByteSize(string(text)) - if err != nil { - return err - } - *b = size - return nil -} diff --git a/internal/utils/bytesize/bytesize_test.go b/internal/utils/bytesize/bytesize_test.go deleted file mode 100644 index 8a6a2ac..0000000 --- a/internal/utils/bytesize/bytesize_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package bytesize - -import ( - "testing" -) - -func TestParseByteSizeValid(t *testing.T) { - tests := []struct { - input string - expected ByteSize - }{ - // Basic byte values - {"0", 0}, - {"1024", 1024}, - {"1024B", 1024}, - {"1024 B", 1024}, - {"1024 BYTES", 1024}, - {"1024 BYTE", 1024}, - - // IEC units (binary) - {"1KiB", KiB}, - {"1 KiB", KiB}, - {"1.5KiB", ByteSize(float64(KiB) * 1.5)}, - {"1MiB", MiB}, - {"1.5MiB", ByteSize(float64(MiB) * 1.5)}, - {"1GiB", GiB}, - {"1TiB", TiB}, - {"1PiB", PiB}, - - // SI units (decimal) - {"1KB", KB}, - {"1 KB", KB}, - {"1.5KB", ByteSize(float64(KB) * 1.5)}, - {"1MB", MB}, - {"1.5MB", ByteSize(float64(MB) * 1.5)}, - {"1GB", GB}, - {"1TB", TB}, - {"1PB", PB}, - - // Short forms (default to SI units) - {"1K", KB}, // Defaults to SI - {"1M", MB}, - {"1G", GB}, - {"1T", TB}, - {"1P", PB}, - - // Case insensitivity - {"1kb", KB}, - {"1kib", KiB}, - {"1mB", MB}, - {"1mIb", MiB}, - {"1Kb", KB}, - {"1KiB", KiB}, - - // With spaces - {"1 KB", KB}, - {"1 KiB", KiB}, - {"1 MB", MB}, - {"1 MiB", MiB}, - } - - for _, test := range tests { - t.Run(test.input, func(t *testing.T) { - result, err := ParseByteSize(test.input) - if err != nil { - t.Errorf("ParseByteSize(%q) returned error: %v", test.input, err) - } - if result != test.expected { - t.Errorf("ParseByteSize(%q) = %v, want %v", test.input, result, test.expected) - } - }) - } -} - -func TestParseByteSizeInvalid(t *testing.T) { - tests := []string{ - "", - "abc", - "1XB", - "1.5.5MB", - "-KB", - "KB", - "1KB1", - "1.KB", - ".5KB", - } - - for _, test := range tests { - t.Run(test, func(t *testing.T) { - _, err := ParseByteSize(test) - if err == nil { - t.Errorf("ParseByteSize(%q) should have returned an error", test) - } - }) - } -} - -func TestByteSizeString(t *testing.T) { - tests := []struct { - input ByteSize - expected string - }{ - {0, "0B"}, - {512, "512B"}, - {KiB, "1.00KiB"}, - {ByteSize(float64(KiB) * 1.5), "1.50KiB"}, - {MiB, "1.00MiB"}, - {ByteSize(float64(MiB) * 2.25), "2.25MiB"}, - {GiB, "1.00GiB"}, - {TiB, "1.00TiB"}, - {PiB, "1.00PiB"}, - } - - for _, test := range tests { - t.Run(test.expected, func(t *testing.T) { - result := test.input.String() - if result != test.expected { - t.Errorf("ByteSize(%d).String() = %q, want %q", test.input, result, test.expected) - } - }) - } -} - -func TestByteSizeFormat(t *testing.T) { - tests := []struct { - input ByteSize - useIEC bool - expected string - }{ - // IEC (binary) format - {KiB, true, "1.00KiB"}, - {MiB, true, "1.00MiB"}, - {GiB, true, "1.00GiB"}, - {ByteSize(float64(KiB) * 1.5), true, "1.50KiB"}, - - // SI (decimal) format - {KB, false, "1.00KB"}, - {MB, false, "1.00MB"}, - {GB, false, "1.00GB"}, - {ByteSize(float64(KB) * 1.5), false, "1.50KB"}, - - // Edge cases - {0, true, "0B"}, - {0, false, "0B"}, - {512, true, "512B"}, - {512, false, "512B"}, - } - - for _, test := range tests { - t.Run(test.expected, func(t *testing.T) { - result := test.input.Format(test.useIEC) - if result != test.expected { - t.Errorf("ByteSize(%d).Format(%v) = %q, want %q", - test.input, test.useIEC, result, test.expected) - } - }) - } -} - -func TestTextMarshaling(t *testing.T) { - tests := []struct { - size ByteSize - expected string - }{ - {KiB, "1.00KiB"}, - {MiB, "1.00MiB"}, - {ByteSize(float64(GiB) * 1.5), "1.50GiB"}, - } - - for _, test := range tests { - t.Run(test.expected, func(t *testing.T) { - // Test marshaling - bytes, err := test.size.MarshalText() - if err != nil { - t.Errorf("MarshalText() returned error: %v", err) - } - if string(bytes) != test.expected { - t.Errorf("MarshalText() = %q, want %q", string(bytes), test.expected) - } - - // Test unmarshaling - var size ByteSize - err = size.UnmarshalText([]byte(test.expected)) - if err != nil { - t.Errorf("UnmarshalText(%q) returned error: %v", test.expected, err) - } - if size != test.size { - t.Errorf("UnmarshalText(%q) = %v, want %v", test.expected, size, test.size) - } - }) - } -}