Skip to content

Commit

Permalink
Minor tweaks and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
70sh1 committed Nov 29, 2023
1 parent d7cb24f commit fc9e661
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 50 deletions.
86 changes: 39 additions & 47 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,12 @@ func NewProcessor(sourcePath string, password string, mode string) (*processor,
}
debug.FreeOSMemory() // Free memory held after scrypt call

blakeKey := make([]byte, 64)
c, err := chacha20.NewUnauthenticatedCipher(key, nonce)
if err != nil {
return nil, fmt.Errorf("error initializing cipher; %v", err)
}

blakeKey := make([]byte, 64)
c.XORKeyStream(blakeKey, blakeKey)
blake, err := blake2b.New512(blakeKey)
if err != nil {
Expand All @@ -113,40 +114,40 @@ func NewProcessor(sourcePath string, password string, mode string) (*processor,
return &processor{c, blake, file, nonce, salt, fileSize}, nil
}

// Read bytes from encryptor's source (file) into buffer b, truncate it if n < len(b),
// Read len(b) bytes from encryptor's source (file) into buffer b, truncate it if n < len(b),
// XOR it, update the encryptor's HMAC with the resulting slice,
// return number of bytes read and error.
func (ab *encryptor) Read(b []byte) (int, error) {
n, err := ab.source.Read(b)
func (e *encryptor) Read(b []byte) (int, error) {
n, err := e.source.Read(b)
if n > 0 {
b = b[:n]
ab.c.XORKeyStream(b, b)
if err := ab.updateHmac(b); err != nil {
e.c.XORKeyStream(b, b)
if err := e.updateHmac(b); err != nil {
return n, err
}
return n, err
}
return 0, io.EOF
}

// Read bytes from decryptor's source (file) into buffer b, truncate it if n < len(b),
// Read len(b) bytes from decryptor's source (file) into buffer b, truncate it if n < len(b),
// update HMAC with slice, XOR the slice,
// return number of bytes read and error.
func (ab *decryptor) Read(b []byte) (int, error) {
n, err := ab.source.Read(b)
func (d *decryptor) Read(b []byte) (int, error) {
n, err := d.source.Read(b)
if n > 0 {
b = b[:n]
if err := ab.updateHmac(b); err != nil {
if err := d.updateHmac(b); err != nil {
return n, err
}
ab.c.XORKeyStream(b, b)
d.c.XORKeyStream(b, b)
return n, err
}
return 0, io.EOF
}

func (ab *processor) updateHmac(data []byte) error {
n, err := ab.hmac.Write(data)
func (p *processor) updateHmac(data []byte) error {
n, err := p.hmac.Write(data)
if err != nil {
return err
}
Expand Down Expand Up @@ -188,7 +189,7 @@ func closeAndRemove(f *os.File) {
os.Remove(f.Name())
}

func limitStringLength(s string, n int) string {
func filenameOverflow(s string, n int) string {
if len(s) > n {
return s[:n] + "..."
}
Expand Down Expand Up @@ -242,20 +243,21 @@ func cleanAndCheckPaths(paths []string, outputDir string) ([]string, string, err
return paths, outputDir, nil
}

// Create new progress bar pool.
func newBarPool(paths []string) (pool *pb.Pool, bars []*pb.ProgressBar) {
barTmpl := `{{ string . "status" }} {{ string . "filename" }} {{ string . "filesize" }} {{ bar . "[" "-" ">" " " "]" }} {{ string . "error" }}`
for _, path := range paths {
bar := pb.New64(1).SetTemplateString(barTmpl).SetWidth(80)
bar.Set("status", " ")
bar.Set("filename", limitStringLength(filepath.Base(path), 25))
bar.Set("filename", filenameOverflow(filepath.Base(path), 25))
bars = append(bars, bar)
}
return pb.NewPool(bars...), bars
}

func generatePassphrase(length int) (string, error) {
if length < 6 {
return "", errors.New("length < 6 is not secure")
return "", errors.New("length less than 6 is not secure")
}
wordlist, err := embedded.ReadFile("wordlist.txt")
if err != nil {
Expand All @@ -281,22 +283,25 @@ func deriveKey(password string, salt []byte) ([]byte, error) {
return scrypt.Key([]byte(password), salt, 65536, 8, 1, 32) // 65536 == 2^16
}

func barFail(bar *pb.ProgressBar, err error) {
bar.Set("status", "❌")
bar.Set("error", err)
}

func encryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
processor, err := NewProcessor(pathIn, password, "enc")
if err != nil {
// Moving these repetitive lines to the function call would be nice and much cleaner,
// Moving this repetitive line to this function's call would be nice and much cleaner,
// but then the bar doesn't update properly for some reason.
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}
encryptor := &encryptor{processor}
defer encryptor.source.Close()

tmpFile, err := os.CreateTemp(filepath.Dir(pathOut), "*.tmp")
if err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}
defer closeAndRemove(tmpFile)
Expand All @@ -308,8 +313,7 @@ func encryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
header = append(header, tagPlaceholder...)

if _, err := tmpFile.Write(header); err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

Expand All @@ -319,28 +323,24 @@ func encryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
defer w.Close()

if _, err := io.Copy(w, encryptor); err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

tag := encryptor.hmac.Sum(nil)
if _, err := tmpFile.Seek(int64(len(encryptor.nonce)+len(encryptor.hmacSalt)), 0); err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}
if _, err := tmpFile.Write(tag); err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

tmpFile.Close()
encryptor.source.Close()
if err := os.Rename(tmpFile.Name(), pathOut); err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

Expand All @@ -352,8 +352,7 @@ func encryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
func decryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
processor, err := NewProcessor(pathIn, password, "dec")
if err != nil {
bar.Set("status", "❌")
bar.Set("error", errors.Unwrap(err))
barFail(bar, err)
return err
}
decryptor := &decryptor{processor}
Expand All @@ -362,15 +361,13 @@ func decryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
expectedTag := make([]byte, 64)
n, err := decryptor.source.Read(expectedTag)
if n != 64 || err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

tmpFile, err := os.CreateTemp(filepath.Dir(pathOut), "*.tmp")
if err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}
defer closeAndRemove(tmpFile)
Expand All @@ -381,8 +378,7 @@ func decryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
defer w.Close()

if _, err := io.Copy(w, decryptor); err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

Expand All @@ -392,14 +388,12 @@ func decryptFile(pathIn, pathOut, password string, bar *pb.ProgressBar) error {
actualTag := decryptor.hmac.Sum(nil)
if !bytes.Equal(actualTag, expectedTag) {
err = errors.New("incorrect password or corrupt/forged data")
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

if err := os.Rename(tmpFile.Name(), pathOut); err != nil {
bar.Set("status", "❌")
bar.Set("error", err)
barFail(bar, err)
return err
}

Expand Down Expand Up @@ -428,8 +422,7 @@ func encryptFiles(paths []string, outputDir, password string, overwrite bool) (i
fileOut = filepath.Join(outputDir, filepath.Base(fileOut))
}
if _, err := os.Stat(fileOut); !errors.Is(err, os.ErrNotExist) && !overwrite {
bar.Set("status", "❌")
bar.Set("error", "output already exists")
barFail(bar, errors.New("output already exists"))
return
}
if err := encryptFile(fileIn, fileOut, password, bar); err != nil {
Expand Down Expand Up @@ -463,8 +456,7 @@ func decryptFiles(paths []string, outputDir, password string, overwrite bool) er
fileOut = filepath.Join(outputDir, filepath.Base(fileOut))
}
if _, err := os.Stat(fileOut); !errors.Is(err, os.ErrNotExist) && !overwrite {
bar.Set("status", "❌")
bar.Set("error", "output already exists")
barFail(bar, errors.New("output already exists"))
return
}
decryptFile(fileIn, fileOut, password, bar)
Expand Down
5 changes: 2 additions & 3 deletions eddy.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ func main() {
Aliases: []string{"enc", "e"},
Usage: "encrypt provided `FILES`",
Action: func(cCtx *cli.Context) error {
fmt.Println()
var noPasswordProvided bool
var numProcessed int64
var err error
Expand Down Expand Up @@ -88,7 +87,6 @@ func main() {
}

startTime := time.Now()

if numProcessed, err = encryptFiles(paths, outputDir, password, overwrite); err != nil {
log.Fatal(err)
}
Expand All @@ -106,7 +104,6 @@ func main() {
Aliases: []string{"dec", "d"},
Usage: "decrypt provided `FILES`",
Action: func(cCtx *cli.Context) error {
fmt.Println()
var err error
paths := append(cCtx.Args().Tail(), cCtx.Args().First())
if paths, outputDir, err = cleanAndCheckPaths(paths, outputDir); err != nil {
Expand All @@ -119,6 +116,7 @@ func main() {
log.Fatal(err)
}
}

startTime := time.Now()
if err := decryptFiles(paths, outputDir, password, overwrite); err != nil {
log.Fatal(err)
Expand All @@ -131,6 +129,7 @@ func main() {
},
},
}
fmt.Println()
log.SetFlags(0) // Remove date/time prefix from logger
log.SetPrefix("❗ ERROR: ") // Only logging errors with log.Fatal so this prefix is set
if err := app.Run(os.Args); err != nil {
Expand Down

0 comments on commit fc9e661

Please sign in to comment.