diff --git a/cmd/tidb-dashboard/main.go b/cmd/tidb-dashboard/main.go index b687c9c86..a342b34b3 100755 --- a/cmd/tidb-dashboard/main.go +++ b/cmd/tidb-dashboard/main.go @@ -71,6 +71,9 @@ func NewCLIConfig() *DashboardCLIConfig { flag.IntVar(&cfg.CoreConfig.NgmTimeout, "ngm-timeout", cfg.CoreConfig.NgmTimeout, "timeout secs for accessing the ngm API") flag.BoolVar(&cfg.CoreConfig.EnableKeyVisualizer, "keyviz", true, "enable/disable key visualizer(default: true)") flag.BoolVar(&cfg.CoreConfig.DisableCustomPromAddr, "disable-custom-prom-addr", false, "do not allow custom prometheus address") + flag.StringVar(&cfg.CoreConfig.SigningAlgorithm, "signing-algorithm", cfg.CoreConfig.SigningAlgorithm, "signing algorithm for jwt (HS256, HS384, HS512, RS256, RS384, RS512)") + flag.Float64Var(&cfg.CoreConfig.UnauthedAPIQpsLimit, "unauthed-api-qps-limit", cfg.CoreConfig.UnauthedAPIQpsLimit, "unauthed API qps limit") + flag.IntVar(&cfg.CoreConfig.UnauthedAPIBurstLimit, "unauthed-api-burst-limit", cfg.CoreConfig.UnauthedAPIBurstLimit, "unauthed API burst limit") showVersion := flag.BoolP("version", "v", false, "print version information and exit") diff --git a/go.mod b/go.mod index 859619d31..8d5f382d3 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( go.uber.org/zap v1.19.0 golang.org/x/oauth2 v0.11.0 golang.org/x/sync v0.3.0 + golang.org/x/time v0.6.0 google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.33.0 gorm.io/datatypes v1.1.0 diff --git a/go.sum b/go.sum index 2f022d241..a1e151fef 100644 --- a/go.sum +++ b/go.sum @@ -519,6 +519,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/apiserver/logsearch/service.go b/pkg/apiserver/logsearch/service.go index 90ff16b4d..96553b250 100644 --- a/pkg/apiserver/logsearch/service.go +++ b/pkg/apiserver/logsearch/service.go @@ -11,6 +11,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pingcap/log" + clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/fx" "go.uber.org/zap" @@ -19,6 +20,8 @@ import ( "github.com/pingcap/tidb-dashboard/pkg/apiserver/utils" "github.com/pingcap/tidb-dashboard/pkg/config" "github.com/pingcap/tidb-dashboard/pkg/dbstore" + "github.com/pingcap/tidb-dashboard/pkg/pd" + "github.com/pingcap/tidb-dashboard/pkg/utils/topology" "github.com/pingcap/tidb-dashboard/util/rest" ) @@ -30,9 +33,11 @@ type Service struct { logStoreDirectory string db *dbstore.DB scheduler *Scheduler + etcdClient *clientv3.Client + pdClient *pd.Client } -func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB) *Service { +func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB, etcdClient *clientv3.Client, pdClient *pd.Client) *Service { dir := config.TempDir if dir == "" { var err error @@ -52,6 +57,8 @@ func NewService(lc fx.Lifecycle, config *config.Config, db *dbstore.DB) *Service logStoreDirectory: dir, db: db, scheduler: nil, // will be filled after scheduler is created + etcdClient: etcdClient, + pdClient: pdClient, } scheduler := NewScheduler(service) service.scheduler = scheduler @@ -112,6 +119,10 @@ func (s *Service) CreateTaskGroup(c *gin.Context) { rest.Error(c, rest.ErrBadRequest.New("Expect at least 1 target")) return } + if err := s.verifyTargets(c.Request.Context(), req.Targets); err != nil { + rest.Error(c, err) + return + } stats := model.NewRequestTargetStatisticsFromArray(&req.Targets) taskGroup := TaskGroupModel{ SearchRequest: &req.Request, @@ -361,3 +372,157 @@ func (s *Service) DownloadLogs(c *gin.Context) { serveMultipleTaskForDownload(tasks, c) } } + +func (s *Service) verifyTargets(ctx context.Context, targets []model.RequestTargetNode) error { + kindToTargets := make(map[model.NodeKind][]model.RequestTargetNode) + for _, target := range targets { + kindToTargets[target.Kind] = append(kindToTargets[target.Kind], target) + } + var tikvInfos []topology.StoreInfo + var tiflashInfos []topology.StoreInfo + for kind, targets := range kindToTargets { + switch kind { + case model.NodeKindTiDB: + infos, err := topology.FetchTiDBTopology(ctx, s.etcdClient) + if err != nil { + log.Error("failed to fetch tidb topology", zap.Error(err)) + return rest.ErrInternalServerError.NewWithNoMessage() + } + for _, target := range targets { + matched := false + for _, info := range infos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + if !matched { + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + case model.NodeKindTiKV, model.NodeKindTiFlash: + if len(tikvInfos) == 0 { + var err error + tikvInfos, tiflashInfos, err = topology.FetchStoreTopology(s.pdClient) + if err != nil { + log.Error("failed to fetch tikv/tiflash topology", zap.Error(err)) + return rest.ErrInternalServerError.NewWithNoMessage() + } + } + for _, target := range targets { + matched := false + if kind == model.NodeKindTiKV { + for _, info := range tikvInfos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + } else { + for _, info := range tiflashInfos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + } + if !matched { + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + case model.NodeKindPD: + infos, err := topology.FetchPDTopology(s.pdClient) + if err != nil { + log.Error("failed to fetch pd topology", zap.Error(err)) + return rest.ErrInternalServerError.NewWithNoMessage() + } + for _, target := range targets { + matched := false + for _, info := range infos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + if !matched { + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + case model.NodeKindTiCDC: + infos, err := topology.FetchTiCDCTopology(ctx, s.etcdClient) + if err != nil { + log.Error("failed to fetch ticdc topology", zap.Error(err)) + return rest.ErrInternalServerError.NewWithNoMessage() + } + for _, target := range targets { + matched := false + for _, info := range infos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + if !matched { + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + case model.NodeKindTiProxy: + infos, err := topology.FetchTiProxyTopology(ctx, s.etcdClient) + if err != nil { + log.Error("failed to fetch tiproxy topology", zap.Error(err)) + return rest.ErrInternalServerError.NewWithNoMessage() + } + for _, target := range targets { + matched := false + for _, info := range infos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + if !matched { + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + case model.NodeKindTSO: + infos, err := topology.FetchTSOTopology(ctx, s.pdClient) + if err != nil { + log.Error("failed to fetch tso topology", zap.Error(err)) + return rest.ErrInternalServerError.NewWithNoMessage() + } + for _, target := range targets { + matched := false + for _, info := range infos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + if !matched { + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + case model.NodeKindScheduling: + infos, err := topology.FetchSchedulingTopology(ctx, s.pdClient) + if err != nil { + log.Error("failed to fetch scheduling topology", zap.Error(err)) + return rest.ErrInternalServerError.NewWithNoMessage() + } + for _, target := range targets { + matched := false + for _, info := range infos { + if info.IP == target.IP && info.Port == uint(target.Port) { + matched = true + break + } + } + if !matched { + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + default: + return rest.ErrInvalidEndpoint.NewWithNoMessage() + } + } + return nil +} diff --git a/pkg/apiserver/logsearch/task.go b/pkg/apiserver/logsearch/task.go index 09f04bb07..ead03b7f0 100644 --- a/pkg/apiserver/logsearch/task.go +++ b/pkg/apiserver/logsearch/task.go @@ -16,7 +16,6 @@ import ( "strconv" "sync" "time" - "unsafe" "github.com/pingcap/kvproto/pkg/diagnosticspb" "github.com/pingcap/log" @@ -252,7 +251,7 @@ func (t *Task) searchLog(client diagnosticspb.DiagnosticsClient, targetType diag } for _, msg := range res.Messages { line := logMessageToString(msg) - _, err := bufWriter.Write(*(*[]byte)(unsafe.Pointer(&line))) // #nosec + _, err := bufWriter.WriteString(line) if err != nil { t.setError(err) return diff --git a/pkg/apiserver/statement/queries.go b/pkg/apiserver/statement/queries.go index 16deb6a0a..1fa0cdcf4 100644 --- a/pkg/apiserver/statement/queries.go +++ b/pkg/apiserver/statement/queries.go @@ -218,7 +218,7 @@ func (s *Service) createPlanBinding(db *gorm.DB, planDigest string) (err error) return errors.New("invalid planDigest") } - query := db.Exec(fmt.Sprintf("CREATE GLOBAL BINDING FROM HISTORY USING PLAN DIGEST '%s'", planDigest)) + query := db.Exec("CREATE GLOBAL BINDING FROM HISTORY USING PLAN DIGEST ?", planDigest) return query.Error } diff --git a/pkg/apiserver/user/auth.go b/pkg/apiserver/user/auth.go index 31250b575..62231ca50 100644 --- a/pkg/apiserver/user/auth.go +++ b/pkg/apiserver/user/auth.go @@ -18,8 +18,10 @@ import ( "github.com/joomcode/errorx" "github.com/pingcap/log" "go.uber.org/zap" + "golang.org/x/time/rate" "github.com/pingcap/tidb-dashboard/pkg/apiserver/utils" + "github.com/pingcap/tidb-dashboard/pkg/config" "github.com/pingcap/tidb-dashboard/util/featureflag" "github.com/pingcap/tidb-dashboard/util/rest" ) @@ -78,7 +80,7 @@ func (a BaseAuthenticator) SignOutInfo(_ *utils.SessionUser, _ string) (*SignOut return &SignOutInfo{}, nil } -func NewAuthService(featureFlags *featureflag.Registry) *AuthService { +func NewAuthService(featureFlags *featureflag.Registry, config *config.Config) *AuthService { var secret *[32]byte secretStr := os.Getenv("DASHBOARD_SESSION_SECRET") @@ -108,11 +110,12 @@ func NewAuthService(featureFlags *featureflag.Registry) *AuthService { } middleware, err := jwt.New(&jwt.GinJWTMiddleware{ - IdentityKey: utils.SessionUserKey, - Realm: "dashboard", - Key: secret[:], - Timeout: time.Hour * 24, - MaxRefresh: time.Hour * 24, + IdentityKey: utils.SessionUserKey, + Realm: "dashboard", + Key: secret[:], + Timeout: time.Hour * 24, + MaxRefresh: time.Hour * 24, + SigningAlgorithm: config.SigningAlgorithm, Authenticator: func(c *gin.Context) (interface{}, error) { var form AuthenticateForm if err := c.ShouldBindJSON(&form); err != nil { @@ -244,10 +247,14 @@ func (s *AuthService) authForm(f AuthenticateForm) (*utils.SessionUser, error) { return u, nil } -func registerRouter(r *gin.RouterGroup, s *AuthService) { +func registerRouter(r *gin.RouterGroup, s *AuthService, cfg *config.Config) { endpoint := r.Group("/user") endpoint.GET("/login_info", s.GetLoginInfoHandler) - endpoint.POST("/login", s.LoginHandler) + if cfg.UnauthedAPIQpsLimit > 0 && cfg.UnauthedAPIBurstLimit > 0 { + endpoint.POST("/login", s.MWRateLimited(rate.Limit(cfg.UnauthedAPIQpsLimit), cfg.UnauthedAPIBurstLimit), s.LoginHandler) + } else { + endpoint.POST("/login", s.LoginHandler) + } endpoint.GET("/sign_out_info", s.MWAuthRequired(), s.getSignOutInfoHandler) } @@ -293,6 +300,18 @@ func (s *AuthService) MWRequireWritePriv() gin.HandlerFunc { } } +func (s *AuthService) MWRateLimited(r rate.Limit, b int) gin.HandlerFunc { + limiter := rate.NewLimiter(r, b) + return func(ctx *gin.Context) { + if !limiter.Allow() { + rest.Error(ctx, rest.ErrTooManyRequests.NewWithNoMessage()) + ctx.Abort() + return + } + ctx.Next() + } +} + // RegisterAuthenticator registers an authenticator in the authenticate pipeline. func (s *AuthService) RegisterAuthenticator(typeID utils.AuthType, a Authenticator) { s.authenticators[typeID] = a diff --git a/pkg/config/config.go b/pkg/config/config.go index de311dafe..e748a5f8a 100755 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -37,6 +37,10 @@ type Config struct { FeatureVersion string // assign the target TiDB version when running TiDB Dashboard as standalone mode NgmTimeout int // in seconds + + SigningAlgorithm string + UnauthedAPIQpsLimit float64 + UnauthedAPIBurstLimit int } func Default() *Config { @@ -54,6 +58,9 @@ func Default() *Config { DisableCustomPromAddr: false, FeatureVersion: version.PDVersion, NgmTimeout: 30, // s + SigningAlgorithm: "", + UnauthedAPIQpsLimit: 0, + UnauthedAPIBurstLimit: 0, } } diff --git a/util/rest/error.go b/util/rest/error.go index 82be647c9..97bd9a967 100644 --- a/util/rest/error.go +++ b/util/rest/error.go @@ -12,10 +12,13 @@ import ( ) var ( - ErrUnauthenticated = errorx.CommonErrors.NewType("unauthenticated") - ErrForbidden = errorx.CommonErrors.NewType("forbidden") - ErrBadRequest = errorx.CommonErrors.NewType("bad_request") - ErrNotFound = errorx.CommonErrors.NewType("not_found") + ErrUnauthenticated = errorx.CommonErrors.NewType("unauthenticated") + ErrForbidden = errorx.CommonErrors.NewType("forbidden") + ErrBadRequest = errorx.CommonErrors.NewType("bad_request") + ErrNotFound = errorx.CommonErrors.NewType("not_found") + ErrTooManyRequests = errorx.CommonErrors.NewType("too_many_requests") + ErrInvalidEndpoint = errorx.CommonErrors.NewType("invalid_endpoint") + ErrInternalServerError = errorx.CommonErrors.NewType("internal_server_error") errInternal = errorx.CommonErrors.NewType("internal") propHTTPCode = errorx.RegisterProperty("http_code")