Skip to content
This repository has been archived by the owner on Jan 17, 2025. It is now read-only.

Commit

Permalink
Improve cache invalidation in multiuser context.
Browse files Browse the repository at this point in the history
This is a security fix (#35).
  • Loading branch information
kiorky committed Dec 12, 2024
1 parent 2b7f14a commit a48eeb4
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 27 deletions.
116 changes: 108 additions & 8 deletions src/bitwardentools/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import OrderedDict
from copy import deepcopy
from subprocess import run
from time import time
from time import sleep, time

import requests
from jwt import encode as jwt_encode
Expand All @@ -26,12 +26,14 @@
from bitwardentools import crypto as bwcrypto
from bitwardentools.common import L, caseinsentive_key_search

LOGIN_ENDPOINT_RE = re.compile("connect/token")
VAULTIER_FIELD_ID = "vaultiersecretid"
DEFAULT_CACHE = {"id": {}, "name": {}, "sync": False}
SYNC_ALL_ORGAS_ID = "__orga__all__ORGAS__"
SYNC_ORGA_ID = "__orga__{0}"
SECRET_CACHE = {"id": {}, "name": {}, "vaultiersecretid": {}, "sync": []}
DEFAULT_BITWARDEN_CACHE = {
"sync_token": {},
"sync": {},
"templates": {},
"users": deepcopy(DEFAULT_CACHE),
Expand Down Expand Up @@ -103,6 +105,7 @@
API_CHANGES = {
"1.27.0": _version.parse("1.27.0"),
}
MARKER = object()


def uncapitzalize(s):
Expand Down Expand Up @@ -271,6 +274,10 @@ class RunError(BitwardenError):
"""."""


class SecurityError(BitwardenError):
"""."""


class NoOrganizationKeyError(BitwardenError):
"""."""

Expand Down Expand Up @@ -693,6 +700,7 @@ def __init__(
cache=None,
vaultier=False,
authentication_cb=None,
multiuser=False,
):
# goal is to allow shared cache amongst client instances
# but also if we want totally isolated caches
Expand Down Expand Up @@ -728,6 +736,7 @@ def __init__(
self.login()
self._is_vaultwarden = False
self._version = None
self.multiuser = multiuser

@property
def token(self):
Expand Down Expand Up @@ -759,7 +768,17 @@ def adminr(
headers = {}
return getattr(requests, method.lower())(url, headers=headers, *a, **kw)

def r(self, uri, method="post", headers=None, token=None, retry=True, *a, **kw):
def r(
self,
uri,
method="post",
headers=None,
token=None,
retry=True,
multiuser=MARKER,
*a,
**kw,
):
url = uri
if not url.startswith("http"):
url = f"{self.server}{uri}"
Expand All @@ -768,14 +787,28 @@ def r(self, uri, method="post", headers=None, token=None, retry=True, *a, **kw):
if token is not False:
token = self.get_token(token)
headers.update({"Authorization": f"Bearer {token['access_token']}"})
self.invalidate_other_user_cache(url, token, multiuser=multiuser)
# if we try to get a new token, invalidate any local cache for security reason
resp = getattr(requests, method.lower())(url, headers=headers, *a, **kw)
if resp.status_code in [401] and token is not False and retry:
sleep(0.05)
L.debug(
f"Access denied, trying to retry after refreshing token for {token['email']}"
)
token = self.login(token["email"], token["password"])
headers.update({"Authorization": f"Bearer {token['access_token']}"})
resp = getattr(requests, method.lower())(url, headers=headers, *a, **kw)
if resp.status_code > 399 and retry is not False:
sleep(0.5)
L.debug(f"Something went wrong, retrying {url}")
resp = getattr(requests, method.lower())(url, headers=headers, *a, **kw)
return resp

def verify_token(self, token):
resp = self.r(
"/api/accounts/revision-date", token=token, retry=False, method="get"
)
self.assert_bw_response(resp)
return resp

def login(
Expand All @@ -784,19 +817,19 @@ def login(
password=None,
scope="api offline_access",
grant_type="password",
force=None,
):
email = email or self.email
try:
if force:
raise KeyError("force_relog")
token = self.tokens[email]
except KeyError:
pass
else:
# as token is already there, test if token is still usable
resp = self.r(
"/api/accounts/revision-date", token=token, retry=False, method="get"
)
try:
self.assert_bw_response(resp)
self.verify_token(token)
except ResponseError:
self.tokens.pop(email, None)
else:
Expand Down Expand Up @@ -952,10 +985,59 @@ def get_template(self, otype=None, **kw):
tpl.update(kw)
return tpl

def invalidate_other_user_cache(self, url=None, token=None, multiuser=MARKER):
"""
if we detect any new login or token change and the token is different
we bust cache for security unless the user choosed explicitly the contrary
either by setting self.multiuser=True or client(multiuser=True)
which is discouraged unless you know what you do
and takes the whole responsability of it because this can lead to leaks
"""
if multiuser is MARKER:
multiuser = self.multiuser
if not token:
token = {}
if not url:
url = ""
sat = self._cache["sync_token"]
at = sat.get("access_token", "")
is_different_token = at != token.get("access_token", "")
# in case of token check we will check first if token is same
# and in other case, we will check by email in case of token
# regeneration between requests
if not multiuser and is_different_token:
temail = token.get("email", "")
semail = sat.get("email", "")
if temail and semail and semail == temail:
resp = self.r(
"/api/accounts/profile", method="get", token=token, multiuser=True
)
self.assert_bw_response(resp)
profile = resp.json()
pemail = profile.get("Email", profile.get("email", ""))
if semail == pemail:
# tokens belong both to the same user, explicit no bust
return False
else:
self.bust_cache()
raise SecurityError(
f"token tampering/impersonation detected for email: "
f"token: {temail} / sync token: {semail} / remote profile: {pemail}"
)
if (
not self.multiuser
and self._cache["sync"]
and (is_different_token or LOGIN_ENDPOINT_RE.search(url))
):
self.bust_cache()
return True
return False

def api_sync(self, sync=None, cache=None, token=None):
_CACHE = self._cache["sync"]
k = "api_sync"
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
if sync is None:
sync = False
if cache is None:
Expand All @@ -975,6 +1057,7 @@ def api_sync(self, sync=None, cache=None, token=None):
self.assert_bw_response(resp)
_CACHE.update(resp.json())
_CACHE[k] = True
self._cache["sync_token"] = token
return _CACHE

def cli_sync(self, sync=None):
Expand All @@ -985,6 +1068,7 @@ def sync(self, sync=None, token=None):

def finish_orga(self, orga, cache=None, token=None, complete=None):
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
if complete and not getattr("orga", "BillingEmail", "") and not orga._complete:
orga = BWFactory.construct(
self.r(f"/api/organizations/{orga.id}", method="get").json(),
Expand All @@ -997,6 +1081,7 @@ def finish_orga(self, orga, cache=None, token=None, complete=None):

def get_organizations(self, sync=None, cache=None, token=None):
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
_CACHE = self._cache["organizations"]
if sync is None:
sync = False
Expand Down Expand Up @@ -1055,10 +1140,10 @@ def get_organization(self, orga, sync=None, cache=None, token=None, complete=Non
exc.criteria = [orga]
raise exc

def get_token(self, token=None):
def get_token(self, token=None, force=None):
token = token or self.token
if not token:
token = self.login()
token = self.login(force=force)
return token

def decrypt_item(self, val, key, decode=True, charset=None):
Expand Down Expand Up @@ -1380,6 +1465,8 @@ def create_organization(
return obj

def get_organization_key(self, orga, token=None, sync=None):
if token:
self.invalidate_other_user_cache(token=token)
keys = self._cache["organizations"].setdefault("keys", {})
if sync is None:
sync = False
Expand Down Expand Up @@ -1489,6 +1576,7 @@ def get_collections(self, orga=None, sync=None, cache=None, token=None):
orga is either None for all or an orga(or orgaid)
"""
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
if not orga:
sync_key = SYNC_ALL_ORGAS_ID
else:
Expand Down Expand Up @@ -1547,6 +1635,7 @@ def get_collection(
):
criteria = [item_or_id_or_name, orga]
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
if orga:
orga = self.get_organization(orga, token=token)
if isinstance(item_or_id_or_name, Collection):
Expand Down Expand Up @@ -1612,6 +1701,7 @@ def decrypt(
self, value, key=None, orga=None, token=None, recursion=None, dictkey=None
):
token = self.get_token(token=token)
self.invalidate_other_user_cache(token=token)
nvalue = value
idv = id(value)
if recursion is None:
Expand Down Expand Up @@ -1678,6 +1768,7 @@ def get_ciphers(
):
vaultier = self.get_vaultier(vaultier)
token = self.get_token(token=token)
self.invalidate_other_user_cache(token=token)
scache = self._cache["ciphers"]
if sync or cache is False:
scache.pop("sync", None)
Expand Down Expand Up @@ -1730,6 +1821,7 @@ def get_attachments(
):
vaultier = self.get_vaultier(vaultier)
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
sec = self.get_cipher(
item,
collection=collection,
Expand Down Expand Up @@ -1759,6 +1851,7 @@ def delete_attachment(
):
vaultier = self.get_vaultier(vaultier)
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
sec = self.get_cipher(
item,
collection=collection,
Expand Down Expand Up @@ -1787,6 +1880,7 @@ def delete_attachments(
):
vaultier = self.get_vaultier(vaultier)
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
ret = []
if not isinstance(attachments, list):
attachments = [attachments]
Expand Down Expand Up @@ -1816,6 +1910,7 @@ def attach(
):
vaultier = self.get_vaultier(vaultier)
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
fn = os.path.basename(filepath)
try:
attachments = self.get_attachments(
Expand Down Expand Up @@ -1884,6 +1979,7 @@ def get_cipher(
item_or_id_or_name = item_or_id_or_name.id
vaultier = self.get_vaultier(vaultier)
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
_id = f"{self.item_or_id(item_or_id_or_name)}"
if isinstance(_id, str):
_id = _id.lower()
Expand Down Expand Up @@ -2433,6 +2529,7 @@ def get_accesses(self, objs, sync=None, token=None):
# XXX: maybe we will implement cache at a later time
sync = True
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
ret, single = OrderedDict(), False
if not isinstance(objs, (list, set, tuple)):
single = True
Expand Down Expand Up @@ -3068,6 +3165,7 @@ def set_collection_access(

def collections_to_payloads(self, collections, orga=None, token=None):
token = self.get_token(token)
self.invalidate_other_user_cache(token=token)
colexc = []
dcollections = {}
if collections:
Expand Down Expand Up @@ -3102,6 +3200,7 @@ def ensure_private_key(self):
def accept_invitation(self, orga, email, id=None, name=None, sync=None, token=None):
self.ensure_private_key()
token = self.get_token(token=token)
self.invalidate_other_user_cache(token=token)
orga = self.get_organization(orga, token=token)
user = self.get_user(email=email, name=name, id=id, sync=sync)
email = user.email
Expand Down Expand Up @@ -3161,6 +3260,7 @@ def confirm_invitation(self, orga, email, name=None, sync=None, token=None):
Just email is necessary to match users
"""
token = self.get_token(token=token)
self.invalidate_other_user_cache(token=token)
orga = self.get_organization(orga, token=token)
orgkey = self.get_organization_key(orga, token=token)
oaccess = self.get_accesses(orga, token=token)
Expand Down
Loading

0 comments on commit a48eeb4

Please sign in to comment.