diff --git a/internal/corazawaf/transaction.go b/internal/corazawaf/transaction.go index 720c78272..b1bc1ecbd 100644 --- a/internal/corazawaf/transaction.go +++ b/internal/corazawaf/transaction.go @@ -1356,22 +1356,25 @@ func (tx *Transaction) AuditLog() *auditlog.Log { HostIP_: tx.variables.serverAddr.Get(), HostPort_: hostPort, ServerID_: tx.variables.serverName.Get(), // TODO check + Request_: &auditlog.TransactionRequest{ + Method_: tx.variables.requestMethod.Get(), + URI_: tx.variables.requestURI.Get(), + Protocol_: tx.variables.requestProtocol.Get(), + }, } for _, part := range tx.AuditLogParts { switch part { case types.AuditLogPartRequestHeaders: - if al.Transaction_.Request_ == nil { - al.Transaction_.Request_ = &auditlog.TransactionRequest{} - } al.Transaction_.Request_.Headers_ = tx.variables.requestHeaders.Data() case types.AuditLogPartRequestBody: - if al.Transaction_.Request_ == nil { - al.Transaction_.Request_ = &auditlog.TransactionRequest{} + reader, err := tx.requestBodyBuffer.Reader() + if err == nil { + content, err := io.ReadAll(reader) + if err == nil { + al.Transaction_.Request_.Body_ = string(content) + } } - // TODO maybe change to: - // al.Transaction.Request.Body = tx.RequestBodyBuffer.String() - al.Transaction_.Request_.Body_ = tx.variables.requestBody.Get() /* * TODO: diff --git a/testing/auditlog_test.go b/testing/auditlog_test.go index 324ef1f7e..fdbfb2c5c 100644 --- a/testing/auditlog_test.go +++ b/testing/auditlog_test.go @@ -13,6 +13,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "testing" "github.com/corazawaf/coraza/v3/internal/auditlog" @@ -224,3 +225,111 @@ func TestAuditLogOnNoLog(t *testing.T) { t.Error(err) } } + +func TestAuditLogRequestMethodURIProtocol(t *testing.T) { + waf := corazawaf.NewWAF() + parser := seclang.NewParser(waf) + if err := parser.FromString(` + SecRuleEngine DetectionOnly + SecAuditEngine On + SecAuditLogFormat json + SecAuditLogType serial + `); err != nil { + t.Fatal(err) + } + // generate a random tmp file + file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { + t.Fatal(err) + } + tx := waf.NewTransaction() + + uri := "/some-url" + method := "POST" + proto := "HTTP/1.1" + + tx.ProcessURI(uri, method, proto) + // now we read file + if _, err := file.Seek(0, 0); err != nil { + t.Error(err) + } + tx.ProcessLogging() + var al2 auditlog.Log + if err := json.NewDecoder(file).Decode(&al2); err != nil { + t.Error(err) + } + trans := al2.Transaction() + if trans == nil { + t.Fatalf("Expected 1 transaction, got nil") + } + req := trans.Request() + if req == nil { + t.Fatalf("Expected 1 request, got nil") + } + if req.URI() != uri { + t.Fatalf("Expected %s uri, got %s", uri, req.URI()) + } + if req.Method() != method { + t.Fatalf("Expected %s method, got %s", method, req.Method()) + } + if req.Protocol() != proto { + t.Fatalf("Expected %s protocol, got %s", proto, req.Protocol()) + } +} + +func TestAuditLogRequestBody(t *testing.T) { + waf := corazawaf.NewWAF() + parser := seclang.NewParser(waf) + if err := parser.FromString(` + SecRuleEngine DetectionOnly + SecAuditEngine On + SecAuditLogFormat json + SecAuditLogType serial + SecRequestBodyAccess On + `); err != nil { + t.Fatal(err) + } + // generate a random tmp file + file, err := os.Create(filepath.Join(t.TempDir(), "tmp.log")) + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + if err := parser.FromString(fmt.Sprintf("SecAuditLog %s", file.Name())); err != nil { + t.Fatal(err) + } + tx := waf.NewTransaction() + params := "somepost=data" + _, _, err = tx.ReadRequestBodyFrom(strings.NewReader(params)) + if err != nil { + t.Error(err) + } + _, err = tx.ProcessRequestBody() + if err != nil { + t.Error(err) + } + // now we read file + if _, err := file.Seek(0, 0); err != nil { + t.Error(err) + } + tx.ProcessLogging() + var al2 auditlog.Log + if err := json.NewDecoder(file).Decode(&al2); err != nil { + t.Error(err) + } + trans := al2.Transaction() + if trans == nil { + t.Fatalf("Expected 1 transaction, got nil") + } + req := trans.Request() + if req == nil { + t.Fatalf("Expected 1 request, got nil") + } + if req.Body() != params { + t.Fatalf("Expected %s uri, got %s", params, req.Body()) + } +}