Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: fix security issues #1731

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/tidb-dashboard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ func NewCLIConfig() *DashboardCLIConfig {
flag.IntVar(&cfg.CoreConfig.NgmTimeout, "ngm-timeout", cfg.CoreConfig.NgmTimeout, "timeout secs for accessing the ngm API")
flag.BoolVar(&cfg.CoreConfig.EnableKeyVisualizer, "keyviz", true, "enable/disable key visualizer(default: true)")
flag.BoolVar(&cfg.CoreConfig.DisableCustomPromAddr, "disable-custom-prom-addr", false, "do not allow custom prometheus address")
flag.StringVar(&cfg.CoreConfig.SigningAlgorithm, "signing-algorithm", cfg.CoreConfig.SigningAlgorithm, "signing algorithm for jwt (HS256, HS384, HS512, RS256, RS384, RS512)")
flag.Float64Var(&cfg.CoreConfig.UnauthedAPIQpsLimit, "unauthed-api-qps-limit", cfg.CoreConfig.UnauthedAPIQpsLimit, "unauthed API qps limit")
flag.IntVar(&cfg.CoreConfig.UnauthedAPIBurstLimit, "unauthed-api-burst-limit", cfg.CoreConfig.UnauthedAPIBurstLimit, "unauthed API burst limit")

showVersion := flag.BoolP("version", "v", false, "print version information and exit")

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ require (
go.uber.org/zap v1.19.0
golang.org/x/oauth2 v0.11.0
golang.org/x/sync v0.3.0
golang.org/x/time v0.6.0
google.golang.org/grpc v1.59.0
google.golang.org/protobuf v1.33.0
gorm.io/datatypes v1.1.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
167 changes: 166 additions & 1 deletion pkg/apiserver/logsearch/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/gin-gonic/gin"
"github.com/pingcap/log"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/fx"
"go.uber.org/zap"

Expand All @@ -19,6 +20,8 @@ import (
"github.com/pingcap/tidb-dashboard/pkg/apiserver/utils"
"github.com/pingcap/tidb-dashboard/pkg/config"
"github.com/pingcap/tidb-dashboard/pkg/dbstore"
"github.com/pingcap/tidb-dashboard/pkg/pd"
"github.com/pingcap/tidb-dashboard/pkg/utils/topology"
"github.com/pingcap/tidb-dashboard/util/rest"
)

Expand All @@ -30,9 +33,11 @@ type Service struct {
logStoreDirectory string
db *dbstore.DB
scheduler *Scheduler
etcdClient *clientv3.Client
pdClient *pd.Client
}

func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB) *Service {
func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB, etcdClient *clientv3.Client, pdClient *pd.Client) *Service {
dir := config.TempDir
if dir == "" {
var err error
Expand All @@ -52,6 +57,8 @@ func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB) *Service
logStoreDirectory: dir,
db: db,
scheduler: nil, // will be filled after scheduler is created
etcdClient: etcdClient,
pdClient: pdClient,
}
scheduler := NewScheduler(service)
service.scheduler = scheduler
Expand Down Expand Up @@ -112,6 +119,10 @@ func (s *Service) CreateTaskGroup(c *gin.Context) {
rest.Error(c, rest.ErrBadRequest.New("Expect at least 1 target"))
return
}
if err := s.verifyTargets(c.Request.Context(), req.Targets); err != nil {
rest.Error(c, err)
return
}
stats := model.NewRequestTargetStatisticsFromArray(&req.Targets)
taskGroup := TaskGroupModel{
SearchRequest: &req.Request,
Expand Down Expand Up @@ -361,3 +372,157 @@ func (s *Service) DownloadLogs(c *gin.Context) {
serveMultipleTaskForDownload(tasks, c)
}
}

func (s *Service) verifyTargets(ctx context.Context, targets []model.RequestTargetNode) error {
kindToTargets := make(map[model.NodeKind][]model.RequestTargetNode)
for _, target := range targets {
kindToTargets[target.Kind] = append(kindToTargets[target.Kind], target)
}
var tikvInfos []topology.StoreInfo
var tiflashInfos []topology.StoreInfo
for kind, targets := range kindToTargets {
switch kind {
case model.NodeKindTiDB:
infos, err := topology.FetchTiDBTopology(ctx, s.etcdClient)
if err != nil {
log.Error("failed to fetch tidb topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTiKV, model.NodeKindTiFlash:
if len(tikvInfos) == 0 {
var err error
tikvInfos, tiflashInfos, err = topology.FetchStoreTopology(s.pdClient)
if err != nil {
log.Error("failed to fetch tikv/tiflash topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
}
for _, target := range targets {
matched := false
if kind == model.NodeKindTiKV {
for _, info := range tikvInfos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
} else {
for _, info := range tiflashInfos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindPD:
infos, err := topology.FetchPDTopology(s.pdClient)
if err != nil {
log.Error("failed to fetch pd topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTiCDC:
infos, err := topology.FetchTiCDCTopology(ctx, s.etcdClient)
if err != nil {
log.Error("failed to fetch ticdc topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTiProxy:
infos, err := topology.FetchTiProxyTopology(ctx, s.etcdClient)
if err != nil {
log.Error("failed to fetch tiproxy topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindTSO:
infos, err := topology.FetchTSOTopology(ctx, s.pdClient)
if err != nil {
log.Error("failed to fetch tso topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
case model.NodeKindScheduling:
infos, err := topology.FetchSchedulingTopology(ctx, s.pdClient)
if err != nil {
log.Error("failed to fetch scheduling topology", zap.Error(err))
return rest.ErrInternalServerError.NewWithNoMessage()
}
for _, target := range targets {
matched := false
for _, info := range infos {
if info.IP == target.IP && info.Port == uint(target.Port) {
matched = true
break
}
}
if !matched {
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
default:
return rest.ErrInvalidEndpoint.NewWithNoMessage()
}
}
return nil
}
3 changes: 1 addition & 2 deletions pkg/apiserver/logsearch/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"strconv"
"sync"
"time"
"unsafe"

"github.com/pingcap/kvproto/pkg/diagnosticspb"
"github.com/pingcap/log"
Expand Down Expand Up @@ -252,7 +251,7 @@ func (t *Task) searchLog(client diagnosticspb.DiagnosticsClient, targetType diag
}
for _, msg := range res.Messages {
line := logMessageToString(msg)
_, err := bufWriter.Write(*(*[]byte)(unsafe.Pointer(&line))) // #nosec
_, err := bufWriter.WriteString(line)
if err != nil {
t.setError(err)
return
Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/statement/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (s *Service) createPlanBinding(db *gorm.DB, planDigest string) (err error)
return errors.New("invalid planDigest")
}

query := db.Exec(fmt.Sprintf("CREATE GLOBAL BINDING FROM HISTORY USING PLAN DIGEST '%s'", planDigest))
query := db.Exec("CREATE GLOBAL BINDING FROM HISTORY USING PLAN DIGEST ?", planDigest)
return query.Error
}

Expand Down
35 changes: 27 additions & 8 deletions pkg/apiserver/user/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import (
"github.com/joomcode/errorx"
"github.com/pingcap/log"
"go.uber.org/zap"
"golang.org/x/time/rate"

"github.com/pingcap/tidb-dashboard/pkg/apiserver/utils"
"github.com/pingcap/tidb-dashboard/pkg/config"
"github.com/pingcap/tidb-dashboard/util/featureflag"
"github.com/pingcap/tidb-dashboard/util/rest"
)
Expand Down Expand Up @@ -78,7 +80,7 @@ func (a BaseAuthenticator) SignOutInfo(_ *utils.SessionUser, _ string) (*SignOut
return &SignOutInfo{}, nil
}

func NewAuthService(featureFlags *featureflag.Registry) *AuthService {
func NewAuthService(featureFlags *featureflag.Registry, config *config.Config) *AuthService {
var secret *[32]byte

secretStr := os.Getenv("DASHBOARD_SESSION_SECRET")
Expand Down Expand Up @@ -108,11 +110,12 @@ func NewAuthService(featureFlags *featureflag.Registry) *AuthService {
}

middleware, err := jwt.New(&jwt.GinJWTMiddleware{
IdentityKey: utils.SessionUserKey,
Realm: "dashboard",
Key: secret[:],
Timeout: time.Hour * 24,
MaxRefresh: time.Hour * 24,
IdentityKey: utils.SessionUserKey,
Realm: "dashboard",
Key: secret[:],
Timeout: time.Hour * 24,
MaxRefresh: time.Hour * 24,
SigningAlgorithm: config.SigningAlgorithm,
Authenticator: func(c *gin.Context) (interface{}, error) {
var form AuthenticateForm
if err := c.ShouldBindJSON(&form); err != nil {
Expand Down Expand Up @@ -244,10 +247,14 @@ func (s *AuthService) authForm(f AuthenticateForm) (*utils.SessionUser, error) {
return u, nil
}

func registerRouter(r *gin.RouterGroup, s *AuthService) {
func registerRouter(r *gin.RouterGroup, s *AuthService, cfg *config.Config) {
endpoint := r.Group("/user")
endpoint.GET("/login_info", s.GetLoginInfoHandler)
endpoint.POST("/login", s.LoginHandler)
if cfg.UnauthedAPIQpsLimit > 0 && cfg.UnauthedAPIBurstLimit > 0 {
endpoint.POST("/login", s.MWRateLimited(rate.Limit(cfg.UnauthedAPIQpsLimit), cfg.UnauthedAPIBurstLimit), s.LoginHandler)
} else {
endpoint.POST("/login", s.LoginHandler)
}
endpoint.GET("/sign_out_info", s.MWAuthRequired(), s.getSignOutInfoHandler)
}

Expand Down Expand Up @@ -293,6 +300,18 @@ func (s *AuthService) MWRequireWritePriv() gin.HandlerFunc {
}
}

func (s *AuthService) MWRateLimited(r rate.Limit, b int) gin.HandlerFunc {
limiter := rate.NewLimiter(r, b)
return func(ctx *gin.Context) {
if !limiter.Allow() {
rest.Error(ctx, rest.ErrTooManyRequests.NewWithNoMessage())
ctx.Abort()
return
}
ctx.Next()
}
}

// RegisterAuthenticator registers an authenticator in the authenticate pipeline.
func (s *AuthService) RegisterAuthenticator(typeID utils.AuthType, a Authenticator) {
s.authenticators[typeID] = a
Expand Down
7 changes: 7 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type Config struct {
FeatureVersion string // assign the target TiDB version when running TiDB Dashboard as standalone mode

NgmTimeout int // in seconds

SigningAlgorithm string
UnauthedAPIQpsLimit float64
UnauthedAPIBurstLimit int
}

func Default() *Config {
Expand All @@ -54,6 +58,9 @@ func Default() *Config {
DisableCustomPromAddr: false,
FeatureVersion: version.PDVersion,
NgmTimeout: 30, // s
SigningAlgorithm: "",
UnauthedAPIQpsLimit: 0,
UnauthedAPIBurstLimit: 0,
}
}

Expand Down
11 changes: 7 additions & 4 deletions util/rest/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ import (
)

var (
ErrUnauthenticated = errorx.CommonErrors.NewType("unauthenticated")
ErrForbidden = errorx.CommonErrors.NewType("forbidden")
ErrBadRequest = errorx.CommonErrors.NewType("bad_request")
ErrNotFound = errorx.CommonErrors.NewType("not_found")
ErrUnauthenticated = errorx.CommonErrors.NewType("unauthenticated")
ErrForbidden = errorx.CommonErrors.NewType("forbidden")
ErrBadRequest = errorx.CommonErrors.NewType("bad_request")
ErrNotFound = errorx.CommonErrors.NewType("not_found")
ErrTooManyRequests = errorx.CommonErrors.NewType("too_many_requests")
ErrInvalidEndpoint = errorx.CommonErrors.NewType("invalid_endpoint")
ErrInternalServerError = errorx.CommonErrors.NewType("internal_server_error")

errInternal = errorx.CommonErrors.NewType("internal")
propHTTPCode = errorx.RegisterProperty("http_code")
Expand Down
Loading