diff --git a/authentik/outposts/consumer.py b/authentik/outposts/consumer.py index 80b64999d5c4..3120a1d50ddc 100644 --- a/authentik/outposts/consumer.py +++ b/authentik/outposts/consumer.py @@ -128,6 +128,12 @@ def receive_json(self, content: Data, **kwargs): state.args.update(msg.args) elif msg.instruction == WebsocketMessageInstruction.ACK: return + elif msg.instruction == WebsocketMessageInstruction.PROVIDER_SPECIFIC: + if "response_channel" not in msg.args: + return + self.logger.debug("Posted response to channel", msg=msg) + async_to_sync(self.channel_layer.send)(msg.args.get("response_channel"), content) + return GAUGE_OUTPOSTS_LAST_UPDATE.labels( tenant=connection.schema_name, outpost=self.outpost.name, diff --git a/authentik/outposts/http.py b/authentik/outposts/http.py new file mode 100644 index 000000000000..3e684d72f3e1 --- /dev/null +++ b/authentik/outposts/http.py @@ -0,0 +1,86 @@ +from base64 import b64decode +from dataclasses import asdict, dataclass +from random import choice +from typing import Any +from uuid import uuid4 + +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer +from channels_redis.pubsub import RedisPubSubChannelLayer +from requests.adapters import BaseAdapter +from requests.models import PreparedRequest, Response +from requests.utils import CaseInsensitiveDict +from structlog.stdlib import get_logger + +from authentik.outposts.models import Outpost + + +@dataclass +class OutpostPreparedRequest: + uid: str + method: str + url: str + headers: dict[str, str] + body: Any + ssl_verify: bool + timeout: int + + @staticmethod + def from_requests(req: PreparedRequest) -> "OutpostPreparedRequest": + return OutpostPreparedRequest( + uid=str(uuid4()), + method=req.method, + url=req.url, + headers=req.headers._store, + body=req.body, + ssl_verify=True, + timeout=0, + ) + + @property + def response_channel(self) -> str: + return f"authentik_outpost_http_response_{self.uid}" + + +class OutpostHTTPAdapter(BaseAdapter): + """Requests Adapter that sends HTTP requests via a specified Outpost""" + + def __init__(self, outpost: Outpost, default_timeout=10): + super().__init__() + self.__outpost = outpost + self.__logger = get_logger().bind() + self.__layer: RedisPubSubChannelLayer = get_channel_layer() + self.default_timeout = default_timeout + + def parse_response(self, raw_response: dict, req: PreparedRequest) -> Response: + res = Response() + res.request = req + res.status_code = raw_response.get("status") + res.url = raw_response.get("final_url") + res.headers = CaseInsensitiveDict(raw_response.get("headers")) + res._content = b64decode(raw_response.get("body")) + return res + + def send(self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None): + # Convert request so we can send it to the outpost + converted = OutpostPreparedRequest.from_requests(request) + converted.ssl_verify = verify + converted.timeout = timeout if timeout else self.default_timeout + # Pick one of the outpost instances + state = choice(self.__outpost.state) # nosec + self.__logger.debug("sending HTTP request to outpost", uid=converted.uid) + async_to_sync(self.__layer.send)( + state.uid, + { + "type": "event.provider.specific", + "sub_type": "http_request", + "response_channel": converted.response_channel, + "request": asdict(converted), + }, + ) + self.__logger.debug("receiving HTTP response from outpost", uid=converted.uid) + raw_response = async_to_sync(self.__layer.receive)( + converted.response_channel, + ) + self.__logger.debug("received HTTP response from outpost", uid=converted.uid) + return self.parse_response(raw_response.get("args", {}).get("response", {}), request) diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index 4032892fe870..69cbe0321bbb 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -98,6 +98,7 @@ class OutpostType(models.TextChoices): LDAP = "ldap" RADIUS = "radius" RAC = "rac" + SCIM = "scim" def default_outpost_config(host: str | None = None): diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 7a80ce9be439..3d9b63f3316a 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -43,13 +43,15 @@ from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController from authentik.providers.radius.controllers.docker import RadiusDockerController from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController +from authentik.providers.scim.controllers.docker import SCIMDockerController +from authentik.providers.scim.controllers.kubernetes import SCIMKubernetesController from authentik.root.celery import CELERY_APP LOGGER = get_logger() CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" -def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: +def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: # noqa: PLR0911 """Get a controller for the outpost, when a service connection is defined""" if not outpost.service_connection: return None @@ -74,6 +76,11 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: return RACDockerController if isinstance(service_connection, KubernetesServiceConnection): return RACKubernetesController + if outpost.type == OutpostType.SCIM: + if isinstance(service_connection, DockerServiceConnection): + return SCIMDockerController + if isinstance(service_connection, KubernetesServiceConnection): + return SCIMKubernetesController return None diff --git a/authentik/providers/scim/clients/base.py b/authentik/providers/scim/clients/base.py index 246520114c83..d9a78e1d668d 100644 --- a/authentik/providers/scim/clients/base.py +++ b/authentik/providers/scim/clients/base.py @@ -19,6 +19,7 @@ TransientSyncException, ) from authentik.lib.utils.http import get_http_session +from authentik.outposts.http import OutpostHTTPAdapter from authentik.providers.scim.clients.exceptions import SCIMRequestException from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.models import SCIMProvider @@ -41,8 +42,7 @@ class SCIMClient[TModel: "Model", TConnection: "Model", TSchema: "BaseModel"]( def __init__(self, provider: SCIMProvider): super().__init__(provider) - self._session = get_http_session() - self._session.verify = provider.verify_certificates + self._session = self.get_session(provider) self.provider = provider # Remove trailing slashes as we assume the URL doesn't have any base_url = provider.url @@ -52,6 +52,15 @@ def __init__(self, provider: SCIMProvider): self.token = provider.token self._config = self.get_service_provider_config() + def get_session(self, provider: SCIMProvider): + session = get_http_session() + if self.provider.outpost_set.exists(): + adapter = OutpostHTTPAdapter() + session.mount("https://", adapter) + session.mount("http://", adapter) + session.verify = provider.verify_certificates + return session + def _request(self, method: str, path: str, **kwargs) -> dict: """Wrapper to send a request to the full URL""" try: diff --git a/authentik/providers/scim/controllers/__init__.py b/authentik/providers/scim/controllers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/authentik/providers/scim/controllers/docker.py b/authentik/providers/scim/controllers/docker.py new file mode 100644 index 000000000000..ac8af400f96f --- /dev/null +++ b/authentik/providers/scim/controllers/docker.py @@ -0,0 +1,12 @@ +"""SCIM Provider Docker Controller""" + +from authentik.outposts.controllers.docker import DockerController +from authentik.outposts.models import DockerServiceConnection, Outpost + + +class SCIMDockerController(DockerController): + """SCIM Provider Docker Controller""" + + def __init__(self, outpost: Outpost, connection: DockerServiceConnection): + super().__init__(outpost, connection) + self.deployment_ports = [] diff --git a/authentik/providers/scim/controllers/kubernetes.py b/authentik/providers/scim/controllers/kubernetes.py new file mode 100644 index 000000000000..998f3222020e --- /dev/null +++ b/authentik/providers/scim/controllers/kubernetes.py @@ -0,0 +1,14 @@ +"""SCIM Provider Kubernetes Controller""" + +from authentik.outposts.controllers.k8s.service import ServiceReconciler +from authentik.outposts.controllers.kubernetes import KubernetesController +from authentik.outposts.models import KubernetesServiceConnection, Outpost + + +class SCIMKubernetesController(KubernetesController): + """SCIM Provider Kubernetes Controller""" + + def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection): + super().__init__(outpost, connection) + self.deployment_ports = [] + del self.reconcilers[ServiceReconciler.reconciler_name()] diff --git a/blueprints/schema.json b/blueprints/schema.json index 9b3b91eb7419..7dc89930b21d 100644 --- a/blueprints/schema.json +++ b/blueprints/schema.json @@ -4264,7 +4264,8 @@ "proxy", "ldap", "radius", - "rac" + "rac", + "scim" ], "title": "Type" }, @@ -6974,7 +6975,7 @@ "spnego_server_name": { "type": "string", "title": "Spnego server name", - "description": "Force the use of a specific server name for SPNEGO" + "description": "Force the use of a specific server name for SPNEGO. Must be in the form HTTP@hostname" }, "spnego_keytab": { "type": "string", diff --git a/cmd/scim/main.go b/cmd/scim/main.go new file mode 100644 index 000000000000..b3f96a5501f4 --- /dev/null +++ b/cmd/scim/main.go @@ -0,0 +1,178 @@ +package main + +import ( + "context" + "crypto/tls" + "encoding/base64" + "fmt" + "io" + "net/http" + "net/url" + "os" + + "github.com/mitchellh/mapstructure" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "goauthentik.io/internal/common" + "goauthentik.io/internal/debug" + "goauthentik.io/internal/outpost/ak" + "goauthentik.io/internal/outpost/ak/healthcheck" +) + +const helpMessage = `authentik SCIM + +Required environment variables: +- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company") +- AUTHENTIK_TOKEN: Token to authenticate with +- AUTHENTIK_INSECURE: Skip SSL Certificate verification` + +var rootCmd = &cobra.Command{ + Long: helpMessage, + PersistentPreRun: func(cmd *cobra.Command, args []string) { + log.SetLevel(log.DebugLevel) + log.SetFormatter(&log.JSONFormatter{ + FieldMap: log.FieldMap{ + log.FieldKeyMsg: "event", + log.FieldKeyTime: "timestamp", + }, + DisableHTMLEscape: true, + }) + }, + Run: func(cmd *cobra.Command, args []string) { + debug.EnableDebugServer() + akURL, found := os.LookupEnv("AUTHENTIK_HOST") + if !found { + fmt.Println("env AUTHENTIK_HOST not set!") + fmt.Println(helpMessage) + os.Exit(1) + } + akToken, found := os.LookupEnv("AUTHENTIK_TOKEN") + if !found { + fmt.Println("env AUTHENTIK_TOKEN not set!") + fmt.Println(helpMessage) + os.Exit(1) + } + + akURLActual, err := url.Parse(akURL) + if err != nil { + fmt.Println(err) + fmt.Println(helpMessage) + os.Exit(1) + } + + ex := common.Init() + defer common.Defer() + go func() { + for { + <-ex + os.Exit(0) + } + }() + + ac := ak.NewAPIController(*akURLActual, akToken) + if ac == nil { + os.Exit(1) + } + defer ac.Shutdown() + + ac.Server = &SCIMOutpost{ + ac: ac, + log: log.WithField("logger", "authentik.outpost.scim"), + } + + err = ac.Start() + if err != nil { + log.WithError(err).Panic("Failed to run server") + } + + for { + <-ex + } + }, +} + +type HTTPRequest struct { + Uid string `mapstructure:"uid"` + Method string `mapstructure:"method"` + URL string `mapstructure:"url"` + Headers map[string][]string `mapstructure:"headers"` + Body interface{} `mapstructure:"body"` + SSLVerify bool `mapstructure:"ssl_verify"` + Timeout int `mapstructure:"timeout"` +} + +type RequestArgs struct { + Request HTTPRequest `mapstructure:"request"` + ResponseChannel string `mapstructure:"response_channel"` +} + +type SCIMOutpost struct { + ac *ak.APIController + log *log.Entry +} + +func (s *SCIMOutpost) Type() string { return "SCIM" } +func (s *SCIMOutpost) Stop() error { return nil } +func (s *SCIMOutpost) Refresh() error { return nil } +func (s *SCIMOutpost) TimerFlowCacheExpiry(context.Context) {} + +func (s *SCIMOutpost) Start() error { + s.ac.AddWSHandler(func(ctx context.Context, args map[string]interface{}) { + rd := RequestArgs{} + err := mapstructure.Decode(args, &rd) + if err != nil { + s.log.WithError(err).Warning("failed to parse http request") + return + } + s.log.WithField("rd", rd).WithField("raw", args).Debug("request data") + req, err := http.NewRequest(rd.Request.Method, rd.Request.URL, nil) + if err != nil { + s.log.WithError(err).Warning("failed to create request") + return + } + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: !rd.Request.SSLVerify}, + // todo: timeout + } + c := &http.Client{ + Transport: tr, + } + s.log.WithField("url", req.URL.Host).Debug("sending HTTP request") + res, err := c.Do(req) + if err != nil { + s.log.WithError(err).Warning("failed to send request") + return + } + body, err := io.ReadAll(res.Body) + if err != nil { + s.log.WithError(err).Warning("failed to read body") + return + } + s.log.WithField("res", res.StatusCode).Debug("sending HTTP response") + err = s.ac.SendWS(ak.WebsocketInstructionProviderSpecific, map[string]interface{}{ + "sub_type": "http_response", + "response_channel": rd.ResponseChannel, + "response": map[string]interface{}{ + "status": res.StatusCode, + "final_url": res.Request.URL.String(), + "headers": res.Header, + "body": base64.StdEncoding.EncodeToString(body), + }, + }) + if err != nil { + s.log.WithError(err).Warning("failed to send http response") + return + } + }) + return nil +} + +func main() { + rootCmd.AddCommand(healthcheck.Command) + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} diff --git a/internal/outpost/ak/api.go b/internal/outpost/ak/api.go index 2def0642731b..9440b78f0cd1 100644 --- a/internal/outpost/ak/api.go +++ b/internal/outpost/ak/api.go @@ -90,7 +90,7 @@ func NewAPIController(akURL url.URL, token string) *APIController { time.Sleep(time.Second * 3) } if len(outposts.Results) < 1 { - panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") + panic("No outposts found with given token, ensure the given token corresponds to an authentik Outpost") } outpost := outposts.Results[0] diff --git a/internal/outpost/ak/api_ws.go b/internal/outpost/ak/api_ws.go index cda7bd03d2b0..9a573099a554 100644 --- a/internal/outpost/ak/api_ws.go +++ b/internal/outpost/ak/api_ws.go @@ -218,15 +218,19 @@ func (a *APIController) AddWSHandler(handler WSHandler) { a.wsHandlers = append(a.wsHandlers, handler) } +func (a *APIController) SendWS(inst WebsocketInstruction, args map[string]interface{}) error { + msg := websocketMessage{ + Instruction: inst, + Args: args, + } + err := a.wsConn.WriteJSON(msg) + return err +} + func (a *APIController) SendWSHello(args map[string]interface{}) error { allArgs := a.getWebsocketPingArgs() for key, value := range args { allArgs[key] = value } - aliveMsg := websocketMessage{ - Instruction: WebsocketInstructionHello, - Args: allArgs, - } - err := a.wsConn.WriteJSON(aliveMsg) - return err + return a.SendWS(WebsocketInstructionHello, args) } diff --git a/internal/outpost/ak/api_ws_msg.go b/internal/outpost/ak/api_ws_msg.go index cedecb93d5d5..1d11860b3f87 100644 --- a/internal/outpost/ak/api_ws_msg.go +++ b/internal/outpost/ak/api_ws_msg.go @@ -1,19 +1,19 @@ package ak -type websocketInstruction int +type WebsocketInstruction int const ( // WebsocketInstructionAck Code used to acknowledge a previous message - WebsocketInstructionAck websocketInstruction = 0 + WebsocketInstructionAck WebsocketInstruction = 0 // WebsocketInstructionHello Code used to send a healthcheck keepalive - WebsocketInstructionHello websocketInstruction = 1 + WebsocketInstructionHello WebsocketInstruction = 1 // WebsocketInstructionTriggerUpdate Code received to trigger a config update - WebsocketInstructionTriggerUpdate websocketInstruction = 2 + WebsocketInstructionTriggerUpdate WebsocketInstruction = 2 // WebsocketInstructionProviderSpecific Code received to trigger some provider specific function - WebsocketInstructionProviderSpecific websocketInstruction = 3 + WebsocketInstructionProviderSpecific WebsocketInstruction = 3 ) type websocketMessage struct { - Instruction websocketInstruction `json:"instruction"` + Instruction WebsocketInstruction `json:"instruction"` Args map[string]interface{} `json:"args"` } diff --git a/schema.yml b/schema.yml index d4f3eb78ac7b..04af924923b9 100644 --- a/schema.yml +++ b/schema.yml @@ -42943,7 +42943,8 @@ components: readOnly: true spnego_server_name: type: string - description: Force the use of a specific server name for SPNEGO + description: Force the use of a specific server name for SPNEGO. Must be + in the form HTTP@hostname spnego_ccache: type: string description: Credential cache to use for SPNEGO in form type:residual @@ -43112,7 +43113,8 @@ components: be in the form TYPE:residual spnego_server_name: type: string - description: Force the use of a specific server name for SPNEGO + description: Force the use of a specific server name for SPNEGO. Must be + in the form HTTP@hostname spnego_keytab: type: string writeOnly: true @@ -45445,6 +45447,7 @@ components: - ldap - radius - rac + - scim type: string PaginatedApplicationList: type: object @@ -48410,7 +48413,8 @@ components: be in the form TYPE:residual spnego_server_name: type: string - description: Force the use of a specific server name for SPNEGO + description: Force the use of a specific server name for SPNEGO. Must be + in the form HTTP@hostname spnego_keytab: type: string writeOnly: true diff --git a/web/src/admin/outposts/OutpostForm.ts b/web/src/admin/outposts/OutpostForm.ts index 3c276caaf7ae..f5f2e5a06608 100644 --- a/web/src/admin/outposts/OutpostForm.ts +++ b/web/src/admin/outposts/OutpostForm.ts @@ -73,6 +73,9 @@ const radiusListFetch = async (page: number, search = "") => const racListProvider = async (page: number, search = "") => provisionMaker(await api().providersRacList(providerListArgs(page, search))); +const scimListProvider = async (page: number, search = "") => + provisionMaker(await api().providersScimList(providerListArgs(page, search))); + function providerProvider(type: OutpostTypeEnum): DataProvider { switch (type) { case OutpostTypeEnum.Proxy: @@ -83,6 +86,8 @@ function providerProvider(type: OutpostTypeEnum): DataProvider { return radiusListFetch; case OutpostTypeEnum.Rac: return racListProvider; + case OutpostTypeEnum.Scim: + return scimListProvider; default: throw new Error(`Unrecognized OutputType: ${type}`); } @@ -142,6 +147,7 @@ export class OutpostForm extends ModelForm { [OutpostTypeEnum.Ldap, msg("LDAP")], [OutpostTypeEnum.Radius, msg("Radius")], [OutpostTypeEnum.Rac, msg("RAC")], + [OutpostTypeEnum.Scim, msg("SCIM")], ]; return html`