diff --git a/pwlib/ssha.go b/pwlib/ssha.go index bc0d4a0..82cba93 100644 --- a/pwlib/ssha.go +++ b/pwlib/ssha.go @@ -10,25 +10,32 @@ import ( "strings" ) +// SSHASaltLen is the length of the salt for SSHA +const SSHASaltLen = 4 + +// SSHAPrefix is the prefix for SSHA +const SSHAPrefix = "{SSHA}" + +// SSHEncoder is the encoder object for SSHA type SSHAEncoder struct { } // Encode encodes the []byte of raw password -func (enc SSHAEncoder) Encode(rawPassPhrase []byte) ([]byte, error) { +func (enc SSHAEncoder) Encode(rawPassPhrase []byte, prefix string) ([]byte, error) { salt, err := makeSSHASalt() if err != nil { return []byte{}, err } hash := makeSSHAHash(rawPassPhrase, salt) b64 := base64.StdEncoding.EncodeToString(hash) - return []byte(fmt.Sprintf("{SSHA}%s", b64)), nil + return []byte(fmt.Sprintf("%s%s", prefix, b64)), nil } // Matches matches the encoded password and the raw password func (enc SSHAEncoder) Matches(encodedPassPhrase, rawPassPhrase []byte) bool { // strip the {SSHA} eppS := string(encodedPassPhrase) - if strings.HasPrefix(string(encodedPassPhrase), "{SSHA}") { + if strings.HasPrefix(string(encodedPassPhrase), SSHAPrefix) { eppS = string(encodedPassPhrase)[6:] } hash, err := base64.StdEncoding.DecodeString(eppS) @@ -43,12 +50,12 @@ func (enc SSHAEncoder) Matches(encodedPassPhrase, rawPassPhrase []byte) bool { sum := sha.Sum(nil) // compare without the last 4 bytes of the hash with the calculated hash - return bytes.Equal(sum, hash[:len(hash)-4]) + return bytes.Equal(sum, hash[:len(hash)-SSHASaltLen]) } // makeSSHASalt make 4Byte salt for SSHA hashing func makeSSHASalt() (salt []byte, err error) { - salt, err = makeSalt(4) + salt, err = makeSalt(SSHASaltLen) return } diff --git a/pwlib/ssha_test.go b/pwlib/ssha_test.go index 3a555a9..ff2dd6e 100644 --- a/pwlib/ssha_test.go +++ b/pwlib/ssha_test.go @@ -12,16 +12,16 @@ var sshaPlain = []byte("password") func TestSSHA(t *testing.T) { t.Run("TestSSHAEncoder_Encode", func(t *testing.T) { enc := SSHAEncoder{} - encoded, err := enc.Encode(sshaPlain) + encoded, err := enc.Encode(sshaPlain, SSHAPrefix) require.NoError(t, err) - assert.Contains(t, string(encoded), "{SSHA}") + assert.Contains(t, string(encoded), SSHAPrefix) assert.Greater(t, len(encoded), 6) t.Log(string(encoded)) }) t.Run("TestSSHAEncoder_Matches", func(t *testing.T) { enc := SSHAEncoder{} - encoded, err := enc.Encode(sshaPlain) + encoded, err := enc.Encode(sshaPlain, SSHAPrefix) require.NoError(t, err) assert.True(t, enc.Matches(encoded, sshaPlain)) })