Skip to content

Commit

Permalink
refactor providers code
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Milchev <[email protected]>
  • Loading branch information
imilchev committed Feb 8, 2024
1 parent 4d9e59e commit bb9d905
Show file tree
Hide file tree
Showing 15 changed files with 458 additions and 220 deletions.
4 changes: 4 additions & 0 deletions providers-sdk/v1/plugin/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func (m *GRPCClient) Connect(req *ConnectReq, callback ProviderCallback) (*Conne
return m.client.Connect(context.Background(), req)
}

func (m *GRPCClient) Disconnect(req *DisconnectReq) (*DisconnectRes, error) {
return m.client.Disconnect(context.Background(), req)
}

func (m *GRPCClient) MockConnect(req *ConnectReq, callback ProviderCallback) (*ConnectRes, error) {
m.connect(req, callback)
return m.client.MockConnect(context.Background(), req)
Expand Down
5 changes: 5 additions & 0 deletions providers-sdk/v1/plugin/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ var PluginMap = map[string]plugin.Plugin{
"provider": &ProviderPluginImpl{},
}

type Closer interface {
Close()
}

type ProviderCallback interface {
Collect(req *DataRes) error
GetRecording(req *DataReq) (*ResourceData, error)
Expand All @@ -33,6 +37,7 @@ type ProviderPlugin interface {
Heartbeat(req *HeartbeatReq) (*HeartbeatRes, error)
ParseCLI(req *ParseCLIReq) (*ParseCLIRes, error)
Connect(req *ConnectReq, callback ProviderCallback) (*ConnectRes, error)
Disconnect(req *DisconnectReq) (*DisconnectRes, error)
MockConnect(req *ConnectReq, callback ProviderCallback) (*ConnectRes, error)
Shutdown(req *ShutdownReq) (*ShutdownRes, error)
GetData(req *DataReq) (*DataRes, error)
Expand Down
302 changes: 212 additions & 90 deletions providers-sdk/v1/plugin/plugin.pb.go

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions providers-sdk/v1/plugin/plugin.proto
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,19 @@ message HeartbeatRes {

}

message DisconnectReq {
uint32 connection = 1;
}

message DisconnectRes {

}

service ProviderPlugin {
rpc Heartbeat(HeartbeatReq) returns (HeartbeatRes);
rpc ParseCLI(ParseCLIReq) returns (ParseCLIRes);
rpc Connect(ConnectReq) returns (ConnectRes);
rpc Disconnect(DisconnectReq) returns (DisconnectRes);
rpc MockConnect(ConnectReq) returns (ConnectRes);
rpc Shutdown(ShutdownReq) returns (ShutdownRes);
rpc GetData(DataReq) returns (DataRes);
Expand Down
37 changes: 37 additions & 0 deletions providers-sdk/v1/plugin/plugin_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 13 additions & 2 deletions providers-sdk/v1/plugin/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,23 @@ type Runtime struct {
Callback ProviderCallback
HasRecording bool
CreateResource CreateNamedResource
NewResource NewResource
GetData GetData
SetData SetData
Upstream *upstream.UpstreamClient
}

type Connection interface{}
type Connection interface {
SetID(id uint32)
ID() uint32
}

type CreateNamedResource func(runtime *Runtime, name string, args map[string]*llx.RawData) (Resource, error)
type (
CreateNamedResource func(runtime *Runtime, name string, args map[string]*llx.RawData) (Resource, error)
NewResource func(runtime *Runtime, name string, args map[string]*llx.RawData) (Resource, error)
GetData func(resource Resource, field string, args map[string]*llx.RawData) *DataRes
SetData func(resource Resource, field string, val *llx.RawData) error
)

type Resource interface {
MqlID() string
Expand Down
143 changes: 138 additions & 5 deletions providers-sdk/v1/plugin/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,166 @@ package plugin
import (
"errors"
"os"
"strconv"
"strings"
sync "sync"
"time"

llx "go.mondoo.com/cnquery/v10/llx"
)

type Service struct {
runtimes map[uint32]*Runtime
lastConnectionID uint32
runtimesLock sync.Mutex

lastHeartbeat int64
lock sync.Mutex
heartbeatLock sync.Mutex
}

func NewService() *Service {
return &Service{
runtimes: make(map[uint32]*Runtime),
}
}

var heartbeatRes HeartbeatRes

func (s *Service) AddRuntime(runtime *Runtime) uint32 {
s.runtimesLock.Lock()
defer s.runtimesLock.Unlock()
s.lastConnectionID++
runtime.Connection.SetID(s.lastConnectionID)
s.runtimes[s.lastConnectionID] = runtime
return s.lastConnectionID
}

func (s *Service) GetRuntime(id uint32) (*Runtime, error) {
s.runtimesLock.Lock()
defer s.runtimesLock.Unlock()
if runtime, ok := s.runtimes[id]; ok {
return runtime, nil
}
return nil, errors.New("connection " + strconv.FormatUint(uint64(id), 10) + " not found")
}

func (s *Service) Disconnect(req *DisconnectReq) (*DisconnectRes, error) {
s.runtimesLock.Lock()
defer s.runtimesLock.Unlock()
if runtime, ok := s.runtimes[req.Connection]; ok {
// If the runtime implements the Closer interface, we need to call the
// Close function
if closer, ok := runtime.Connection.(Closer); ok {
closer.Close()
}
delete(s.runtimes, req.Connection)
}
return &DisconnectRes{}, nil
}

func (s *Service) GetData(req *DataReq) (*DataRes, error) {
runtime, ok := s.runtimes[req.Connection]
if !ok {
return nil, errors.New("connection " + strconv.FormatUint(uint64(req.Connection), 10) + " not found")
}

args := PrimitiveArgsToRawDataArgs(req.Args, runtime)

if req.ResourceId == "" && req.Field == "" {
res, err := runtime.NewResource(runtime, req.Resource, args)
if err != nil {
return nil, err
}

rd := llx.ResourceData(res, res.MqlName()).Result()
return &DataRes{
Data: rd.Data,
}, nil
}

resource, ok := runtime.Resources.Get(req.Resource + "\x00" + req.ResourceId)
if !ok {
// Note: Since resources are internally always created, there are only very
// few cases where we arrive here:
// 1. The caller is wrong. Possibly a mixup with IDs
// 2. The resource was loaded from a recording, but the field is not
// in the recording. Thus the resource was never created inside the
// plugin. We will attempt to create the resource and see if the field
// can be computed.
if !runtime.HasRecording {
return nil, errors.New("resource '" + req.Resource + "' (id: " + req.ResourceId + ") doesn't exist")
}

args, err := runtime.ResourceFromRecording(req.Resource, req.ResourceId)
if err != nil {
return nil, errors.New("attempted to load resource '" + req.Resource + "' (id: " + req.ResourceId + ") from recording failed: " + err.Error())
}

resource, err = runtime.CreateResource(runtime, req.Resource, args)
if err != nil {
return nil, errors.New("attempted to create resource '" + req.Resource + "' (id: " + req.ResourceId + ") from recording failed: " + err.Error())
}
}

return runtime.GetData(resource, req.Field, args), nil
}

func (s *Service) StoreData(req *StoreReq) (*StoreRes, error) {
runtime, ok := s.runtimes[req.Connection]
if !ok {
return nil, errors.New("connection " + strconv.FormatUint(uint64(req.Connection), 10) + " not found")
}

var errs []string
for i := range req.Resources {
info := req.Resources[i]

args, err := ProtoArgsToRawDataArgs(info.Fields)
if err != nil {
errs = append(errs, "failed to add cached "+info.Name+" (id: "+info.Id+"), failed to parse arguments")
continue
}

resource, ok := runtime.Resources.Get(info.Name + "\x00" + info.Id)
if !ok {
resource, err = runtime.CreateResource(runtime, info.Name, args)
if err != nil {
errs = append(errs, "failed to add cached "+info.Name+" (id: "+info.Id+"), creation failed: "+err.Error())
continue
}

runtime.Resources.Set(info.Name+"\x00"+info.Id, resource)
}

for k, v := range args {
if err := runtime.SetData(resource, k, v); err != nil {
errs = append(errs, "failed to add cached "+info.Name+" (id: "+info.Id+"), field error: "+err.Error())
}
}
}

if len(errs) != 0 {
return nil, errors.New(strings.Join(errs, ", "))
}
return &StoreRes{}, nil
}

func (s *Service) Heartbeat(req *HeartbeatReq) (*HeartbeatRes, error) {
if req.Interval == 0 {
return nil, errors.New("heartbeat failed, requested interval is 0")
}

now := time.Now().UnixNano()
s.lock.Lock()
s.heartbeatLock.Lock()
s.lastHeartbeat = now
s.lock.Unlock()
s.heartbeatLock.Unlock()

go func() {
time.Sleep(time.Duration(req.Interval))

s.lock.Lock()
s.heartbeatLock.Lock()
isDead := s.lastHeartbeat == now
s.lock.Unlock()
s.heartbeatLock.Unlock()

if isDead {
os.Exit(1)
Expand Down
6 changes: 5 additions & 1 deletion providers/k8s/connection/admission/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Connection struct {
}

// func newManifestProvider(selectedResourceID string, objectKind string, opts ...Option) (KubernetesProvider, error) {
func NewConnection(id uint32, asset *inventory.Asset, data string) (shared.Connection, error) {
func NewConnection(asset *inventory.Asset, data string) (shared.Connection, error) {
c := &Connection{
asset: asset,
namespace: asset.Connections[0].Options[shared.OPTION_NAMESPACE],
Expand Down Expand Up @@ -66,6 +66,10 @@ func (c *Connection) SupportedResourceTypes() (*resources.ApiResourceIndex, erro
return c.ManifestParser.SupportedResourceTypes()
}

func (c *Connection) SetID(id uint32) {
c.id = id
}

func (c *Connection) ID() uint32 {
return c.id
}
Expand Down
7 changes: 5 additions & 2 deletions providers/k8s/connection/api/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Connection struct {
currentClusterName string
}

func NewConnection(id uint32, asset *inventory.Asset, discoveryCache *resources.DiscoveryCache) (shared.Connection, error) {
func NewConnection(asset *inventory.Asset, discoveryCache *resources.DiscoveryCache) (shared.Connection, error) {
// check if the user .kube/config file exists
// NOTE: BuildConfigFromFlags falls back to cluster loading when .kube/config string is empty
// therefore we want to only change the kubeconfig string when the file really exists
Expand Down Expand Up @@ -99,7 +99,6 @@ func NewConnection(id uint32, asset *inventory.Asset, discoveryCache *resources.
}

res := Connection{
id: id,
asset: asset,
d: d,
config: config,
Expand All @@ -125,6 +124,10 @@ func buildConfigFromFlags(masterUrl, kubeconfigPath string, context string) (*re
&clientcmd.ConfigOverrides{ClusterInfo: clientcmdapi.Cluster{Server: masterUrl}, CurrentContext: context}).ClientConfig()
}

func (c *Connection) SetID(id uint32) {
c.id = id
}

func (c *Connection) ID() uint32 {
return c.id
}
Expand Down
Loading

0 comments on commit bb9d905

Please sign in to comment.