From 7d9945d764ef4076829992d0e6ba5224a7c45dd0 Mon Sep 17 00:00:00 2001 From: Omar Date: Tue, 28 Nov 2023 17:33:12 +0000 Subject: [PATCH] Use enclave server port as the key for tunnel tracking; hook up removal of port forwards --- .../port_forward_manager.go | 18 ++++++++++++++++- .../tunnel_session_tracker.go | 20 +++++++++---------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/portal/daemon/port_forward_manager/port_forward_manager.go b/portal/daemon/port_forward_manager/port_forward_manager.go index 0aed6134e0..3a7de452fb 100644 --- a/portal/daemon/port_forward_manager/port_forward_manager.go +++ b/portal/daemon/port_forward_manager/port_forward_manager.go @@ -78,8 +78,16 @@ func (manager *PortForwardManager) CreateUserServicePortForward(ctx context.Cont return allBoundPorts, nil } +// RemoveUserServicePortForward +// here we only stop a single session at a time, so require all of enclaveId, serviceId, portId, to be specified func (manager *PortForwardManager) RemoveUserServicePortForward(ctx context.Context, enclaveServicePort EnclaveServicePort) error { - panic("implement me") + if err := validateRemoveUserServicePortForwardArgs(enclaveServicePort); err != nil { + return stacktrace.Propagate(err, "Validation failed for arguments") + } + + manager.tunnelSessionTracker.StopForwardingPort(enclaveServicePort) + + return nil } func (manager *PortForwardManager) createAndOpenEphemeralPortForwardsToUserServices(serviceInterfaceDetails []*ServiceInterfaceDetail) (map[EnclaveServicePort]uint16, error) { @@ -124,6 +132,14 @@ func validateCreateUserServicePortForwardArgs(enclaveServicePort EnclaveServiceP return nil } +// Removal only works for specific service/ports, so make sure all fields are populated +func validateRemoveUserServicePortForwardArgs(enclaveServicePort EnclaveServicePort) error { + if enclaveServicePort.EnclaveId() == "" || enclaveServicePort.ServiceId() == "" || enclaveServicePort.PortId() == "" { + return stacktrace.NewError("All of enclaveId, serviceId, and portId, must be specified for removal of a port forward: %v", enclaveServicePort) + } + return nil +} + func getLocalChiselServerUri(localPortToChiselServer uint16) string { return "localhost:" + strconv.Itoa(int(localPortToChiselServer)) } diff --git a/portal/daemon/port_forward_manager/tunnel_session_tracker.go b/portal/daemon/port_forward_manager/tunnel_session_tracker.go index 13887d25b8..fc14a7da83 100644 --- a/portal/daemon/port_forward_manager/tunnel_session_tracker.go +++ b/portal/daemon/port_forward_manager/tunnel_session_tracker.go @@ -8,13 +8,12 @@ import ( // TODO(omar): there will be some complexity in cases where ephemeral port binds are upgraded to static type TunnelSessionTracker struct { - // TODO(omar): hash key here probably needs sorting / verifying, due to the pointer it carries - activePortForwards map[*ServiceInterfaceDetail]*PortForwardTunnel + activePortForwards map[EnclaveServicePort]*PortForwardTunnel } func NewTunnelSessionTracker() *TunnelSessionTracker { return &TunnelSessionTracker{ - map[*ServiceInterfaceDetail]*PortForwardTunnel{}, + map[EnclaveServicePort]*PortForwardTunnel{}, } } @@ -34,19 +33,20 @@ func (tracker *TunnelSessionTracker) CreateAndOpenPortForward(serviceInterfaceDe } // TODO(omar): do we need to wait until port is fully open? - tracker.addPortForward(serviceInterfaceDetail, portForward) + tracker.addPortForward(serviceInterfaceDetail.enclaveServicePort, portForward) return portForward.localPortNumber, nil } -func (tracker *TunnelSessionTracker) StopForwardingPort(serviceInterfaceDetail *ServiceInterfaceDetail) { +func (tracker *TunnelSessionTracker) StopForwardingPort(enclaveServicePort EnclaveServicePort) { // TODO(omar): i don't think we care about stopping sessions that have been removed right now // this depends on where we go wrt to monitoring and cleaning up dead sessions, so I'll see how that // evolves prior to doing anything here - portForward, _ := tracker.activePortForwards[serviceInterfaceDetail] - - portForward.Close() + portForward, found := tracker.activePortForwards[enclaveServicePort] + if found { + portForward.Close() + } } -func (tracker *TunnelSessionTracker) addPortForward(serviceInterfaceDetail *ServiceInterfaceDetail, portForward *PortForwardTunnel) { - tracker.activePortForwards[serviceInterfaceDetail] = portForward +func (tracker *TunnelSessionTracker) addPortForward(enclaveServicePort EnclaveServicePort, portForward *PortForwardTunnel) { + tracker.activePortForwards[enclaveServicePort] = portForward }