diff --git a/web-agent/app/worker.py b/web-agent/app/worker.py index 7fef462..4eb87e8 100644 --- a/web-agent/app/worker.py +++ b/web-agent/app/worker.py @@ -22,9 +22,10 @@ from collections import deque from logging.handlers import TimedRotatingFileHandler from pathlib import Path -from typing import Optional, Tuple, Any, Dict, Union, List +from typing import Optional, Tuple, Any, Dict, Union, List, Callable from urllib.parse import urlparse, urlunparse +import random import requests from gevent.pool import Pool @@ -149,7 +150,12 @@ def process() -> None: elif get_task_response.status_code == 204: _log_get_task_metric(get_task_duration_ms, get_task_server_url, 204) logger.info("No task available. Waiting...") - elif get_task_response.status_code > 500: + elif get_task_response.status_code == 429: + delay = get_retry_delay(get_task_response) + logger.warning("Rate limit on get-task, retrying in %.2fs", delay) + _log_get_task_metric(get_task_duration_ms, get_task_server_url, 429) + gevent.sleep(delay) + elif get_task_response.status_code >= 500: _log_get_task_metric(get_task_duration_ms, get_task_server_url, get_task_response.status_code) logger.error("Getting 5XX error %d, increasing backoff time", get_task_response.status_code) gevent.sleep(thread_backoff_time) @@ -218,43 +224,31 @@ def _log_update_metrics( logger.debug(f"Failed to log update metrics: {e}") -def update_task(task: Optional[Dict[str, Any]], count: int = 0) -> None: +def update_task(task: Optional[Dict[str, Any]]) -> None: if task is None: return - # Update the task status - if count > max_retry: - logger.error("Retry count exceeds for task %s", task['taskId']) - return - try: + + def _do_update() -> requests.Response: rate_limiter.throttle() update_start_time = time.time() - update_task_response: requests.Response = requests.post( + response = requests.post( f"{config_dict.get('server_url')}/api/http-teleport/put-result", headers=_get_headers(), json=task, timeout=30, verify=config_dict.get('verify_cert'), proxies=config_dict['outgoing_proxy'] ) update_duration_ms = (time.time() - update_start_time) * 1000 + _log_update_metrics(task, response, update_duration_ms) + return response - # Log metrics for update operation - _log_update_metrics(task, update_task_response, update_duration_ms) - - if update_task_response.status_code == 200: - logger.info("Task %s updated successfully. Response: %s", task['taskId'], - update_task_response.text) - elif update_task_response.status_code == 429 or update_task_response.status_code == 504: - gevent.sleep(2) - logger.warning("Rate limit hit while updating the task output, retrying again for task %s", task['taskId']) - count = count + 1 - update_task(task, count) - else: - logger.warning("Failed to update task %s: %s", task['taskId'], update_task_response.text) - - + try: + response = retry_request(_do_update, operation_name=f"put-result[{task['taskId']}]") + if response and response.status_code == 200: + logger.info("Task %s updated successfully. Response: %s", task['taskId'], response.text) + elif response: + logger.warning("Failed to update task %s: %s", task['taskId'], response.text) except requests.exceptions.RequestException as e: - logger.error("Network error processing task %s: %s", task['taskId'], e) - count = count + 1 - update_task(task, count) + logger.error("Network error updating task %s: %s", task['taskId'], e) def _get_headers() -> Dict[str, str]: @@ -265,6 +259,100 @@ def _get_headers() -> Dict[str, str]: return headers +def is_concurrent_limit_error(response: requests.Response) -> bool: + """Returns True if 429 is due to concurrent request limit (not standard rate limit).""" + if response.status_code == 429: + try: + return "Too many concurrent requests" in response.text + except Exception: + return False + return False + + +def get_retry_delay(response: requests.Response, default_delay: int = 2) -> float: + """ + Extract retry delay from a 429 response. + Priority: concurrent error → random 0-10s, header value, default. + """ + if is_concurrent_limit_error(response): + delay = random.uniform(0, 10) + logger.info("Concurrent limit error, using random delay: %.2fs", delay) + return delay + + retry_after = response.headers.get('X-Rate-Limit-Retry-After-Seconds') + if retry_after: + try: + delay = int(retry_after) + if delay < 0: + logger.warning("Negative retry delay %ds in header, using default %ds", delay, default_delay) + return default_delay + if delay > 300: + logger.warning("Excessive retry delay %ds in header, capping at 300s", delay) + return 300 + logger.info("Using retry-after header delay: %ds", delay) + return delay + except ValueError: + logger.warning("Invalid retry delay '%s' in header, using default %ds", retry_after, default_delay) + + return default_delay + + +def retry_request( + func: Callable[[], requests.Response], + max_retries: int = 5, + max_server_retries: int = 3, + operation_name: str = "request" +) -> Optional[requests.Response]: + """ + Unified retry wrapper for ArmorCode API calls. + - 429: respects X-Rate-Limit-Retry-After-Seconds header, up to max_retries + - 5XX / network error: exponential backoff (5s base, 120s cap), up to max_server_retries + Uses gevent.sleep() to yield to other greenlets during waits. + """ + rate_limit_attempts = 0 + server_error_attempts = 0 + + while True: + try: + response = func() + + if response.status_code == 429: + if rate_limit_attempts >= max_retries: + logger.error("%s rate limit exceeded after %d retries", operation_name, max_retries) + return response + delay = get_retry_delay(response) + logger.warning("%s rate limit hit (attempt %d/%d), retrying in %.2fs", + operation_name, rate_limit_attempts + 1, max_retries, delay) + gevent.sleep(delay) + rate_limit_attempts += 1 + continue + + if response.status_code >= 500: + if server_error_attempts >= max_server_retries: + logger.error("%s 5XX error %d after %d retries", + operation_name, response.status_code, max_server_retries) + return response + delay = min(5 * (2 ** server_error_attempts), 120) + random.uniform(0, 2) + logger.warning("%s got %d (attempt %d/%d), retrying in %.2fs", + operation_name, response.status_code, + server_error_attempts + 1, max_server_retries, delay) + gevent.sleep(delay) + server_error_attempts += 1 + continue + + return response + + except requests.exceptions.RequestException as e: + if server_error_attempts >= max_server_retries: + logger.error("%s network error after %d retries: %s", operation_name, max_server_retries, e) + raise + delay = min(5 * (2 ** server_error_attempts), 120) + random.uniform(0, 2) + logger.warning("%s network error (attempt %d/%d), retrying in %.2fs: %s", + operation_name, server_error_attempts + 1, max_server_retries, delay, e) + gevent.sleep(delay) + server_error_attempts += 1 + + def check_for_logs_fetch(url, task, temp_output_file_zip): if 'agent/fetch-logs' in url and 'fetchLogs' in task.get('taskId'): try: @@ -284,20 +372,26 @@ def check_for_logs_fetch(url, task, temp_output_file_zip): "file": (temp_output_file_zip.name, open(temp_output_file_zip.name, "rb"), f"{'application/zip'}"), "task": (None, task_json, "application/json") } - rate_limiter.throttle() upload_logs_url = f"{config_dict.get('server_url')}/api/http-teleport/upload-logs" if len(config_dict.get('env_name', '')) > 0: upload_logs_url = f"{config_dict.get('server_url')}/api/http-teleport/upload-logs?envName={config_dict.get('env_name')}" - upload_result: requests.Response = requests.post( - upload_logs_url, - headers=headers, - timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'], - files=files - ) - if upload_result.status_code == 200: + + def _do_upload_logs() -> requests.Response: + rate_limiter.throttle() + return requests.post( + upload_logs_url, + headers=headers, + timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'], + files=files + ) + + upload_result = retry_request(_do_upload_logs, operation_name="upload-logs") + if upload_result and upload_result.status_code == 200: return True else: - logger.error("Response code while uploading is not 200 , response code {} and error {} ", upload_result.status_code, upload_result.content) + logger.error("Response code while uploading is not 200 , response code {} and error {} ", + upload_result.status_code if upload_result else 'None', + upload_result.content if upload_result else 'None') return True except Exception as e: logger.error(f"Error zipping logs: {str(e)}") @@ -485,12 +579,16 @@ def upload_response(temp_file, temp_file_zip, taskId: str, task: Dict[str, Any]) } rate_limiter.throttle() upload_start_time = time.time() - upload_result: requests.Response = requests.post( - f"{config_dict.get('server_url')}/api/http-teleport/upload-result", - headers=headers, - timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'], - files=files - ) + + def _do_upload() -> requests.Response: + return requests.post( + f"{config_dict.get('server_url')}/api/http-teleport/upload-result", + headers=headers, + timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'], + files=files + ) + + upload_result = retry_request(_do_upload, operation_name=f"upload-result[{taskId}]") upload_duration_ms = (time.time() - upload_start_time) * 1000 # Track file upload metrics @@ -499,8 +597,8 @@ def upload_response(temp_file, temp_file_zip, taskId: str, task: Dict[str, Any]) operation="upload_file", url=f"{config_dict.get('server_url')}/api/http-teleport/upload-result", method="POST", - status_code=str(upload_result.status_code), - success=str(upload_result.status_code < 400).lower() + status_code=str(upload_result.status_code if upload_result else 0), + success=str(upload_result.status_code < 400 if upload_result else False).lower() ) _safe_log_metric("http.request.duration_ms", upload_duration_ms, tags) @@ -741,12 +839,17 @@ def upload_s3(temp_file, preSignedUrl: str, headers: Dict[str, Any]) -> bool: try: with open(temp_file, 'rb') as file: - response: requests.Response = requests.put(preSignedUrl, headers=headersForS3, data=file, - verify=config_dict.get('verify_cert', False), - proxies=config_dict['outgoing_proxy'], timeout=120) - response.raise_for_status() - logger.info('File uploaded successfully to S3') - return True + file_data = file.read() + + def _do_s3_upload() -> requests.Response: + return requests.put(preSignedUrl, headers=headersForS3, data=file_data, + verify=config_dict.get('verify_cert', False), + proxies=config_dict['outgoing_proxy'], timeout=120) + + response = retry_request(_do_s3_upload, max_retries=0, max_server_retries=3, operation_name="s3-upload") + response.raise_for_status() + logger.info('File uploaded successfully to S3') + return True except requests.exceptions.RequestException as e: logger.error("Network error uploading to S3: %s", e) raise @@ -768,17 +871,22 @@ def _createFolder(folder_path: str) -> None: def get_s3_upload_url(taskId: str) -> Tuple[Optional[str], Optional[str]]: params: Dict[str, str] = {'fileName': f"{taskId}{uuid.uuid4().hex}"} - try: + + def _do_get_url() -> requests.Response: rate_limiter.throttle() - get_s3_url: requests.Response = requests.get( + return requests.get( f"{config_dict.get('server_url')}/api/http-teleport/upload-url", params=params, headers=_get_headers(), timeout=25, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'] ) - get_s3_url.raise_for_status() - data: Optional[Dict[str, str]] = get_s3_url.json().get('data', None) + try: + response = retry_request(_do_get_url, operation_name="get-s3-upload-url") + if response is None: + return None, None + response.raise_for_status() + data: Optional[Dict[str, str]] = response.json().get('data', None) if data is not None: return data.get('putUrl'), data.get('getUrl') logger.warning("No data returned when requesting S3 upload URL")