diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 76efdb027546f1..548f09c0240810 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -471,6 +471,11 @@ class MailConfig(BaseSettings): default=False, ) + EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field( + description="Maximum number of emails allowed to be sent from the same IP address in a minute", + default=50, + ) + class RagEtlConfig(BaseSettings): """ diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 005f38b8e5dbcd..60ffd14dcb1383 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -13,7 +13,7 @@ InvalidTokenError, PasswordMismatchError, ) -from controllers.console.error import NotAllowedCreateWorkspace, NotAllowedRegister +from controllers.console.error import EmailSendIpLimitError, NotAllowedCreateWorkspace, NotAllowedRegister from controllers.console.setup import setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db @@ -31,6 +31,10 @@ def post(self): parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() + ip_address = get_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index ad6c03672235cf..eb389d48ef06f6 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -15,7 +15,7 @@ InvalidEmailError, InvalidTokenError, ) -from controllers.console.error import NotAllowedCreateWorkspace, NotAllowedRegister +from controllers.console.error import EmailSendIpLimitError, NotAllowedCreateWorkspace, NotAllowedRegister from controllers.console.setup import setup_required from events.tenant_event import tenant_was_created from libs.helper import email, get_remote_ip @@ -122,6 +122,10 @@ def post(self): parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() + ip_address = get_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + if args["language"] is not None and args["language"] == "zh-Hans": language = "zh-Hans" else: diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index c3153be419c6ab..f454eec2ece22a 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -50,3 +50,9 @@ class NotAllowedRegister(BaseHTTPException): error_code = "unauthorized" description = "Account not found." code = 400 + + +class EmailSendIpLimitError(BaseHTTPException): + error_code = "email_send_ip_limit" + description = "Too many emails have been sent from this IP address recently. Please try again later." + code = 429 diff --git a/api/services/account_service.py b/api/services/account_service.py index ae540ad9ca85ef..1d33a637cf41dd 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -369,6 +369,46 @@ def reset_login_error_rate_limit(email: str): key = f"login_error_rate_limit:{email}" redis_client.delete(key) + @staticmethod + def is_email_send_ip_limit(ip_address: str): + minute_key = f"email_send_ip_limit_minute:{ip_address}" + freeze_key = f"email_send_ip_limit_freeze:{ip_address}" + hour_limit_key = f"email_send_ip_limit_hour:{ip_address}" + + # check ip is frozen + if redis_client.get(freeze_key): + return True + + # check current minute count + current_minute_count = redis_client.get(minute_key) + if current_minute_count is None: + current_minute_count = 0 + current_minute_count = int(current_minute_count) + + # check current hour count + if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: + hour_limit_count = redis_client.get(hour_limit_key) + if hour_limit_count is None: + hour_limit_count = 0 + hour_limit_count = int(hour_limit_count) + + if hour_limit_count >= 1: + redis_client.setex(freeze_key, 60 * 60, 1) + return True + else: + redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes + + # add hour limit count + redis_client.incr(hour_limit_key) + redis_client.expire(hour_limit_key, 60 * 60) + + return True + + redis_client.setex(minute_key, 60, current_minute_count + 1) + redis_client.expire(minute_key, 60) + + return False + def _get_login_cache_key(*, account_id: str, token: str): return f"account_login:{account_id}:{token}"