Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions src/aignostics/platform/_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from collections.abc import Callable
from urllib.request import getproxies

from aignx.codegen.api.public_api import PublicApi
from aignx.codegen.api_client import ApiClient
from aignx.codegen.configuration import Configuration
from aignx.codegen.configuration import AuthSettings, Configuration
from aignx.codegen.exceptions import NotFoundException
from aignx.codegen.models import ApplicationReadResponse as Application
from aignx.codegen.models import MeReadResponse as Me
Expand All @@ -18,6 +19,34 @@
logger = get_logger(__name__)


class _OAuth2TokenProviderConfiguration(Configuration):
"""
Overwrites the original Configuration to call a function to obtain a refresh token.

The base class does not support callbacks. This is necessary for integrations where
tokens may expire or need to be refreshed automatically.
"""

def __init__(
self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None
) -> None:
super().__init__(host=host, ssl_ca_cert=ssl_ca_cert)
self.token_provider = token_provider

def auth_settings(self) -> AuthSettings:
token = self.token_provider() if self.token_provider else None
if not token:
return {}
return {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"in": "header",
"key": "Authorization",
"value": f"Bearer {token}",
}
}


class Client:
"""Main client for interacting with the Aignostics Platform API.

Expand Down Expand Up @@ -92,28 +121,29 @@ def application(self, application_id: str) -> Application:

@staticmethod
def get_api_client(cache_token: bool = True) -> PublicApi:
"""Creates and configures an authenticated API client.
"""Create and configure an authenticated API client.

Args:
cache_token (bool): If True, caches the authentication token.
Defaults to True.

Returns:
ExternalsApi: Configured API client with authentication token.
PublicApi: Configured API client with authentication token.

Raises:
RuntimeError: If authentication fails.
"""
token = get_token(use_cache=cache_token)
config = Configuration(
host=settings().api_root,
ssl_ca_cert=os.getenv("REQUESTS_CA_BUNDLE"), # point to .cer file of proxy if defined

def token_provider() -> str:
return get_token(use_cache=cache_token)

ca_file = os.getenv("REQUESTS_CA_BUNDLE") # point to .cer file of proxy if defined
config = _OAuth2TokenProviderConfiguration(
host=settings().api_root, ssl_ca_cert=ca_file, token_provider=token_provider
)
config.proxy = getproxies().get("https") # use system proxy
client = ApiClient(
config,
header_name="Authorization",
header_value=f"Bearer {token}",
)
client.user_agent = user_agent()
return PublicApi(client)
51 changes: 51 additions & 0 deletions tests/aignostics/platform/client_token_provider_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for the token provider configuration and its integration with the client."""

from unittest.mock import Mock, patch

from aignostics.platform._client import Client, _OAuth2TokenProviderConfiguration


def test_oauth2_token_provider_configuration_uses_token_provider() -> None:
"""Test that token_provider is used when provided."""
token_provider = Mock(return_value="dynamic-token")
config = _OAuth2TokenProviderConfiguration(host="https://dummy", token_provider=token_provider)
auth = config.auth_settings()
assert auth["OAuth2AuthorizationCodeBearer"]["value"] == "Bearer dynamic-token"
token_provider.assert_called_once()


def test_oauth2_token_provider_configuration_no_token() -> None:
"""Test that auth_settings returns empty dict if no token_provider is set."""
config = _OAuth2TokenProviderConfiguration(host="https://dummy")
auth = config.auth_settings()
assert auth == {}


def test_client_passes_token_provider() -> None:
"""Test that the client passes the token provider to the configuration."""
with (
patch("aignostics.platform._client.get_token", return_value="client-token"),
patch("aignostics.platform._client.ApiClient") as api_client_mock,
patch("aignostics.platform._client.PublicApi") as public_api_mock,
):
Client(cache_token=False)
config_used = api_client_mock.call_args[0][0]
assert isinstance(config_used, _OAuth2TokenProviderConfiguration)
assert config_used.token_provider() == "client-token"
public_api_mock.assert_called()


def test_client_me_calls_api() -> None:
"""Test that the client.me() method calls the API and returns the result."""
with (
patch("aignostics.platform._client.get_token", return_value="client-token"),
patch("aignostics.platform._client.ApiClient"),
patch("aignostics.platform._client.PublicApi") as public_api_mock,
):
api_instance = Mock()
api_instance.get_me_v1_me_get.return_value = "me-info"
public_api_mock.return_value = api_instance
client = Client()
result = client.me()
assert result == "me-info"
api_instance.get_me_v1_me_get.assert_called_once()
Loading