Skip to content

Commit

Permalink
feat: user impersonate (#1811)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahargl authored Sep 3, 2024
1 parent 2df7f58 commit c4f5b4e
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 42 deletions.
43 changes: 33 additions & 10 deletions keep/api/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,18 @@ def delete_user(username):
session.commit()


def user_exists(tenant_id, username):
from keep.api.models.db.user import User

with Session(engine) as session:
user = session.exec(
select(User)
.where(User.tenant_id == tenant_id)
.where(User.username == username)
).first()
return user is not None


def create_user(tenant_id, username, password, role):
from keep.api.models.db.user import User

Expand Down Expand Up @@ -2589,7 +2601,9 @@ def get_pmi_values(
return pmi_values


def update_incident_summary(tenant_id: str, incident_id: UUID, summary: str) -> Incident:
def update_incident_summary(
tenant_id: str, incident_id: UUID, summary: str
) -> Incident:
with Session(engine) as session:
incident = session.exec(
select(Incident)
Expand All @@ -2598,29 +2612,38 @@ def update_incident_summary(tenant_id: str, incident_id: UUID, summary: str) ->
).first()

if not incident:
logger.error(f"Incident not found for tenant {tenant_id} and incident {incident_id}", extra={"tenant_id": tenant_id})
return
logger.error(
f"Incident not found for tenant {tenant_id} and incident {incident_id}",
extra={"tenant_id": tenant_id},
)
return

incident.generated_summary = summary
session.commit()
session.refresh(incident)

return


return


def update_incident_name(tenant_id: str, incident_id: UUID, name: str) -> Incident:
with Session(engine) as session:
incident = session.exec(
select(Incident).where(Incident.tenant_id == tenant_id).where(Incident.id == incident_id)
select(Incident)
.where(Incident.tenant_id == tenant_id)
.where(Incident.id == incident_id)
).first()

if not incident:
logger.error(f"Incident not found for tenant {tenant_id} and incident {incident_id}", extra={"tenant_id": tenant_id})
logger.error(
f"Incident not found for tenant {tenant_id} and incident {incident_id}",
extra={"tenant_id": tenant_id},
)
return

incident.ai_generated_name = name
session.commit()
session.refresh(incident)

return incident


Expand Down
92 changes: 80 additions & 12 deletions keep/identitymanager/authverifierbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
OAuth2PasswordBearer,
)

from keep.api.core.config import config
from keep.api.core.db import get_api_key, update_key_last_used
from keep.identitymanager.authenticatedentity import AuthenticatedEntity
from keep.identitymanager.rbac import Admin as AdminRole
from keep.identitymanager.rbac import get_role_by_role_name

auth_header = APIKeyHeader(name="X-API-KEY", scheme_name="API Key", auto_error=False)
Expand Down Expand Up @@ -62,6 +64,18 @@ def __init__(self, scopes: list[str] = []) -> None:
ALL_RESOURCES.update([scope.split(":")[1] for scope in scopes])
self.scopes = scopes
self.logger = logging.getLogger(__name__)
self.impersonation_enabled = (
config("KEEP_IMPERSONATION_ENABLED", default="false") == "true"
)
self.impersonation_user_header = config(
"KEEP_IMPERSONATION_USER_HEADER", default="X-KEEP-USER"
)
self.impersonation_role_header = config(
"KEEP_IMPERSONATION_ROLE_HEADER", default="X-KEEP-ROLE"
)
self.impersonation_auto_provision = (
config("KEEP_IMPERSONATION_AUTO_PROVISION", default="false") == "true"
)

def __call__(
self,
Expand Down Expand Up @@ -265,24 +279,78 @@ def _verify_api_key(
if not tenant_api_key:
self.logger.warning("Invalid API Key")
raise HTTPException(status_code=401, detail="Invalid API Key")
else:

try:
self.logger.debug("Updating API Key last used")
try:
update_key_last_used(
tenant_api_key.tenant_id, reference_id=tenant_api_key.reference_id
)
except Exception:
self.logger.exception("Failed to update API Key last used")
update_key_last_used(
tenant_api_key.tenant_id, reference_id=tenant_api_key.reference_id
)
self.logger.debug("Successfully updated API Key last used")
except Exception:
self.logger.exception("Failed to update API Key last used")

request.state.tenant_id = tenant_api_key.tenant_id

self.logger.debug(f"API key verified for tenant: {tenant_api_key.tenant_id}")
# check if impersonation is enabled, if not, return the api key's authenticated entity
if not self.impersonation_enabled:
return AuthenticatedEntity(
tenant_api_key.tenant_id,
tenant_api_key.created_by,
tenant_api_key.reference_id,
tenant_api_key.role,
)
# check if impersonation headers are present
user_name = request.headers.get(self.impersonation_user_header)
role = request.headers.get(self.impersonation_role_header)
# if not, return the apikey's authenticated entity
if not user_name or not role:
return AuthenticatedEntity(
tenant_api_key.tenant_id,
tenant_api_key.created_by,
tenant_api_key.reference_id,
tenant_api_key.role,
)

self.logger.info("Impersonating user")
user_name = request.headers.get(self.impersonation_user_header)
role = request.headers.get(self.impersonation_role_header)
if not user_name or not role:
raise HTTPException(status_code=401, detail="Impersonation headers missing")

# TODO - validate authorization meaning api key X has access to impersonate user Y
# for now, only admin users can impersonate
if tenant_api_key.role != AdminRole.get_name():
self.logger.error("Impersonation not allowed for non-admin users")
raise HTTPException(
status_code=401, detail="Impersonation not allowed for non-admin users"
)

# auto provision user
if self.impersonation_auto_provision:
self.logger.info(f"Auto provisioning user: {user_name}")
self._provision_user(tenant_api_key.tenant_id, user_name, role)
self.logger.info(f"User {user_name} provisioned successfully")

self.logger.info("User impersonated successfully")
return AuthenticatedEntity(
tenant_api_key.tenant_id,
tenant_api_key.created_by,
tenant_api_key.reference_id,
tenant_api_key.role,
tenant_id=tenant_api_key.tenant_id,
email=user_name,
api_key_name=None,
role=role,
)

def _provision_user(self, tenant_api_key, user_name, role):
"""
Create a user for impersonation.
Args:
tenant_api_key: The API key used for impersonation.
user_name: The name of the user to create.
role: The role of the user to create.
"""
raise NotImplementedError(
"User provisioning not implemented"
" for {}".format(self.__class__.__name__)
)

def _verify_bearer_token(self, token: str) -> AuthenticatedEntity:
Expand Down
6 changes: 6 additions & 0 deletions keep/identitymanager/identity_managers/db/db_authverifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jwt
from fastapi import HTTPException

from keep.api.core.db import create_user, user_exists
from keep.identitymanager.authenticatedentity import AuthenticatedEntity
from keep.identitymanager.authverifierbase import AuthVerifierBase
from keep.identitymanager.rbac import Admin as AdminRole
Expand Down Expand Up @@ -39,3 +40,8 @@ def _verify_bearer_token(self, token: str) -> AuthenticatedEntity:
detail="You don't have the required permissions to access this resource",
)
return AuthenticatedEntity(tenant_id, email, None, role_name)

# create user for auto-provisioning
def _provision_user(self, tenant_id, user_name, role):
if not user_exists(tenant_id, user_name):
create_user(tenant_id=tenant_id, username=user_name, role=role, password="")
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def on_start(self, app) -> None:

@app.post("/signin")
def signin(body: dict):
# block empty passwords (e.g. user provisioned)
if not body.get("password"):
return JSONResponse(
status_code=401,
content={"message": "Empty password"},
)

# validate the user/password
user = get_user(body.get("username"), body.get("password"))
if not user:
Expand Down
33 changes: 14 additions & 19 deletions tests/fixtures/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,30 @@

@pytest.fixture
def test_app(monkeypatch, request):
auth_type = request.param
elastic_enabled = None
if isinstance(auth_type, tuple):
auth_type = auth_type[0]
elastic_enabled = auth_type[1]
# Check if request.param is a dict or a string
if isinstance(request.param, dict):
# Set environment variables based on the provided dictionary
for key, value in request.param.items():
monkeypatch.setenv(key, str(value))
else:
# Old behavior for string parameters
auth_type = request.param
monkeypatch.setenv("AUTH_TYPE", auth_type)
monkeypatch.setenv("KEEP_JWT_SECRET", "somesecret")

monkeypatch.setenv("AUTH_TYPE", auth_type)
if auth_type == "MULTI_TENANT":
monkeypatch.setenv("AUTH0_DOMAIN", "https//auth0domain.com")
monkeypatch.setenv("KEEP_JWT_SECRET", "somesecret")
if elastic_enabled is not None:
monkeypatch.setenv("ELASTIC_ENABLED", str(elastic_enabled))
# Ok this is bit complex so stay with me:
# We need to reload the app to make sure the AuthVerifier is instantiated with the correct environment variable
# However, we can't just reload the module because the app is instantiated in the get_app() function
# So we need to delete the module from sys.modules and re-import it
if auth_type == "MULTI_TENANT":
monkeypatch.setenv("AUTH0_DOMAIN", "https://auth0domain.com")

# First, delete all the routes modules from sys.modules
# Clear and reload modules to ensure environment changes are reflected
for module in list(sys.modules):
if module.startswith("keep.api.routes"):
del sys.modules[module]
# Second, delete the api module from sys.modules
if "keep.api.api" in sys.modules:
importlib.reload(sys.modules["keep.api.api"])

# Now, import it, and it will re-instantiate the app with the correct environment variable
# Import and return the app instance
from keep.api.api import get_app

# Finally, return the app
app = get_app()
return app

Expand Down
Loading

0 comments on commit c4f5b4e

Please sign in to comment.