Skip to content

Commit

Permalink
Merge pull request #3 from spreadshirt/reasonable-improvements
Browse files Browse the repository at this point in the history
Reasonable improvements!
  • Loading branch information
heyLu authored Mar 6, 2023
2 parents d77e1ca + c7c7d4b commit dc05674
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 24 deletions.
109 changes: 85 additions & 24 deletions prom-revisionist.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package main

import (
"bytes"
"compress/gzip"
"compress/zlib"
"encoding/json"
"flag"
"fmt"
"io"
Expand All @@ -10,7 +13,6 @@ import (
"net/url"
"os"
"regexp"
"strings"

"github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/promql/parser"
Expand Down Expand Up @@ -109,18 +111,31 @@ func main() {
u := req.URL
u.Scheme = upstreamURL.Scheme
u.Host = upstreamURL.Host
// TODO: modify query in url.Query/url.RawQuery
if u.Query().Get("query") != "" {
query := u.Query().Get("query")
before := query

query, rev, err = rewrite(revisionists, query)
if err != nil {
log.Printf("could not rewrite: %s", err)
}

if rev != nil {
log.Printf("rewriting!\n%s\n=>\n%s", before, query)

params := u.Query()
params.Set("query", query)
u.RawQuery = params.Encode()
wasRewrite = true
}
}
proxyReq, err := http.NewRequest(req.Method, u.String(), bytes.NewBuffer(body))
if err != nil {
log.Printf("failed to created request: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
proxyReq.Header = req.Header
if wasRewrite {
// TODO: allow keeping gzip and other encodings (handle them transparently)
proxyReq.Header.Del("Accept-Encoding")
}

resp, err := http.DefaultClient.Do(proxyReq)
if err != nil {
Expand All @@ -130,7 +145,13 @@ func main() {
}
defer resp.Body.Close()

contentEncoding := resp.Header.Get("Content-Encoding")
for name, vals := range resp.Header {
if name == "Content-Length" && contentEncoding != "" {
// recompression changes the length, skip
continue
}

for _, val := range vals {
w.Header().Add(name, val)
}
Expand All @@ -145,27 +166,60 @@ func main() {

var in io.Reader = resp.Body
if wasRewrite {
buf := new(bytes.Buffer)
_, err = io.Copy(buf, resp.Body)
if err != nil {
log.Printf("could not write body: %s", err)
return
if contentEncoding != "" {
switch contentEncoding {
case "gzip":
in, err = gzip.NewReader(in)
if err != nil {
log.Printf("could not create gzip reader: %s", err)
return
}

zw := gzip.NewWriter(out)
defer zw.Close()

out = zw
case "deflate":
in, err = zlib.NewReader(in)
if err != nil {
log.Printf("could not create deflate reader: %s", err)
return
}

zw := zlib.NewWriter(out)
defer zw.Close()

out = zw
default:
log.Printf("unhandled compression %q", contentEncoding)
return
}
}

res := buf.String()
for from, to := range rev.config.RenameLabels {
// TODO: rewrite by using streaming in some way
res = strings.Replace(res, `"`+to+`"`, `"`+from+`"`, -1)
}
dec := json.NewDecoder(in)
tw := NewTokenWriter(out, dec)
token, err := dec.Token()
for err == nil {
switch tok := token.(type) {
case string:
replace, ok := rev.config.RenameLabelsReverse[tok]
if ok {
token = replace
}
}
err = tw.Write(token)
if err != nil {
break
}

buf.Reset()
_, err = buf.WriteString(res)
if err != nil {
log.Printf("could not rewrite: %s", err)
token, err = dec.Token()
}
if err != nil && err != io.EOF {
log.Printf("could not decode response: %s", err)
return
}

in = buf
return
}

_, err = io.Copy(out, in)
Expand Down Expand Up @@ -257,9 +311,10 @@ type RewriteConfig struct {
MatchRaw string `yaml:"match"`
WithRaw string `yaml:"with"`
} `yaml:"wrap"`
RenameMetrics map[string]string `yaml:"rename-metrics"`
RenameLabels map[string]string `yaml:"rename-labels"`
RewriteMatchers []struct {
RenameMetrics map[string]string `yaml:"rename-metrics"`
RenameLabels map[string]string `yaml:"rename-labels"`
RenameLabelsReverse map[string]string `yaml:"-"`
RewriteMatchers []struct {
From *labels.Matcher `yaml:"-"`
To *labels.Matcher `yaml:"-"`
FromRaw string `yaml:"from"`
Expand Down Expand Up @@ -307,6 +362,12 @@ func (r *RewriteConfig) Parse() error {
}
}

// reverse label renames to easily convert them back in the response
r.RenameLabelsReverse = make(map[string]string, len(r.RenameLabels))
for key, val := range r.RenameLabels {
r.RenameLabelsReverse[val] = key
}

for j, matcher := range r.RewriteMatchers {
matchers, err := parser.ParseMetricSelector(matcher.FromRaw)
if err != nil {
Expand Down
137 changes: 137 additions & 0 deletions token_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package main

import (
"encoding/json"
"fmt"
"io"
)

type tokenWriter struct {
w io.Writer
dec *json.Decoder

writers []TokenWriter
}

func NewTokenWriter(w io.Writer, dec *json.Decoder) *tokenWriter {
return &tokenWriter{
w: w,
dec: dec,
writers: []TokenWriter{simpleWriter{}},
}
}

func (tw *tokenWriter) Write(token json.Token) error {
if len(tw.writers) == 0 {
return fmt.Errorf("no writer for token %#v", token)
}

curTw := tw.writers[len(tw.writers)-1]
if curTw.Done(token) {
tw.writers = tw.writers[:len(tw.writers)-1]
curTw = tw.writers[len(tw.writers)-1]
}

nextTw, err := curTw.Write(tw.w, token, tw.dec.More())
if err != nil {
return fmt.Errorf("could not write: %w", err)
}
if nextTw != nil {
tw.writers = append(tw.writers, nextTw)
}

return nil
}

type TokenWriter interface {
Done(json.Token) bool
Write(io.Writer, json.Token, bool) (TokenWriter, error)
}

type simpleWriter struct{}

func (sw simpleWriter) Done(json.Token) bool { return false }
func (sw simpleWriter) Write(w io.Writer, token json.Token, _ bool) (TokenWriter, error) {
var err error
var writer TokenWriter

switch tok := token.(type) {
case bool:
_, err = fmt.Fprintf(w, "%v", tok)
case string:
_, err = fmt.Fprintf(w, "%q", tok)
case json.Number:
_, err = w.Write([]byte(tok))
case float64:
_, err = fmt.Fprintf(w, "%f", tok)
case nil:
_, err = w.Write([]byte("null"))
case json.Delim:
if tok == '{' {
writer = &objectWriter{}
} else if tok == '[' {
writer = &arrayWriter{}
}
_, err = w.Write([]byte(tok.String()))
default:
err = fmt.Errorf("unhandled value %#v of type %T", token, token)
}

return writer, err
}

type objectWriter struct {
simpleWriter

n int
}

func (ow *objectWriter) Done(token json.Token) bool {
return token == json.Delim('}')
}

func (ow *objectWriter) Write(w io.Writer, token json.Token, more bool) (TokenWriter, error) {
tw, err := ow.simpleWriter.Write(w, token, more)
if err != nil {
return nil, err
}
if tw != nil {
return tw, err
}

if ow.n%2 == 0 && more {
_, err = w.Write([]byte(":"))
} else if more {
_, err = w.Write([]byte(","))
}
ow.n += 1

return tw, err
}

type arrayWriter struct {
simpleWriter

n int
}

func (aw *arrayWriter) Done(token json.Token) bool {
return token == json.Delim(']')
}

func (aw *arrayWriter) Write(w io.Writer, token json.Token, more bool) (TokenWriter, error) {
tw, err := aw.simpleWriter.Write(w, token, more)
if err != nil {
return nil, err
}
if tw != nil {
return tw, nil
}

if more {
_, err = w.Write([]byte(","))
}
aw.n += 1

return tw, err
}
Loading

0 comments on commit dc05674

Please sign in to comment.