Skip to content

Commit

Permalink
Fixed compile error, moved linux-only test to its own test file
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmu committed Jan 10, 2024
1 parent e19aad2 commit 132bf6a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 47 deletions.
61 changes: 61 additions & 0 deletions go/mysql/auth_server_static_linux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// +build linux

package mysql

func TestStaticConfigHUP(t *testing.T) {
tmpFile, err := ioutil.TempFile("", "mysql_auth_server_static_file.json")
if err != nil {
t.Fatalf("couldn't create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
oldStr := "str5"
jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr)
if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
t.Fatalf("couldn't write temp file: %v", err)
}

aStatic := NewAuthServerStatic(tmpFile.Name(), "", 0)
defer aStatic.close()

if aStatic.getEntries()[oldStr][0].Password != oldStr {
t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr)
}

hupTest(t, aStatic, tmpFile, oldStr, "str2")
hupTest(t, aStatic, tmpFile, "str2", "str3") // still handling the signal
}

func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) {
jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr)
if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
t.Fatalf("couldn't overwrite temp file: %v", err)
}

if aStatic.getEntries()[oldStr][0].Password != oldStr {
t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr)
}

syscall.Kill(syscall.Getpid(), syscall.SIGHUP)
time.Sleep(100 * time.Millisecond) // wait for signal handler

if aStatic.getEntries()[oldStr] != nil {
t.Fatalf("Should not have old %s after config reload", oldStr)
}
if aStatic.getEntries()[newStr][0].Password != newStr {
t.Fatalf("%s's Password should be '%s'", newStr, newStr)
}
}
45 changes: 0 additions & 45 deletions go/mysql/auth_server_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"io/ioutil"
"net"
"os"
"syscall"
"testing"
"time"
)
Expand Down Expand Up @@ -126,29 +125,6 @@ func TestHostMatcher(t *testing.T) {
}
}

func TestStaticConfigHUP(t *testing.T) {
tmpFile, err := ioutil.TempFile("", "mysql_auth_server_static_file.json")
if err != nil {
t.Fatalf("couldn't create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
oldStr := "str5"
jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr)
if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
t.Fatalf("couldn't write temp file: %v", err)
}

aStatic := NewAuthServerStatic(tmpFile.Name(), "", 0)
defer aStatic.close()

if aStatic.getEntries()[oldStr][0].Password != oldStr {
t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr)
}

hupTest(t, aStatic, tmpFile, oldStr, "str2")
hupTest(t, aStatic, tmpFile, "str2", "str3") // still handling the signal
}

func TestStaticConfigHUPWithRotation(t *testing.T) {
tmpFile, err := ioutil.TempFile("", "mysql_auth_server_static_file.json")
if err != nil {
Expand Down Expand Up @@ -180,27 +156,6 @@ func TestStaticConfigHUPWithRotation(t *testing.T) {
aStatic.close()
}

func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) {
jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr)
if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
t.Fatalf("couldn't overwrite temp file: %v", err)
}

if aStatic.getEntries()[oldStr][0].Password != oldStr {
t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr)
}

syscall.Kill(syscall.Getpid(), syscall.SIGHUP)
time.Sleep(100 * time.Millisecond) // wait for signal handler

if aStatic.getEntries()[oldStr] != nil {
t.Fatalf("Should not have old %s after config reload", oldStr)
}
if aStatic.getEntries()[newStr][0].Password != newStr {
t.Fatalf("%s's Password should be '%s'", newStr, newStr)
}
}

func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) {
jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr)
if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ func (th *testHandler) ComInitDB(c *Conn, schemaName string) error {
return nil
}

func (th *testHandler) ComMultiQuery(c *Conn, query string, callback func(res *sqltypes.Result, more bool) error) (string, error) {
func (th *testHandler) ComMultiQuery(c *Conn, query string, callback ResultSpoolFn) (string, error) {
err := th.ComQuery(c, query, callback)
return "", err
}

func (th *testHandler) ComQuery(c *Conn, query string, callback func(res *sqltypes.Result, more bool) error) error {
func (th *testHandler) ComQuery(c *Conn, query string, callback ResultSpoolFn) error {
if result := th.Result(); result != nil {
callback(th.result, false)
return nil
Expand Down

0 comments on commit 132bf6a

Please sign in to comment.