Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 162 additions & 54 deletions web-agent/app/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from collections import deque
from logging.handlers import TimedRotatingFileHandler
from pathlib import Path
from typing import Optional, Tuple, Any, Dict, Union, List
from typing import Optional, Tuple, Any, Dict, Union, List, Callable
from urllib.parse import urlparse, urlunparse

import random
import requests
from gevent.pool import Pool

Expand Down Expand Up @@ -149,7 +150,12 @@ def process() -> None:
elif get_task_response.status_code == 204:
_log_get_task_metric(get_task_duration_ms, get_task_server_url, 204)
logger.info("No task available. Waiting...")
elif get_task_response.status_code > 500:
elif get_task_response.status_code == 429:
delay = get_retry_delay(get_task_response)
logger.warning("Rate limit on get-task, retrying in %.2fs", delay)
_log_get_task_metric(get_task_duration_ms, get_task_server_url, 429)
gevent.sleep(delay)
elif get_task_response.status_code >= 500:
_log_get_task_metric(get_task_duration_ms, get_task_server_url, get_task_response.status_code)
logger.error("Getting 5XX error %d, increasing backoff time", get_task_response.status_code)
gevent.sleep(thread_backoff_time)
Expand Down Expand Up @@ -218,43 +224,31 @@ def _log_update_metrics(
logger.debug(f"Failed to log update metrics: {e}")


def update_task(task: Optional[Dict[str, Any]], count: int = 0) -> None:
def update_task(task: Optional[Dict[str, Any]]) -> None:
if task is None:
return
# Update the task status
if count > max_retry:
logger.error("Retry count exceeds for task %s", task['taskId'])
return
try:

def _do_update() -> requests.Response:
rate_limiter.throttle()
update_start_time = time.time()
update_task_response: requests.Response = requests.post(
response = requests.post(
f"{config_dict.get('server_url')}/api/http-teleport/put-result",
headers=_get_headers(),
json=task,
timeout=30, verify=config_dict.get('verify_cert'), proxies=config_dict['outgoing_proxy']
)
update_duration_ms = (time.time() - update_start_time) * 1000
_log_update_metrics(task, response, update_duration_ms)
return response

# Log metrics for update operation
_log_update_metrics(task, update_task_response, update_duration_ms)

if update_task_response.status_code == 200:
logger.info("Task %s updated successfully. Response: %s", task['taskId'],
update_task_response.text)
elif update_task_response.status_code == 429 or update_task_response.status_code == 504:
gevent.sleep(2)
logger.warning("Rate limit hit while updating the task output, retrying again for task %s", task['taskId'])
count = count + 1
update_task(task, count)
else:
logger.warning("Failed to update task %s: %s", task['taskId'], update_task_response.text)


try:
response = retry_request(_do_update, operation_name=f"put-result[{task['taskId']}]")
if response and response.status_code == 200:
logger.info("Task %s updated successfully. Response: %s", task['taskId'], response.text)
elif response:
logger.warning("Failed to update task %s: %s", task['taskId'], response.text)
except requests.exceptions.RequestException as e:
logger.error("Network error processing task %s: %s", task['taskId'], e)
count = count + 1
update_task(task, count)
logger.error("Network error updating task %s: %s", task['taskId'], e)


def _get_headers() -> Dict[str, str]:
Expand All @@ -265,6 +259,100 @@ def _get_headers() -> Dict[str, str]:
return headers


def is_concurrent_limit_error(response: requests.Response) -> bool:
"""Returns True if 429 is due to concurrent request limit (not standard rate limit)."""
if response.status_code == 429:
try:
return "Too many concurrent requests" in response.text
except Exception:
return False
return False


def get_retry_delay(response: requests.Response, default_delay: int = 2) -> float:
"""
Extract retry delay from a 429 response.
Priority: concurrent error → random 0-10s, header value, default.
"""
if is_concurrent_limit_error(response):
delay = random.uniform(0, 10)
logger.info("Concurrent limit error, using random delay: %.2fs", delay)
return delay

retry_after = response.headers.get('X-Rate-Limit-Retry-After-Seconds')
if retry_after:
try:
delay = int(retry_after)
if delay < 0:
logger.warning("Negative retry delay %ds in header, using default %ds", delay, default_delay)
return default_delay
if delay > 300:
logger.warning("Excessive retry delay %ds in header, capping at 300s", delay)
return 300
logger.info("Using retry-after header delay: %ds", delay)
return delay
except ValueError:
logger.warning("Invalid retry delay '%s' in header, using default %ds", retry_after, default_delay)

return default_delay


def retry_request(
func: Callable[[], requests.Response],
max_retries: int = 5,
max_server_retries: int = 3,
operation_name: str = "request"
) -> Optional[requests.Response]:
"""
Unified retry wrapper for ArmorCode API calls.
- 429: respects X-Rate-Limit-Retry-After-Seconds header, up to max_retries
- 5XX / network error: exponential backoff (5s base, 120s cap), up to max_server_retries
Uses gevent.sleep() to yield to other greenlets during waits.
"""
rate_limit_attempts = 0
server_error_attempts = 0

while True:
try:
response = func()

if response.status_code == 429:
if rate_limit_attempts >= max_retries:
logger.error("%s rate limit exceeded after %d retries", operation_name, max_retries)
return response
delay = get_retry_delay(response)
logger.warning("%s rate limit hit (attempt %d/%d), retrying in %.2fs",
operation_name, rate_limit_attempts + 1, max_retries, delay)
gevent.sleep(delay)
rate_limit_attempts += 1
continue

if response.status_code >= 500:
if server_error_attempts >= max_server_retries:
logger.error("%s 5XX error %d after %d retries",
operation_name, response.status_code, max_server_retries)
return response
delay = min(5 * (2 ** server_error_attempts), 120) + random.uniform(0, 2)
logger.warning("%s got %d (attempt %d/%d), retrying in %.2fs",
operation_name, response.status_code,
server_error_attempts + 1, max_server_retries, delay)
gevent.sleep(delay)
server_error_attempts += 1
continue

return response

except requests.exceptions.RequestException as e:
if server_error_attempts >= max_server_retries:
logger.error("%s network error after %d retries: %s", operation_name, max_server_retries, e)
raise
delay = min(5 * (2 ** server_error_attempts), 120) + random.uniform(0, 2)
logger.warning("%s network error (attempt %d/%d), retrying in %.2fs: %s",
operation_name, server_error_attempts + 1, max_server_retries, delay, e)
gevent.sleep(delay)
server_error_attempts += 1


def check_for_logs_fetch(url, task, temp_output_file_zip):
if 'agent/fetch-logs' in url and 'fetchLogs' in task.get('taskId'):
try:
Expand All @@ -284,20 +372,26 @@ def check_for_logs_fetch(url, task, temp_output_file_zip):
"file": (temp_output_file_zip.name, open(temp_output_file_zip.name, "rb"), f"{'application/zip'}"),
"task": (None, task_json, "application/json")
}
rate_limiter.throttle()
upload_logs_url = f"{config_dict.get('server_url')}/api/http-teleport/upload-logs"
if len(config_dict.get('env_name', '')) > 0:
upload_logs_url = f"{config_dict.get('server_url')}/api/http-teleport/upload-logs?envName={config_dict.get('env_name')}"
upload_result: requests.Response = requests.post(
upload_logs_url,
headers=headers,
timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'],
files=files
)
if upload_result.status_code == 200:

def _do_upload_logs() -> requests.Response:
rate_limiter.throttle()
return requests.post(
upload_logs_url,
headers=headers,
timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'],
files=files
)

upload_result = retry_request(_do_upload_logs, operation_name="upload-logs")
if upload_result and upload_result.status_code == 200:
return True
else:
logger.error("Response code while uploading is not 200 , response code {} and error {} ", upload_result.status_code, upload_result.content)
logger.error("Response code while uploading is not 200 , response code {} and error {} ",
upload_result.status_code if upload_result else 'None',
upload_result.content if upload_result else 'None')
return True
except Exception as e:
logger.error(f"Error zipping logs: {str(e)}")
Expand Down Expand Up @@ -485,12 +579,16 @@ def upload_response(temp_file, temp_file_zip, taskId: str, task: Dict[str, Any])
}
rate_limiter.throttle()
upload_start_time = time.time()
upload_result: requests.Response = requests.post(
f"{config_dict.get('server_url')}/api/http-teleport/upload-result",
headers=headers,
timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'],
files=files
)

def _do_upload() -> requests.Response:
return requests.post(
f"{config_dict.get('server_url')}/api/http-teleport/upload-result",
headers=headers,
timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'],
files=files
)

upload_result = retry_request(_do_upload, operation_name=f"upload-result[{taskId}]")
upload_duration_ms = (time.time() - upload_start_time) * 1000

# Track file upload metrics
Expand All @@ -499,8 +597,8 @@ def upload_response(temp_file, temp_file_zip, taskId: str, task: Dict[str, Any])
operation="upload_file",
url=f"{config_dict.get('server_url')}/api/http-teleport/upload-result",
method="POST",
status_code=str(upload_result.status_code),
success=str(upload_result.status_code < 400).lower()
status_code=str(upload_result.status_code if upload_result else 0),
success=str(upload_result.status_code < 400 if upload_result else False).lower()
)
_safe_log_metric("http.request.duration_ms", upload_duration_ms, tags)

Expand Down Expand Up @@ -741,12 +839,17 @@ def upload_s3(temp_file, preSignedUrl: str, headers: Dict[str, Any]) -> bool:

try:
with open(temp_file, 'rb') as file:
response: requests.Response = requests.put(preSignedUrl, headers=headersForS3, data=file,
verify=config_dict.get('verify_cert', False),
proxies=config_dict['outgoing_proxy'], timeout=120)
response.raise_for_status()
logger.info('File uploaded successfully to S3')
return True
file_data = file.read()

def _do_s3_upload() -> requests.Response:
return requests.put(preSignedUrl, headers=headersForS3, data=file_data,
verify=config_dict.get('verify_cert', False),
proxies=config_dict['outgoing_proxy'], timeout=120)

response = retry_request(_do_s3_upload, max_retries=0, max_server_retries=3, operation_name="s3-upload")
response.raise_for_status()
logger.info('File uploaded successfully to S3')
return True
except requests.exceptions.RequestException as e:
logger.error("Network error uploading to S3: %s", e)
raise
Expand All @@ -768,17 +871,22 @@ def _createFolder(folder_path: str) -> None:

def get_s3_upload_url(taskId: str) -> Tuple[Optional[str], Optional[str]]:
params: Dict[str, str] = {'fileName': f"{taskId}{uuid.uuid4().hex}"}
try:

def _do_get_url() -> requests.Response:
rate_limiter.throttle()
get_s3_url: requests.Response = requests.get(
return requests.get(
f"{config_dict.get('server_url')}/api/http-teleport/upload-url",
params=params,
headers=_get_headers(),
timeout=25, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy']
)
get_s3_url.raise_for_status()

data: Optional[Dict[str, str]] = get_s3_url.json().get('data', None)
try:
response = retry_request(_do_get_url, operation_name="get-s3-upload-url")
if response is None:
return None, None
response.raise_for_status()
data: Optional[Dict[str, str]] = response.json().get('data', None)
if data is not None:
return data.get('putUrl'), data.get('getUrl')
logger.warning("No data returned when requesting S3 upload URL")
Expand Down
Loading