diff --git a/src/aignostics/platform/_sdk_metadata.py b/src/aignostics/platform/_sdk_metadata.py index f980bc4fd..e5b296712 100644 --- a/src/aignostics/platform/_sdk_metadata.py +++ b/src/aignostics/platform/_sdk_metadata.py @@ -86,6 +86,7 @@ class GPUConfig(BaseModel): le=60 * 60, description="Maximum run duration in minutes when using FLEX_START provisioning mode (1-3600). " "Required when provisioning_mode is FLEX_START, must be None otherwise.", + exclude_if=lambda v: v is None, # Exclude from serialization if None ) @model_validator(mode="after") diff --git a/tests/aignostics/platform/sdk_metadata_test.py b/tests/aignostics/platform/sdk_metadata_test.py index 54b514135..a9637ba19 100644 --- a/tests/aignostics/platform/sdk_metadata_test.py +++ b/tests/aignostics/platform/sdk_metadata_test.py @@ -7,10 +7,14 @@ import pytest from pydantic import ValidationError +from aignostics.platform import DEFAULT_GPU_PROVISIONING_MODE, DEFAULT_GPU_TYPE, DEFAULT_MAX_GPUS_PER_SLIDE from aignostics.platform._sdk_metadata import ( ITEM_SDK_METADATA_SCHEMA_VERSION, SDK_METADATA_SCHEMA_VERSION, VALIDATION_CASE_TAG_PREFIX, + GPUConfig, + GPUType, + ProvisioningMode, ValidationCase, build_item_sdk_metadata, build_run_sdk_metadata, @@ -996,8 +1000,6 @@ def test_pipeline_config_defaults() -> None: """Test that pipeline configuration uses correct defaults.""" from aignostics.platform import ( DEFAULT_CPU_PROVISIONING_MODE, - DEFAULT_GPU_PROVISIONING_MODE, - DEFAULT_GPU_TYPE, DEFAULT_MAX_GPUS_PER_SLIDE, PipelineConfig, ) @@ -1255,3 +1257,30 @@ def test_metadata_with_invalid_validation_case_tag() -> None: with pytest.raises(ValidationError) as exc: validate_run_sdk_metadata(metadata) assert "validation_case" in str(exc.value) + + +class TestGPUConfig: + """Test cases for GPUConfig model.""" + + @pytest.mark.unit + @staticmethod + def test_model_dump_should_include_flex_start_max_duration_if_provided() -> None: + """Test that flex_start_max_run_duration_minutes is included in model dump if provided.""" + config = GPUConfig( + gpu_type=GPUType.L4, + provisioning_mode=ProvisioningMode.FLEX_START, + max_gpus_per_slide=DEFAULT_MAX_GPUS_PER_SLIDE, + flex_start_max_run_duration_minutes=1, + ) + assert "flex_start_max_run_duration_minutes" in config.model_dump() + + @pytest.mark.unit + @staticmethod + def test_model_dump_should_exclude_flex_start_max_duration_if_not_provided() -> None: + """Test that flex_start_max_run_duration_minutes is excluded in model dump if not provided.""" + config = GPUConfig( + gpu_type=GPUType.L4, + provisioning_mode=ProvisioningMode.SPOT, + max_gpus_per_slide=DEFAULT_MAX_GPUS_PER_SLIDE, + ) + assert "flex_start_max_run_duration_minutes" not in config.model_dump()