diff --git a/go/mysql/auth_server_static_linux_test.go b/go/mysql/auth_server_static_linux_test.go new file mode 100755 index 00000000000..d7d8d3c8bf5 --- /dev/null +++ b/go/mysql/auth_server_static_linux_test.go @@ -0,0 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package mysql + +func TestStaticConfigHUP(t *testing.T) { + tmpFile, err := ioutil.TempFile("", "mysql_auth_server_static_file.json") + if err != nil { + t.Fatalf("couldn't create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + oldStr := "str5" + jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr) + if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { + t.Fatalf("couldn't write temp file: %v", err) + } + + aStatic := NewAuthServerStatic(tmpFile.Name(), "", 0) + defer aStatic.close() + + if aStatic.getEntries()[oldStr][0].Password != oldStr { + t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) + } + + hupTest(t, aStatic, tmpFile, oldStr, "str2") + hupTest(t, aStatic, tmpFile, "str2", "str3") // still handling the signal +} + +func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) { + jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr) + if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { + t.Fatalf("couldn't overwrite temp file: %v", err) + } + + if aStatic.getEntries()[oldStr][0].Password != oldStr { + t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) + } + + syscall.Kill(syscall.Getpid(), syscall.SIGHUP) + time.Sleep(100 * time.Millisecond) // wait for signal handler + + if aStatic.getEntries()[oldStr] != nil { + t.Fatalf("Should not have old %s after config reload", oldStr) + } + if aStatic.getEntries()[newStr][0].Password != newStr { + t.Fatalf("%s's Password should be '%s'", newStr, newStr) + } +} \ No newline at end of file diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index 16b263b1e73..60630d31941 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -21,7 +21,6 @@ import ( "io/ioutil" "net" "os" - "syscall" "testing" "time" ) @@ -126,29 +125,6 @@ func TestHostMatcher(t *testing.T) { } } -func TestStaticConfigHUP(t *testing.T) { - tmpFile, err := ioutil.TempFile("", "mysql_auth_server_static_file.json") - if err != nil { - t.Fatalf("couldn't create temp file: %v", err) - } - defer os.Remove(tmpFile.Name()) - oldStr := "str5" - jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr) - if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { - t.Fatalf("couldn't write temp file: %v", err) - } - - aStatic := NewAuthServerStatic(tmpFile.Name(), "", 0) - defer aStatic.close() - - if aStatic.getEntries()[oldStr][0].Password != oldStr { - t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) - } - - hupTest(t, aStatic, tmpFile, oldStr, "str2") - hupTest(t, aStatic, tmpFile, "str2", "str3") // still handling the signal -} - func TestStaticConfigHUPWithRotation(t *testing.T) { tmpFile, err := ioutil.TempFile("", "mysql_auth_server_static_file.json") if err != nil { @@ -180,27 +156,6 @@ func TestStaticConfigHUPWithRotation(t *testing.T) { aStatic.close() } -func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) { - jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr) - if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { - t.Fatalf("couldn't overwrite temp file: %v", err) - } - - if aStatic.getEntries()[oldStr][0].Password != oldStr { - t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) - } - - syscall.Kill(syscall.Getpid(), syscall.SIGHUP) - time.Sleep(100 * time.Millisecond) // wait for signal handler - - if aStatic.getEntries()[oldStr] != nil { - t.Fatalf("Should not have old %s after config reload", oldStr) - } - if aStatic.getEntries()[newStr][0].Password != newStr { - t.Fatalf("%s's Password should be '%s'", newStr, newStr) - } -} - func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) { jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr) if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index d7a7182f61b..cd93fa438d2 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -122,12 +122,12 @@ func (th *testHandler) ComInitDB(c *Conn, schemaName string) error { return nil } -func (th *testHandler) ComMultiQuery(c *Conn, query string, callback func(res *sqltypes.Result, more bool) error) (string, error) { +func (th *testHandler) ComMultiQuery(c *Conn, query string, callback ResultSpoolFn) (string, error) { err := th.ComQuery(c, query, callback) return "", err } -func (th *testHandler) ComQuery(c *Conn, query string, callback func(res *sqltypes.Result, more bool) error) error { +func (th *testHandler) ComQuery(c *Conn, query string, callback ResultSpoolFn) error { if result := th.Result(); result != nil { callback(th.result, false) return nil