diff --git a/src/aignostics/platform/_client.py b/src/aignostics/platform/_client.py index 038a62cb6..9c63dc72f 100644 --- a/src/aignostics/platform/_client.py +++ b/src/aignostics/platform/_client.py @@ -1,9 +1,10 @@ import os +from collections.abc import Callable from urllib.request import getproxies from aignx.codegen.api.public_api import PublicApi from aignx.codegen.api_client import ApiClient -from aignx.codegen.configuration import Configuration +from aignx.codegen.configuration import AuthSettings, Configuration from aignx.codegen.exceptions import NotFoundException from aignx.codegen.models import ApplicationReadResponse as Application from aignx.codegen.models import MeReadResponse as Me @@ -18,6 +19,34 @@ logger = get_logger(__name__) +class _OAuth2TokenProviderConfiguration(Configuration): + """ + Overwrites the original Configuration to call a function to obtain a refresh token. + + The base class does not support callbacks. This is necessary for integrations where + tokens may expire or need to be refreshed automatically. + """ + + def __init__( + self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None + ) -> None: + super().__init__(host=host, ssl_ca_cert=ssl_ca_cert) + self.token_provider = token_provider + + def auth_settings(self) -> AuthSettings: + token = self.token_provider() if self.token_provider else None + if not token: + return {} + return { + "OAuth2AuthorizationCodeBearer": { + "type": "oauth2", + "in": "header", + "key": "Authorization", + "value": f"Bearer {token}", + } + } + + class Client: """Main client for interacting with the Aignostics Platform API. @@ -92,28 +121,29 @@ def application(self, application_id: str) -> Application: @staticmethod def get_api_client(cache_token: bool = True) -> PublicApi: - """Creates and configures an authenticated API client. + """Create and configure an authenticated API client. Args: cache_token (bool): If True, caches the authentication token. Defaults to True. Returns: - ExternalsApi: Configured API client with authentication token. + PublicApi: Configured API client with authentication token. Raises: RuntimeError: If authentication fails. """ - token = get_token(use_cache=cache_token) - config = Configuration( - host=settings().api_root, - ssl_ca_cert=os.getenv("REQUESTS_CA_BUNDLE"), # point to .cer file of proxy if defined + + def token_provider() -> str: + return get_token(use_cache=cache_token) + + ca_file = os.getenv("REQUESTS_CA_BUNDLE") # point to .cer file of proxy if defined + config = _OAuth2TokenProviderConfiguration( + host=settings().api_root, ssl_ca_cert=ca_file, token_provider=token_provider ) config.proxy = getproxies().get("https") # use system proxy client = ApiClient( config, - header_name="Authorization", - header_value=f"Bearer {token}", ) client.user_agent = user_agent() return PublicApi(client) diff --git a/tests/aignostics/platform/client_token_provider_test.py b/tests/aignostics/platform/client_token_provider_test.py new file mode 100644 index 000000000..c9382988f --- /dev/null +++ b/tests/aignostics/platform/client_token_provider_test.py @@ -0,0 +1,51 @@ +"""Tests for the token provider configuration and its integration with the client.""" + +from unittest.mock import Mock, patch + +from aignostics.platform._client import Client, _OAuth2TokenProviderConfiguration + + +def test_oauth2_token_provider_configuration_uses_token_provider() -> None: + """Test that token_provider is used when provided.""" + token_provider = Mock(return_value="dynamic-token") + config = _OAuth2TokenProviderConfiguration(host="https://dummy", token_provider=token_provider) + auth = config.auth_settings() + assert auth["OAuth2AuthorizationCodeBearer"]["value"] == "Bearer dynamic-token" + token_provider.assert_called_once() + + +def test_oauth2_token_provider_configuration_no_token() -> None: + """Test that auth_settings returns empty dict if no token_provider is set.""" + config = _OAuth2TokenProviderConfiguration(host="https://dummy") + auth = config.auth_settings() + assert auth == {} + + +def test_client_passes_token_provider() -> None: + """Test that the client passes the token provider to the configuration.""" + with ( + patch("aignostics.platform._client.get_token", return_value="client-token"), + patch("aignostics.platform._client.ApiClient") as api_client_mock, + patch("aignostics.platform._client.PublicApi") as public_api_mock, + ): + Client(cache_token=False) + config_used = api_client_mock.call_args[0][0] + assert isinstance(config_used, _OAuth2TokenProviderConfiguration) + assert config_used.token_provider() == "client-token" + public_api_mock.assert_called() + + +def test_client_me_calls_api() -> None: + """Test that the client.me() method calls the API and returns the result.""" + with ( + patch("aignostics.platform._client.get_token", return_value="client-token"), + patch("aignostics.platform._client.ApiClient"), + patch("aignostics.platform._client.PublicApi") as public_api_mock, + ): + api_instance = Mock() + api_instance.get_me_v1_me_get.return_value = "me-info" + public_api_mock.return_value = api_instance + client = Client() + result = client.me() + assert result == "me-info" + api_instance.get_me_v1_me_get.assert_called_once()