diff --git a/frontend/auth_providers/zju.py b/frontend/auth_providers/zju.py index 5d67cef..2c72e98 100644 --- a/frontend/auth_providers/zju.py +++ b/frontend/auth_providers/zju.py @@ -1,7 +1,6 @@ import secrets import string import json -import hmac from typing import Any from Crypto.Cipher import AES from base64 import b64decode @@ -37,74 +36,69 @@ def __init__(self, **kwargs: Any) -> None: self.cipher_key = b64decode(provider["cipher_key"]) self.provider_url = provider["provider_url"] - # XXX: POST 请求到外部,不宜外传 csrf token,且反 CSRF 的功能已被 state+proof 实现 + # XXX: POST 请求到外部,不宜外传 csrf token,且反 CSRF 的功能已被 nonce 实现 @method_decorator(csrf_exempt) def dispatch(self, *args, **kwargs): return super(LoginView, self).dispatch(*args, **kwargs) - def calc_hmac(self, state: str) -> str: - return hmac.new(self.state_key, state.encode(), "sha256").hexdigest() - - def get(self, request): - template_context = self.template_context.copy() - state = "".join( - secrets.choice(string.ascii_letters + string.digits) for _ in range(32) - ) - proof = self.calc_hmac(state) - - # 在登录前显示提示信息(而非直接跳转) - redirect_uri = request.build_absolute_uri( - "/accounts/zju/login/?" + urlencode({"proof": proof}) - ) - template_context["url"] = ( - self.provider_url - + "?" - + urlencode({"redirect_uri": redirect_uri, "state": state}) - ) - return TemplateResponse(request, self.template_name, template_context) + def normalize_identity(self): + return self.identity.casefold() def check_cipher(self) -> bool: try: - ciphertext = self.request.POST.get("cipher") - proof = self.request.GET.get("proof") - assert isinstance(ciphertext, str) and isinstance(proof, str) + ciphertext = self.cipher + nonce = self.request.session.get("auth_nonce_zju") + assert isinstance(ciphertext, str) and isinstance(nonce, str) + self.request.session.pop("auth_nonce_zju") # ciphertext: # 16 bytes: nonce # 16 bytes till end - 16 bytes: payload # 16 bytes from end: MAC ciphertext = b64decode(ciphertext) cipher = AES.new(self.cipher_key, AES.MODE_GCM, ciphertext[:16]) + cipher.update(nonce.encode()) payload = json.loads( cipher.decrypt_and_verify(ciphertext[16:-16], ciphertext[-16:]) ) - student_id = payload["student_id"] - state = payload["state"] - assert isinstance(student_id, str) and isinstance(state, str) - assert self.calc_hmac(state) == proof - student_id = student_id.strip() + sno = payload["sno"] + assert isinstance(sno, str) + sno = sno.strip() # XXX: 实际上的长度是 8,但作为可信内容放宽限制来避免一些特殊情况 if not ( - all(char in string.digits for char in student_id) - and 5 <= len(student_id) <= 26 + all(char in string.digits for char in sno) + and 5 <= len(sno) <= 26 ): messages.error(self.request, "学号非法") return False - self.sno = student_id + self.sno = sno self.name = payload["name"] # XXX: 能够以此学号登录应当与拥有此邮箱等价 - self.identity = student_id + "@zju.edu.cn" + self.identity = sno + "@zju.edu.cn" return True except Exception: messages.error(self.request, "登录失败") return False - def normalize_identity(self): - return self.identity.casefold() + def get(self, request): + self.cipher = request.GET.get("cipher") + if self.cipher: + if self.check_cipher(): + self.login(email=self.identity, sno=self.sno, name=self.name) + return redirect("hub") + template_context = self.template_context.copy() + nonce = "".join( + secrets.choice(string.ascii_letters + string.digits) for _ in range(32) + ) + request.session['auth_nonce_zju'] = nonce - def post(self, request): - if self.check_cipher(): - self.login(email=self.identity, sno=self.sno, name=self.name) - return redirect("hub") + # 在登录前显示提示信息(而非直接跳转) + redirect_uri = request.build_absolute_uri("/accounts/zju/login/") + template_context["url"] = ( + self.provider_url + + "?" + + urlencode({"redirect_uri": redirect_uri, "nonce": nonce}) + ) + return TemplateResponse(request, self.template_name, template_context) urlpatterns = [