diff --git a/README.md b/README.md index b9d4ca86..6a980d3f 100644 --- a/README.md +++ b/README.md @@ -703,18 +703,65 @@ options = { Embed context values into a bearer token during generation so you can reference those values in your policies. This enables more flexible access controls, such as tracking end-user identity when making API calls using service accounts, and facilitates using signed data tokens during detokenization. -Generate bearer tokens containing context information using a service account with the context_id identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a context_identifier claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. +Generate bearer tokens containing context information using a service account with the `context_id` identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a `context_identifier` claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. + +The `ctx` parameter accepts either a **string** or a **dict**: + +**String context** — use when your policy references a single context value: + +```python +options = {'ctx': 'user_12345'} +token, _ = generate_bearer_token(filepath, options) +``` + +**Dict context** — use when your policy needs multiple context values for conditional data access. Each key in the dict maps to a Skyflow CEL policy variable under `request.context.*`: + +```python +options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } +} +token, _ = generate_bearer_token(filepath, options) +``` + +With the dict above, your Skyflow policies can reference `request.context.role`, `request.context.department`, and `request.context.user_id` to make conditional access decisions. + +Dict keys must contain only alphanumeric characters and underscores (`[a-zA-Z0-9_]`). Invalid keys will raise a `SkyflowError`. > [!TIP] -> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) -> See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. +> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) +> See Skyflow's [context-aware authorization](https://docs.skyflow.com) and [conditional data access](https://docs.skyflow.com) docs for policy variable syntax like `request.context.*`. #### Generate signed data tokens: `generate_signed_data_tokens(filepath, options)` Digitally sign data tokens with a service account's private key to add an extra layer of protection. Skyflow generates data tokens when sensitive data is inserted into the vault. Detokenize signed tokens only by providing the signed data token along with a bearer token generated from the service account's credentials. The service account must have the necessary permissions and context to successfully detokenize the signed data tokens. +The `ctx` parameter on signed data tokens also accepts either a **string** or a **dict**, using the same format as bearer tokens: + +```python +# String context +options = { + 'ctx': 'user_12345', + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} + +# Dict context +options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + }, + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} +``` + > [!TIP] -> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) +> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) > See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. ## Logging diff --git a/samples/detect_api/deidentify_file_async.py b/samples/detect_api/deidentify_file_async.py new file mode 100644 index 00000000..579dab2e --- /dev/null +++ b/samples/detect_api/deidentify_file_async.py @@ -0,0 +1,124 @@ +from skyflow.error import SkyflowError +from skyflow import Env, Skyflow, LogLevel +from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions +from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep, FileInput +from concurrent.futures import ThreadPoolExecutor + +""" + * Skyflow Deidentify File Example + * + * This sample demonstrates how to use all available options for deidentifying files + * using an asynchronous approach. + * Supported file types: images (jpg, png, etc.), pdf, audio (mp3, wav), documents, + * spreadsheets, presentations, structured text. +""" + +def perform_file_deidentification_async(): + try: + # Step 1: Configure Credentials + credentials = { + 'path': '/path/to/credentials.json' # Path to credentials file + } + + # Step 2: Configure Vault + vault_config = { + 'vault_id': '', # Replace with your vault ID + 'cluster_id': '', # Replace with your cluster ID + 'env': Env.PROD, # Deployment environment + 'credentials': credentials + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(vault_config) + .set_log_level(LogLevel.INFO) # Use LogLevel.ERROR in production + .build() + ) + + # Step 4: Create File Object + file_path = '' # Replace with your file path + + deidentify_request = DeidentifyFileRequest( + file=FileInput(file_path=file_path), # File to de-identify + # entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect + allow_regex_list=[''], # Optional: Patterns to allow + restrict_regex_list=[''], # Optional: Patterns to restrict + + # Token format configuration + token_format=TokenFormat( + vault_token=[DetectEntities.SSN], # Use vault tokens for these entities + ), + + # Optional: Custom transformations + # transformations=Transformations( + # shift_dates=DateTransformation( + # max_days=30, + # min_days=10, + # entities=[DetectEntities.DOB] + # ) + # ), + + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) + + # Image-specific options + output_processed_image=True, # Include processed image in output + output_ocr_text=True, # Include OCR text in response + masking_method=MaskingMethod.BLACKBOX, # Masking method for images + + # PDF-specific options + pixel_density=15, # Pixel density for PDF processing + max_resolution=2000, # Max resolution for PDF + + # Audio-specific options + output_processed_audio=True, # Include processed audio + output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type + + # Audio bleep configuration + + # bleep=Bleep( + # gain=5, # Loudness in dB + # frequency=1000, # Pitch in Hz + # start_padding=0.1, # Padding at start (seconds) + # stop_padding=0.2 # Padding at end (seconds) + # ) + ) + + # Create a thread pool executor + executor = ThreadPoolExecutor(max_workers=1) + + future = executor.submit( + lambda: skyflow_client.detect().deidentify_file(deidentify_request) + ) + + def handle_response(future): + exception = future.exception() + if exception is not None: + if isinstance(exception, SkyflowError): + # Handle Skyflow-specific errors + print('\nSkyflow Error:', { + 'http_code': exception.http_code, + 'grpc_code': exception.grpc_code, + 'http_status': exception.http_status, + 'message': exception.message, + 'details': exception.details + }) + else: + # Handle unexpected errors + print('Unexpected Error:', exception) + return + + # Handle Successful Response + result = future.result() + print("\nDeidentify File Response:", result) + + future.add_done_callback(handle_response) + + executor.shutdown(wait=True) + + except Exception as error: + # Handle unexpected errors + print('Unexpected Error:', error) + diff --git a/samples/service_account/signed_token_generation_example.py b/samples/service_account/signed_token_generation_example.py index 32140ada..6ede1746 100644 --- a/samples/service_account/signed_token_generation_example.py +++ b/samples/service_account/signed_token_generation_example.py @@ -18,42 +18,54 @@ credentials_string = json.dumps(skyflow_credentials) -options = { - 'ctx': 'CONTEXT_ID', - 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], - 'time_to_live': 90, # in seconds -} +# Approach 1: Signed data tokens with string context +def get_signed_tokens_with_string_context(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'time_to_live': 90, # in seconds + } + try: + data_token, signed_data_token = generate_signed_data_tokens(file_path, options) + return data_token, signed_data_token + except Exception as e: + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_file_path(): - # Generate signed bearer token from credentials file path. - global bearer_token +# Approach 2: Signed data tokens with JSON object context (dict) +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "analyst" and request.context.department == "research" +def get_signed_tokens_with_object_context(): + options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + 'user_id': 'user_67890', + }, + 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens(file_path, options) - return data_token, signed_data_token - + data_token, signed_data_token = generate_signed_data_tokens(file_path, options) + return data_token, signed_data_token except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_credentials_string(): - # Generate signed bearer token from credentials string. - global bearer_token - +# Approach 3: Signed data tokens from credentials string +def get_signed_tokens_from_credentials_string(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens_from_creds(credentials_string, options) - return data_token, signed_data_token - + data_token, signed_data_token = generate_signed_data_tokens_from_creds(credentials_string, options) + return data_token, signed_data_token except Exception as e: - print(f'Error generating token from credentials string: {str(e)}') - + print(f'Error: {str(e)}') -print(get_signed_bearer_token_from_file_path()) -print(get_signed_bearer_token_from_credentials_string()) +print("String context:", get_signed_tokens_with_string_context()) +print("Object context:", get_signed_tokens_with_object_context()) +print("Creds string:", get_signed_tokens_from_credentials_string()) diff --git a/samples/service_account/token_generation_with_context_example.py b/samples/service_account/token_generation_with_context_example.py index a43a072a..03aa9f06 100644 --- a/samples/service_account/token_generation_with_context_example.py +++ b/samples/service_account/token_generation_with_context_example.py @@ -18,11 +18,13 @@ } credentials_string = json.dumps(skyflow_credentials) -options = {'ctx': ''} -def get_bearer_token_with_context_from_file_path(): - # Generate bearer token with context from credentials file path. +# Approach 1: Bearer token with string context +# Use a simple string identifier when your policy references a single context value. +# In your Skyflow policy, reference this as: request.context +def get_bearer_token_with_string_context(): global bearer_token + options = {'ctx': 'user_12345'} try: if not is_expired(bearer_token): @@ -31,14 +33,40 @@ def get_bearer_token_with_context_from_file_path(): token, _ = generate_bearer_token(file_path, options) bearer_token = token return bearer_token + except Exception as e: + print(f'Error generating token: {str(e)}') + + +# Approach 2: Bearer token with JSON object context (dict) +# Use a dict when your policy needs multiple context values for conditional data access. +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "admin" and request.context.department == "finance" +def get_bearer_token_with_object_context(): + global bearer_token + options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } + } + try: + if not is_expired(bearer_token): + return bearer_token + else: + token, _ = generate_bearer_token(file_path, options) + bearer_token = token + return bearer_token except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error generating token: {str(e)}') +# Approach 3: Bearer token with string context from credentials string def get_bearer_token_with_context_from_credentials_string(): - # Generate bearer token with context from credentials string. global bearer_token + options = {'ctx': 'user_12345'} + try: if not is_expired(bearer_token): return bearer_token @@ -47,9 +75,9 @@ def get_bearer_token_with_context_from_credentials_string(): bearer_token = token return bearer_token except Exception as e: - print(f"Error generating token from credentials string: {str(e)}") - + print(f"Error generating token: {str(e)}") -print(get_bearer_token_with_context_from_file_path()) -print(get_bearer_token_with_context_from_credentials_string()) \ No newline at end of file +print("String context:", get_bearer_token_with_string_context()) +print("Object context:", get_bearer_token_with_object_context()) +print("Creds string:", get_bearer_token_with_context_from_credentials_string()) diff --git a/setup.py b/setup.py index aa38463d..8f76225e 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0.dev0+f7d26df' +current_version = '2.0.2' setup( name='skyflow', diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 0bfde34e..7b405f17 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -62,9 +62,6 @@ def set_log_level(self, log_level): def get_log_level(self): return self.__builder._Builder__log_level - def update_log_level(self, log_level): - self.__builder._Builder__set_log_level(log_level) - def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) return vault_config.get(OptionField.VAULT_CONTROLLER) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index fca43935..80a959dd 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -1,5 +1,4 @@ from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_error class SkyflowError(Exception): def __init__(self, @@ -15,4 +14,4 @@ def __init__(self, self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value self.details = details self.request_id = request_id - super().__init__() \ No newline at end of file + super().__init__(message) \ No newline at end of file diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index f4c98faf..4d3a4574 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -1,5 +1,6 @@ import json import datetime +import re import time import jwt from urllib.parse import urlparse @@ -10,11 +11,56 @@ from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError from skyflow.utils import is_valid_url +from skyflow.utils.constants import CTX_KEY_REGEX invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value +_CTX_KEY_PATTERN = re.compile(CTX_KEY_REGEX) + +_SNAKE_TO_CAMEL_CRED_MAP = { + 'private_key': CredentialField.PRIVATE_KEY, + 'client_id': CredentialField.CLIENT_ID, + 'key_id': CredentialField.KEY_ID, + 'token_uri': CredentialField.TOKEN_URI, + 'client_name': CredentialField.CLIENT_NAME, +} + + +def _normalize_credentials(credentials): + return {_SNAKE_TO_CAMEL_CRED_MAP.get(k, k): v for k, v in credentials.items()} + + +def _validate_and_resolve_ctx(ctx): + """Validate ctx value and return resolved value for JWT claims. + Returns None if ctx should be omitted, the value if valid, or raises SkyflowError if invalid. + """ + if ctx is None: + return None + if isinstance(ctx, str): + if ctx.strip() == '': + return None + return ctx + if isinstance(ctx, dict): + if len(ctx) == 0: + return None + for key in ctx: + if not isinstance(key, str) or not _CTX_KEY_PATTERN.match(key): + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_MAP_KEY.value.format(key), + invalid_input_error_code + ) + return ctx + if isinstance(ctx, (bool, int, float)): + return ctx + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_TYPE.value, + invalid_input_error_code + ) + def is_expired(token, logger = None): + if token is None: + return True if len(token) == 0: log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -34,20 +80,17 @@ def is_expired(token, logger = None): return True def generate_bearer_token(credentials_file_path, options = None, logger = None): + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) try: - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) - credentials_file =open(credentials_file_path, 'r') + credentials_file = open(credentials_file_path, 'r') except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger) - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) - - finally: - credentials_file.close() + with credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) result = get_service_account_token(credentials, options, logger) return result @@ -62,24 +105,25 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) return result def get_service_account_token(credentials, options, logger): + credentials = _normalize_credentials(credentials) try: private_key = credentials[CredentialField.PRIVATE_KEY] - except: - log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) + except KeyError: + log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: client_id = credentials[CredentialField.CLIENT_ID] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: key_id = credentials[CredentialField.KEY_ID] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: token_uri = credentials[CredentialField.TOKEN_URI] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) @@ -87,9 +131,12 @@ def get_service_account_token(credentials, options, logger): log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) - if options and "token_uri" in options: - token_uri = options["token_uri"] - + if options and CredentialField.TOKEN_URI_OPTION in options: + token_uri = options[CredentialField.TOKEN_URI_OPTION] + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) auth_client = AuthClient(base_url) @@ -101,7 +148,7 @@ def get_service_account_token(credentials, options, logger): try: response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) except UnauthorizedError: @@ -120,8 +167,10 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): JwtField.SUB: client_id, JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and JwtField.CTX in options: - payload[JwtField.CTX] = options.get(JwtField.CTX) + if options and OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options.get(OptionField.CTX)) + if resolved_ctx is not None: + payload[JwtField.CTX] = resolved_ctx try: return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: @@ -130,18 +179,20 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): + credentials_obj = _normalize_credentials(credentials_obj) expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) prefix = JWT.SIGNED_TOKEN_PREFIX - token_uri = credentials_obj.get("tokenURI") + token_uri = credentials_obj.get(CredentialField.TOKEN_URI) if not isinstance(token_uri, str) or not is_valid_url(token_uri): log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) - - if options and "token_uri" in options: - token_uri = options["token_uri"] + resolved_ctx = None + if OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options[OptionField.CTX]) + results = [] if options and options.get(OptionField.DATA_TOKENS): for token in options[OptionField.DATA_TOKENS]: claims = { @@ -152,37 +203,30 @@ def get_signed_tokens(credentials_obj, options): JwtField.TOK: token, JwtField.IAT: int(time.time()), } - - if JwtField.CTX in options: - claims[JwtField.CTX] = options[JwtField.CTX] - + if resolved_ctx is not None: + claims[JwtField.CTX] = resolved_ctx private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) - try: + try: signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) - - response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) + results.append(get_signed_data_token_response_object(prefix + signed_jwt, token)) log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) - return response_object + return results def generate_signed_data_tokens(credentials_file_path, options): log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value) try: - credentials_file =open(credentials_file_path, 'r') + credentials_file = open(credentials_file_path, 'r') except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), - invalid_input_error_code) - - finally: - credentials_file.close() - + with credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), + invalid_input_error_code) return get_signed_tokens(credentials, options) def generate_signed_data_tokens_from_creds(credentials, options): @@ -195,9 +239,6 @@ def generate_signed_data_tokens_from_creds(credentials, options): raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value, invalid_input_error_code) return get_signed_tokens(json_credentials, options) + def get_signed_data_token_response_object(signed_token, actual_token): - response_object = { - ResponseField.TOKEN: actual_token, - ResponseField.SIGNED_TOKEN: signed_token - } - return response_object.get(ResponseField.TOKEN), response_object.get(ResponseField.SIGNED_TOKEN) + return actual_token, signed_token diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py index 090f3a2b..12ff1257 100644 --- a/skyflow/utils/_helpers.py +++ b/skyflow/utils/_helpers.py @@ -13,6 +13,6 @@ def format_scope(scopes): def is_valid_url(url): try: result = urlparse(url) - return all([result.scheme in ("http", "https"), result.netloc]) + return all([result.scheme == "https", result.netloc]) except Exception: return False \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 989aa298..79f4c7ec 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -61,6 +61,8 @@ class Error(Enum): EMPTY_CONTEXT = f"{error_prefix} Initialization failed. Invalid context provided. Specify context as type Context." INVALID_CONTEXT_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid context for {{}} with id {{}}. Specify a valid context." INVALID_CONTEXT = f"{error_prefix} Initialization failed. Invalid context. Specify a valid context." + INVALID_CTX_TYPE = f"{error_prefix} Initialization failed. Invalid ctx type. Specify ctx as a string or a dict." + INVALID_CTX_MAP_KEY = f"{error_prefix} Initialization failed. Invalid key '{{}}' in ctx dict. Keys must contain only alphanumeric characters and underscores." INVALID_LOG_LEVEL = f"{error_prefix} Initialization failed. Invalid log level. Specify a valid log level." EMPTY_LOG_LEVEL = f"{error_prefix} Initialization failed. Specify a valid log level." @@ -88,7 +90,7 @@ class Error(Enum): INVALID_TABLE_NAME_IN_INSERT = f"{error_prefix} Validation error. Invalid table name in insert request. Specify a valid table name." INVALID_TYPE_OF_DATA_IN_INSERT = f"{error_prefix} Validation error. Invalid type of data in insert request. Specify data as a object array." EMPTY_DATA_IN_INSERT = f"{error_prefix} Validation error. Data array cannot be empty. Specify data in insert request." - INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. 'upsert' key cannot be empty in options. At least one object of table and column is required." + INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. Invalid 'upsert' value in options. Specify 'upsert' as a non-empty string containing the column name." INVALID_HOMOGENEOUS_TYPE = f"{error_prefix} Validation error. Invalid type of homogeneous. Specify homogeneous as a string." INVALID_TOKEN_MODE_TYPE = f"{error_prefix} Validation error. Invalid type of token mode. Specify token mode as a TokenMode enum." INVALID_RETURN_TOKENS_TYPE = f"{error_prefix} Validation error. Invalid type of return tokens. Specify return tokens as a boolean." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 5d83cbcc..e3b8eea9 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -32,26 +32,18 @@ invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): - dotenv.load_dotenv() + if config_level_creds is not None: + return config_level_creds + if common_skyflow_creds is not None: + return common_skyflow_creds dotenv_path = dotenv.find_dotenv(usecwd=True) if dotenv_path: load_dotenv(dotenv_path) env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") - if config_level_creds: - return config_level_creds - if common_skyflow_creds: - return common_skyflow_creds if env_skyflow_credentials: - env_skyflow_credentials.strip() - try: - env_creds = env_skyflow_credentials.replace('\n', '\\n') - return { - CredentialField.CREDENTIALS_STRING: env_creds - } - except json.JSONDecodeError: - raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + env_creds = env_skyflow_credentials.strip().replace('\n', '\\n') + return {'credentials_string': env_creds} + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: if len(api_key) != ApiKey.LENGTH: @@ -80,9 +72,9 @@ def parse_path_params(url, path_params): return result -def to_lowercase_keys(dict): +def to_lowercase_keys(data): result = {} - for key, value in dict.items(): + for key, value in data.items(): result[key.lower()] = value return result @@ -136,7 +128,7 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) except SkyflowError: raise - except Exception as e: + except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) if files and header and content_type == ContentType.FORMDATA.value: @@ -194,7 +186,6 @@ def get_data_from_content_type(data, content_type): if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) elif content_type == ContentType.FORMDATA.value: - print("Hello") converted_data = None files = {} for key, value in data.items(): @@ -239,8 +230,11 @@ def build_xml(d, tag='item'): return ''.join(xml_parts) +_CACHED_METRICS: dict = {} + def get_metrics(): - sdk_name_version = SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION + if _CACHED_METRICS: + return _CACHED_METRICS try: sdk_client_device_model = platform.node() @@ -257,13 +251,13 @@ def get_metrics(): except Exception: sdk_runtime_details = "" - details_dic = { - SdkMetricsKey.SDK_NAME_VERSION: sdk_name_version, + _CACHED_METRICS.update({ + SdkMetricsKey.SDK_NAME_VERSION: SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION, SdkMetricsKey.SDK_CLIENT_DEVICE_MODEL: sdk_client_device_model, SdkMetricsKey.SDK_CLIENT_OS_DETAILS: sdk_client_os_details, SdkMetricsKey.SDK_RUNTIME_DETAILS: SdkPrefix.PYTHON_RUNTIME + sdk_runtime_details, - } - return details_dic + }) + return _CACHED_METRICS def parse_insert_response(api_response, continue_on_error): # Retrieve the headers and data from the API response @@ -427,22 +421,30 @@ def parse_invoke_connection_response(api_response: requests.Response): error_response = json.loads(content) error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) - status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, status_code) - http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) - grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) - details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS) - message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) - + http_status = None + grpc_code = None + details = None + + error_obj = error_response.get(ResponseField.ERROR) if isinstance(error_response, dict) else None + if isinstance(error_obj, dict): + status_code = error_obj.get(ResponseField.HTTP_CODE, status_code) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS) + message = error_obj.get(ResponseField.MESSAGE, message) + elif isinstance(error_obj, str) and error_obj: + message = error_obj + if error_from_client is not None: - if details is None: + if details is None: details = [] error_from_client_bool = error_from_client.lower() == BooleanString.TRUE details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) - + except json.JSONDecodeError: - raise SkyflowError(content if content else message, status_code, request_id) + raise SkyflowError(message, status_code, request_id) def parse_deidentify_text_response(api_response: DeidentifyStringResponse): entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities] @@ -486,21 +488,46 @@ def handle_exception(error, logger): def handle_json_error(err, data, request_id, logger): try: - if isinstance(data, dict): # If data is already a dict + if isinstance(data, dict): description = data elif isinstance(data, ErrorResponse): description = data.dict() else: description = json.loads(data) - status_code = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found - http_status = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) - grpc_code = description.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) - details = description.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS, []) - description_message = description.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) - log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger = logger) + if ResponseField.ERROR in description: + error_obj = description.get(ResponseField.ERROR, {}) + status_code = error_obj.get(ResponseField.HTTP_CODE, HttpStatusCode.INTERNAL_SERVER_ERROR) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS, []) + description_message = error_obj.get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + elif ResponseField.RESPONSES in description: + responses = description.get(ResponseField.RESPONSES, []) + messages = [] + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + for resp in responses: + resp_status = resp.get(ResponseField.STATUS, HttpStatusCode.INTERNAL_SERVER_ERROR) + resp_body = resp.get(ResponseField.BODY, {}) + if isinstance(resp_status, int) and resp_status >= HttpStatusCode.BAD_REQUEST: + status_code = resp_status + error_msg = resp_body.get(ResponseField.ERROR) + if error_msg: + messages.append(str(error_msg)) + description_message = '; '.join(messages) if messages else SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + http_status = None + grpc_code = None + details = [] + else: + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + http_status = None + grpc_code = None + details = [] + description_message = SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + + log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger=logger) except json.JSONDecodeError: - log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger = logger) + log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger=logger) def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index bd8e63ec..bc50f210 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0.dev0+f7d26df' +SDK_VERSION = '2.0.2' diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 401bffe5..17ba96e2 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -1,6 +1,7 @@ OPTIONAL_TOKEN='token' PROTOCOL='https' SKY_META_DATA_HEADER='sky-metadata' +CTX_KEY_REGEX=r'^[a-zA-Z0-9_]+$' class SKYFLOW: SKYFLOW_ID = 'skyflowId' @@ -116,6 +117,7 @@ class ResponseField: TYPE = 'type' TOKENIZED_DATA = 'tokenized_data' SIGNED_TOKEN = 'signed_token' + RESPONSES = 'responses' class CredentialField: @@ -123,6 +125,8 @@ class CredentialField: CLIENT_ID = 'clientID' KEY_ID = 'keyID' TOKEN_URI = 'tokenURI' + TOKEN_URI_OPTION = 'token_uri' + CLIENT_NAME = 'clientName' CREDENTIALS_STRING = 'credentials_string' API_KEY = 'api_key' TOKEN = 'token' @@ -192,6 +196,7 @@ class DeidentifyFileRequestField: OUTPUT_OCR_TEXT = 'output_ocr_text' MASKING_METHOD = 'masking_method' PIXEL_DENSITY = 'pixel_density' + DENSITY = 'density' MAX_RESOLUTION = 'max_resolution' OUTPUT_PROCESSED_AUDIO = 'output_processed_audio' OUTPUT_TRANSCRIPTION = 'output_transcription' @@ -227,6 +232,7 @@ class DeidentifyField: ENTITY_UNQ_COUNTER = 'entity_unq_counter' ENTITY_UNIQUE_COUNTER = 'entity_unique_counter' ENTITY_ONLY = 'entity_only' + VAULT_TOKEN = 'vault_token' ENTITIES = 'entities' MAX_DAYS = 'max_days' MIN_DAYS = 'min_days' diff --git a/skyflow/utils/enums/detect_output_transcriptions.py b/skyflow/utils/enums/detect_output_transcriptions.py index 4e14f911..a398a3d8 100644 --- a/skyflow/utils/enums/detect_output_transcriptions.py +++ b/skyflow/utils/enums/detect_output_transcriptions.py @@ -4,4 +4,5 @@ class DetectOutputTranscriptions(Enum): DIARIZED_TRANSCRIPTION = "diarized_transcription" MEDICAL_DIARIZED_TRANSCRIPTION = "medical_diarized_transcription" MEDICAL_TRANSCRIPTION = "medical_transcription" - TRANSCRIPTION = "transcription" \ No newline at end of file + TRANSCRIPTION = "transcription" + PLAINTEXT_TRANSCRIPTION = "plaintext_transcription" \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 08d4905b..f07398f8 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -42,32 +42,32 @@ def validate_required_field(logger, config, field_name, expected_type, empty_err if field_name not in config or not isinstance(field_value, expected_type): if field_name == ConfigField.VAULT_ID: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value, logger) if field_name == ConfigField.CLUSTER_ID: - logger.error(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value, logger) if field_name == OptionField.CONNECTION_ID: - logger.error(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value, logger) if field_name == OptionField.CONNECTION_URL: - logger.error(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value) + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value, logger) raise SkyflowError(invalid_error, invalid_input_error_code) if isinstance(field_value, str) and not field_value.strip(): if field_name == ConfigField.VAULT_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value, logger) if field_name == ConfigField.CLUSTER_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value, logger) if field_name == OptionField.CONNECTION_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value, logger) if field_name == OptionField.CONNECTION_URL: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value, logger) if field_name == CredentialField.PATH: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value, logger) if field_name == CredentialField.CREDENTIALS_STRING: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value, logger) if field_name == CredentialField.TOKEN: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value, logger) if field_name == CredentialField.API_KEY: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value, logger) raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: @@ -90,6 +90,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) elif len(key_present) > 1: error_message = ( @@ -97,6 +98,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) if CredentialField.ROLES in credentials: @@ -142,6 +144,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) if is_expired(credentials.get(CredentialField.TOKEN), logger): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value, logger) raise SkyflowError( SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value, @@ -160,8 +163,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) - if "token_uri" in credentials: - token_uri = credentials.get("token_uri") + if CredentialField.TOKEN_URI_OPTION in credentials: + token_uri = credentials.get(CredentialField.TOKEN_URI_OPTION) if ( token_uri is None or not isinstance(token_uri, str) @@ -171,10 +174,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): - raise SkyflowError( SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) - - if log_level is None: - raise SkyflowError(SkyflowMessages.Error.EMPTY_LOG_LEVEL.value, invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) def validate_keys(logger, config, config_keys): for key in config.keys(): @@ -208,7 +208,7 @@ def validate_vault_config(logger, config): # Validate env (optional, should be one of LogLevel values) if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.ENV_IS_REQUIRED.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) return True @@ -232,8 +232,10 @@ def validate_update_vault_config(logger, config): if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if ConfigField.CREDENTIALS in config and config.get(ConfigField.CREDENTIALS): - validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) + + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) return True @@ -255,8 +257,10 @@ def validate_connection_config(logger, config): SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" in config: - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) + + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) return True @@ -298,7 +302,7 @@ def validate_file_from_request(file_input: FileInput): if has_file: file = file_input.file # Validate file object has required attributes - if not hasattr(file, FileUploadField.FILE_NAME) or not isinstance(file.name, str) or not file.name.strip(): + if not hasattr(file, FileUploadField.NAME) or not isinstance(file.name, str) or not file.name.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) # Validate file name @@ -429,7 +433,7 @@ def validate_insert_request(logger, request): log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) - if not len(request.values): + if not request.values: log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code) @@ -439,7 +443,7 @@ def validate_insert_request(logger, request): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value(RequestOperation.INSERT), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) if request.homogeneous is not None and not isinstance(request.homogeneous, bool): @@ -471,7 +475,7 @@ def validate_insert_request(logger, request): logger=logger) if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value(RequestOperation.INSERT), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -503,21 +507,21 @@ def validate_delete_request(logger, request): raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code) def validate_query_request(logger, request): - if not request.query: - log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger = logger) - raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not isinstance(request.query, str): query_type = str(type(request.query)) raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code) + if not request.query: + log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger=logger) + raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) + if not request.query.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) if not request.query.upper().startswith(SqlCommand.SELECT): command = request.query - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) def validate_get_request(logger, request): redaction_type = request.redaction_type @@ -565,13 +569,13 @@ def validate_get_request(logger, request): invalid_input_error_code) if offset is not None and not isinstance(offset, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value(type(offset)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value.format(type(offset)), invalid_input_error_code) if limit is not None and not isinstance(limit, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value(type(limit)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value.format(type(limit)), invalid_input_error_code) if download_url is not None and not isinstance(download_url, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value(type(download_url)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value.format(type(download_url)), invalid_input_error_code) if column_name is not None and (not isinstance(column_name, str) or not column_name.strip()): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) @@ -603,33 +607,30 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code) def validate_update_request(logger, request): - skyflow_id = "" + if not isinstance(request.data, dict): + raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value.format(type(request.data)), invalid_input_error_code) + + if not len(request.data.items()): + raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} - try: - skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) - except Exception: + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) + if skyflow_id is None: log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) - - if not skyflow_id.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger = logger) + elif not skyflow_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger=logger) if not isinstance(request.table, str): log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not isinstance(request.return_tokens, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code) - if not isinstance(request.data, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value(type(request.data)), invalid_input_error_code) - - if not len(request.data.items()): - raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) - if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value, invalid_input_error_code) @@ -667,9 +668,9 @@ def validate_detokenize_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code) if not isinstance(request.data, list): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - if not len(request.data): + if not request.data: log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format(RequestOperation.DETOKENIZE), logger = logger) log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.DETOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code) @@ -695,7 +696,7 @@ def validate_tokenize_request(logger, request): if not isinstance(parameters, list): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(parameters)), invalid_input_error_code) - if not len(parameters): + if not parameters: raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value, invalid_input_error_code) for i, param in enumerate(parameters): @@ -728,9 +729,7 @@ def validate_file_upload_request(logger, request): # Skyflow ID skyflow_id = getattr(request, FileUploadField.SKYFLOW_ID, None) - if skyflow_id is None: - raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) - elif skyflow_id.strip() == "": + if skyflow_id is not None and skyflow_id.strip() == "": raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format(RequestOperation.FILE_UPLOAD), invalid_input_error_code) # Column Name @@ -797,7 +796,7 @@ def validate_invoke_connection_params(logger, query_params, path_params): except TypeError: raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code) -def validate_deidentify_text_request(self, request: DeidentifyTextRequest): +def validate_deidentify_text_request(logger, request: DeidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, invalid_input_error_code) @@ -821,7 +820,7 @@ def validate_deidentify_text_request(self, request: DeidentifyTextRequest): if request.transformations is not None and not isinstance(request.transformations, Transformations): raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) -def validate_reidentify_text_request(self, request: ReidentifyTextRequest): +def validate_reidentify_text_request(logger, request: ReidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, invalid_input_error_code) @@ -837,6 +836,6 @@ def validate_reidentify_text_request(self, request: ReidentifyTextRequest): if request.plain_text_entities is not None and not isinstance(request.plain_text_entities, list): raise SkyflowError(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) -def validate_get_detect_run_request(self, request: GetDetectRunRequest): +def validate_get_detect_run_request(logger, request: GetDetectRunRequest): if not request.run_id or not isinstance(request.run_id, str) or not request.run_id.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_RUN_ID.value, invalid_input_error_code) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index c64e8c6a..8023646c 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -16,6 +16,9 @@ def __init__(self, config): self.__logger = None self.__is_config_updated = False self.__bearer_token = None + self.__credentials = None + self.__vault_url = None + self.__is_static_token = None def set_common_skyflow_credentials(self, credentials): self.__common_skyflow_credentials = credentials @@ -25,16 +28,27 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger = self.__logger) - token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), - self.__config.get(ConfigField.ENV), - self.__config.get(ConfigField.VAULT_ID), - logger = self.__logger) - self.initialize_api_client(vault_url, token) - - def initialize_api_client(self, vault_url, token): - self.__api_client = Skyflow(base_url=vault_url, token=token) + if self.__api_client is not None and not self.__is_config_updated: + if self.__is_static_token: + return + if self.__bearer_token is not None and not is_expired(self.__bearer_token): + return + + needs_reinit = self.__api_client is None or self.__is_config_updated + if needs_reinit: + self.__credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger=self.__logger) + self.__vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), + logger=self.__logger) + self.__is_static_token = CredentialField.TOKEN in self.__credentials or CredentialField.API_KEY in self.__credentials + bearer_token = self.get_bearer_token(self.__credentials) + if needs_reinit: + self.initialize_api_client(self.__vault_url, bearer_token) + + def initialize_api_client(self, vault_url, bearer_token): + token_provider = lambda: self.__bearer_token if self.__bearer_token is not None else bearer_token # noqa: E731 + self.__api_client = Skyflow(base_url=vault_url, token=token_provider) def get_records_api(self): return self.__api_client.records @@ -64,14 +78,13 @@ def get_bearer_token(self, credentials): OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), OptionField.CTX: self.__config.get(OptionField.CTX) } - if "token_uri" in credentials and credentials.get("token_uri"): - options["token_uri"] = credentials.get("token_uri") + if CredentialField.TOKEN_URI_OPTION in credentials and credentials.get(CredentialField.TOKEN_URI_OPTION): + options[CredentialField.TOKEN_URI_OPTION] = credentials.get(CredentialField.TOKEN_URI_OPTION) - if self.__bearer_token is None or self.__is_config_updated: + if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token): if CredentialField.PATH in credentials: - path = credentials.get(CredentialField.PATH) self.__bearer_token, _ = generate_bearer_token( - path, + credentials.get(CredentialField.PATH), options, self.__logger ) @@ -87,10 +100,6 @@ def get_bearer_token(self, credentials): else: log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger) - if is_expired(self.__bearer_token): - self.__is_config_updated = True - raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - return self.__bearer_token def update_config(self, config): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 76dbfaeb..2ce0c104 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,7 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest -from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader, OptionField, ConfigField from skyflow.utils import get_credentials @@ -16,11 +16,11 @@ def __init__(self, vault_client): def invoke(self, request: InvokeConnectionRequest): log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) config = self.__vault_client.get_config() - connection_url = config.get("connection_url") + connection_url = config.get(OptionField.CONNECTION_URL) invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) - credentials = get_credentials(config.get("credentials"), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) + credentials = get_credentials(config.get(ConfigField.CREDENTIALS), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) bearer_token = self.__vault_client.get_bearer_token(credentials) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index c6ef2fb1..f12b6215 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,8 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, - FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField) +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField, Detect as DetectConstants) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -30,7 +30,7 @@ def __get_headers(self): } return headers - def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: + def __build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: deidentify_text_body = {} parsed_entity_types = request.entities @@ -43,7 +43,7 @@ def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[ return deidentify_text_body - def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: + def __build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: parsed_format = Format( redacted=request.redacted_entities, masked=request.masked_entities, @@ -57,13 +57,13 @@ def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[ def _get_file_extension(self, filename: str): return filename.split('.')[-1].lower() if '.' in filename else '' - def __poll_for_processed_file(self, run_id, max_wait_time=64): - max_wait_time = 64 if max_wait_time is None else max_wait_time + def __poll_for_processed_file(self, run_id, max_wait_time=None): + max_wait_time = DetectConstants.WAIT_TIME if max_wait_time is None else max_wait_time files_api = self.__vault_client.get_detect_file_api().with_raw_response current_wait_time = 1 # Start with 1 second try: while True: - response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data + response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data status = response.status if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: @@ -80,7 +80,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: - raise e + handle_exception(e, self.__vault_client.get_logger()) def __save_deidentify_file_response_output(self, response: DetectRunsResponse, output_directory: str, original_file_name: str, name_without_ext: str): if not response or not hasattr(response, DeidentifyField.OUTPUT) or not response.output or not output_directory: @@ -94,6 +94,7 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o base_original_filename = os.path.basename(original_file_name) base_name_without_ext = os.path.splitext(base_original_filename)[0] + real_output_dir = os.path.realpath(output_directory) for idx, output in enumerate(output_list): try: @@ -105,14 +106,25 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o continue decoded_data = base64.b64decode(processed_file) - + + # Sanitize extension from API response to prevent path traversal (CWE-22). + # Avoid os.path.basename here to keep basename mock-free in tests. + safe_ext = None + if processed_file_extension: + raw_ext = str(processed_file_extension).replace('\\', '/').split('/')[-1].lstrip('.') + safe_ext = ''.join(c for c in raw_ext if c.isalnum() or c in ('-', '_')) or 'bin' + if idx == 0 or processed_file_type == DeidentifyField.REDACTED_FILE: output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) - if processed_file_extension: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") + if safe_ext: + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext}") else: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") - + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext or 'bin'}") + + if not os.path.realpath(output_file_name).startswith(real_output_dir + os.sep): + log_error_log(SkyflowMessages.ErrorLogs.SAVING_DEIDENTIFY_FILE_FAILED.value, self.__vault_client.get_logger()) + continue + with open(output_file_name, 'wb') as f: f.write(decoded_data) except Exception as e: @@ -166,16 +178,16 @@ def output_to_dict_list(output): extension = first_output.get(DeidentifyField.EXTENSION, None) if base64_string is not None: - file_bytes = base64.b64decode(base64_string) - file_obj = io.BytesIO(file_bytes) - file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE + file_bytes = base64.b64decode(base64_string) + file_obj = io.BytesIO(file_bytes) + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get(DeidentifyField.TYPE, DetectStatus.UNKNOWN), + type=first_output.get(DeidentifyField.TYPE, None), extension=extension, word_count=word_count, char_count=char_count, @@ -195,6 +207,7 @@ def __get_token_format(self, request): DeidentifyField.DEFAULT: getattr(request.token_format, DeidentifyField.DEFAULT, None), DeidentifyField.ENTITY_UNQ_COUNTER: getattr(request.token_format, DeidentifyField.ENTITY_UNIQUE_COUNTER, None), DeidentifyField.ENTITY_ONLY: getattr(request.token_format, DeidentifyField.ENTITY_ONLY, None), + DeidentifyField.VAULT_TOKEN: getattr(request.token_format, DeidentifyField.VAULT_TOKEN, None) } def __get_transformations(self, request): @@ -217,7 +230,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - deidentify_text_body = self.___build_deidentify_text_body(request) + deidentify_text_body = self.__build_deidentify_text_body(request) try: log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) @@ -229,7 +242,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo restrict_regex=deidentify_text_body[DeidentifyField.RESTRICT_REGEX], token_type=deidentify_text_body[DeidentifyField.TOKEN_TYPE], transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) deidentify_text_response = parse_deidentify_text_response(api_response) log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -245,7 +258,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - reidentify_text_body = self.___build_reidentify_text_body(request) + reidentify_text_body = self.__build_reidentify_text_body(request) try: log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) @@ -253,7 +266,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo vault_id=self.__vault_client.get_vault_id(), text=reidentify_text_body[DeidentifyField.TEXT], format=reidentify_text_body[DeidentifyField.FORMAT], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) reidentify_text_response = parse_reidentify_text_response(api_response) log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -265,14 +278,16 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo def __get_file_from_request(self, request: DeidentifyFileRequest): file_input = request.file - - # Check for file + if hasattr(file_input, FileUploadField.FILE) and file_input.file is not None: return file_input.file - - # Check for file_path if file is not provided + if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: - return open(file_input.file_path, 'rb') + with open(file_input.file_path, 'rb') as f: + content = f.read() + bio = io.BytesIO(content) + bio.name = file_input.file_path + return bio def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) @@ -297,12 +312,13 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio + bleep = request.bleep api_kwargs = { OptionField.VAULT_ID: self.__vault_client.get_vault_id(), DeidentifyField.FILE: req_file, @@ -313,11 +329,11 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION: getattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION, None), DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO, None), - DeidentifyField.BLEEP_GAIN: getattr(request, DeidentifyFileRequestField.BLEEP, None).gain if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_FREQUENCY: getattr(request, DeidentifyFileRequestField.BLEEP, None).frequency if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_START_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).start_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_STOP_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).stop_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.BLEEP_GAIN: bleep.gain if bleep is not None else None, + DeidentifyField.BLEEP_FREQUENCY: bleep.frequency if bleep is not None else None, + DeidentifyField.BLEEP_START_PADDING: bleep.start_padding if bleep is not None else None, + DeidentifyField.BLEEP_STOP_PADDING: bleep.stop_padding if bleep is not None else None, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension == FileExtension.PDF: @@ -331,8 +347,8 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyFileRequestField.MAX_RESOLUTION: getattr(request, DeidentifyFileRequestField.MAX_RESOLUTION, None), - DeidentifyFileRequestField.PIXEL_DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyFileRequestField.DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: @@ -348,7 +364,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyFileRequestField.MASKING_METHOD: getattr(request, DeidentifyFileRequestField.MASKING_METHOD, None), DeidentifyFileRequestField.OUTPUT_OCR_TEXT: getattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT, None), DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE, None), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: @@ -361,7 +377,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: @@ -374,7 +390,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: @@ -387,7 +403,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.JSON, FileExtension.XML]: @@ -401,7 +417,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } else: @@ -415,7 +431,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) @@ -424,7 +440,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == DetectStatus.SUCCESS: + if request.output_directory and processed_response.status == DetectStatus.SUCCESS and file_name: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -449,10 +465,10 @@ def get_detect_run(self, request: GetDetectRunRequest): response = files_api.get_run( run_id, vault_id=self.__vault_client.get_vault_id(), - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) if response.data.status == DetectStatus.IN_PROGRESS: - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)) + parsed_response = DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 856a1961..7d51ee83 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -89,10 +89,7 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return None def __get_headers(self): - headers = { - SKY_META_DATA_HEADER: json.dumps(get_metrics()) - } - return headers + return {SKY_META_DATA_HEADER: json.dumps(get_metrics())} def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.VALIDATE_INSERT_REQUEST.value, self.__vault_client.get_logger()) @@ -106,11 +103,11 @@ def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.INSERT_TRIGGERED.value, self.__vault_client.get_logger()) if request.continue_on_error: api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(), - records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options=self.__get_headers()) + records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) else: api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), - request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) + request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) insert_response = parse_insert_response(api_response, request.continue_on_error) log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) @@ -138,7 +135,7 @@ def update(self, request: UpdateRequest): record=record, tokenization=request.return_tokens, byot=request.token_mode.value, - request_options = self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger()) update_response = parse_update_record_response(api_response) @@ -159,7 +156,7 @@ def delete(self, request: DeleteRequest): self.__vault_client.get_vault_id(), request.table, skyflow_ids=request.ids, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger()) delete_response = parse_delete_response(api_response) @@ -189,7 +186,7 @@ def get(self, request: GetRequest): download_url=request.download_url, column_name=request.column_name, column_values=request.column_values, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger()) get_response = parse_get_response(api_response) @@ -209,7 +206,7 @@ def query(self, request: QueryRequest): api_response = query_api.query_service_execute_query( self.__vault_client.get_vault_id(), query=request.query, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger()) query_response = parse_query_response(api_response) @@ -237,7 +234,7 @@ def detokenize(self, request: DetokenizeRequest): self.__vault_client.get_vault_id(), detokenization_parameters=tokens_list, continue_on_error = request.continue_on_error, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) detokenize_response = parse_detokenize_response(api_response) @@ -262,7 +259,7 @@ def tokenize(self, request: TokenizeRequest): api_response = tokens_api.record_service_tokenize( self.__vault_client.get_vault_id(), tokenization_parameters=records_list, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) tokenize_response = parse_tokenize_response(api_response) log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) @@ -285,7 +282,7 @@ def upload_file(self, request: FileUploadRequest): file=self.__get_file_for_file_upload(request), skyflow_id=request.skyflow_id, return_file_metadata= False, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/data/_file_upload_request.py b/skyflow/vault/data/_file_upload_request.py index d1bd4a44..2f82c3e2 100644 --- a/skyflow/vault/data/_file_upload_request.py +++ b/skyflow/vault/data/_file_upload_request.py @@ -1,10 +1,10 @@ -from typing import BinaryIO +from typing import BinaryIO, Optional class FileUploadRequest: def __init__(self, table: str, - skyflow_id: str, column_name: str, + skyflow_id: Optional[str] = None, file_path: str= None, base64: str= None, file_object: BinaryIO= None, diff --git a/skyflow/vault/data/_get_response.py b/skyflow/vault/data/_get_response.py index cf1b0805..a1640254 100644 --- a/skyflow/vault/data/_get_response.py +++ b/skyflow/vault/data/_get_response.py @@ -1,6 +1,6 @@ class GetResponse: def __init__(self, data=None, errors = None): - self.data = data if data else [] + self.data = data if data is not None else [] self.errors = errors def __repr__(self): diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index b340e21c..97b8df40 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -1,4 +1,5 @@ import io +from typing import Optional from skyflow.vault.detect._file import File class DeidentifyFileResponse: @@ -17,6 +18,7 @@ def __init__( entities: list = None, # list of dicts with keys 'file' and 'extension' run_id: str = None, status: str = None, + errors: Optional[list] = None, ): self.file_base64 = file_base64 self.file = File(file) if file else None @@ -31,6 +33,7 @@ def __init__( self.entities = entities if entities is not None else [] self.run_id = run_id self.status = status + self.errors = errors def __repr__(self): return ( @@ -40,7 +43,7 @@ def __repr__(self): f"char_count={self.char_count!r}, size_in_kb={self.size_in_kb!r}, " f"duration_in_seconds={self.duration_in_seconds!r}, page_count={self.page_count!r}, " f"slide_count={self.slide_count!r}, entities={self.entities!r}, " - f"run_id={self.run_id!r}, status={self.status!r})" + f"run_id={self.run_id!r}, status={self.status!r}, errors={self.errors!r})" ) def __str__(self): diff --git a/skyflow/vault/detect/_deidentify_text_response.py b/skyflow/vault/detect/_deidentify_text_response.py index cdb6632e..227b43bc 100644 --- a/skyflow/vault/detect/_deidentify_text_response.py +++ b/skyflow/vault/detect/_deidentify_text_response.py @@ -1,19 +1,21 @@ -from typing import List +from typing import List, Optional from ._entity_info import EntityInfo class DeidentifyTextResponse: - def __init__(self, + def __init__(self, processed_text: str, - entities: List[EntityInfo], + entities: List[EntityInfo], word_count: int, - char_count: int): + char_count: int, + errors: Optional[list] = None): self.processed_text = processed_text self.entities = entities self.word_count = word_count self.char_count = char_count + self.errors = errors def __repr__(self): - return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count})" + return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count}, errors={self.errors})" def __str__(self): return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/detect/_reidentify_text_response.py b/skyflow/vault/detect/_reidentify_text_response.py index 50c3876d..73ad3f5d 100644 --- a/skyflow/vault/detect/_reidentify_text_response.py +++ b/skyflow/vault/detect/_reidentify_text_response.py @@ -1,9 +1,12 @@ +from typing import Optional + class ReidentifyTextResponse: - def __init__(self, processed_text: str): + def __init__(self, processed_text: str, errors: Optional[list] = None): self.processed_text = processed_text + self.errors = errors def __repr__(self) -> str: - return f"ReidentifyTextResponse(processed_text='{self.processed_text}')" + return f"ReidentifyTextResponse(processed_text='{self.processed_text}', errors={self.errors})" def __str__(self) -> str: return self.__repr__() \ No newline at end of file diff --git a/tests/client/test_skyflow.py b/tests/client/test_skyflow.py index 3e3681bb..dcf80f1f 100644 --- a/tests/client/test_skyflow.py +++ b/tests/client/test_skyflow.py @@ -1,42 +1,41 @@ import unittest -from unittest.mock import patch +from unittest.mock import patch, Mock from skyflow import LogLevel, Env from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow import Skyflow +from skyflow.vault.client.client import VaultClient VALID_VAULT_CONFIG = { "vault_id": "VAULT_ID", "cluster_id": "CLUSTER_ID", "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_VAULT_CONFIG = { "cluster_id": "CLUSTER_ID", # Missing vault_id "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } VALID_CONNECTION_CONFIG = { "connection_id": "CONNECTION_ID", "connection_url": "https://CONNECTION_URL", - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_CONNECTION_CONFIG = { "connection_url": "https://CONNECTION_URL", # Missing connection_id - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } -VALID_CREDENTIALS = { - "path": "/path/to/valid_credentials.json" -} +VALID_CREDENTIALS = {"path": "/path/to/valid_credentials.json"} -class TestSkyflow(unittest.TestCase): +class TestSkyflow(unittest.TestCase): def setUp(self): self.builder = Skyflow.builder() @@ -49,8 +48,10 @@ def test_add_already_exists_vault_config(self): builder = self.builder.add_vault_config(VALID_VAULT_CONFIG) with self.assertRaises(SkyflowError) as context: builder.add_vault_config(VALID_VAULT_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id"))) - + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id")), + ) def test_add_vault_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -61,11 +62,11 @@ def test_add_vault_config_invalid(self): def test_remove_vault_config_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - result = self.builder.remove_vault_config(VALID_VAULT_CONFIG['vault_id']) + result = self.builder.remove_vault_config(VALID_VAULT_CONFIG["vault_id"]) - self.assertNotIn(VALID_VAULT_CONFIG['vault_id'], self.builder._Builder__vault_configs) + self.assertNotIn(VALID_VAULT_CONFIG["vault_id"], self.builder._Builder__vault_configs) - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_vault_config_invalid(self, mock_log_error): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -73,8 +74,7 @@ def test_remove_vault_config_invalid(self, mock_log_error): self.builder.remove_vault_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_VAULT_ID.value) - - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_vault_config_valid(self, mock_validate): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -94,7 +94,7 @@ def test_get_vault(self): def test_get_vault_with_vault_id_none(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - vault = self.builder.get_vault_config(None) + vault = self.builder.get_vault_config(None) config = vault.get("vault_client").get_config() self.assertEqual(self.builder._Builder__vault_list[0], config) @@ -107,19 +107,23 @@ def test_get_vault_with_empty_vault_list_when_vault_id_is_none_raises_error(self def test_get_vault_with_invalid_vault_id_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_vault_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_vault_with_invalid_vault_id_and_non_empty_list_raises_error(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_vault_id') - - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id")) + self.builder.get_vault_config("invalid_vault_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id"), + ) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_build_calls_validate_vault_config(self, mock_validate_vault_config): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -143,7 +147,9 @@ def test_add_already_exists_connection_config(self): with self.assertRaises(SkyflowError) as context: builder.add_connection_config(VALID_CONNECTION_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id) + ) def test_add_connection_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -158,8 +164,7 @@ def test_remove_connection_config_valid(self): self.assertNotIn(VALID_CONNECTION_CONFIG.get("connection_id"), self.builder._Builder__connection_configs) - - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_connection_config_invalid(self, mock_log_error): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -167,7 +172,7 @@ def test_remove_connection_config_invalid(self, mock_log_error): self.builder.remove_connection_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONNECTION_ID.value) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_connection_config_valid(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -194,16 +199,21 @@ def test_get_connection_config_with_connection_id_none(self): def test_get_connection_with_empty_connection_list_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_connection_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_connection_with_invalid_connection_id_raises_error(self): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_connection_id') + self.builder.get_connection_config("invalid_connection_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_connection_id')) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_connection_id"), + ) def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(self): self.builder.build() @@ -212,13 +222,12 @@ def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(s self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS.value) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_build_calls_validate_connection_config(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() mock_validate.assert_called_once_with(self.builder._Builder__logger, VALID_CONNECTION_CONFIG) - def test_build_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG).add_connection_config(VALID_CONNECTION_CONFIG) client = self.builder.build() @@ -236,30 +245,31 @@ def test_invalid_credentials(self): self.assertEqual(VALID_CREDENTIALS, self.builder._Builder__skyflow_credentials) self.assertEqual(builder, self.builder) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_skyflow_client_add_remove_vault_config(self, mock_validate_vault_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['vault_id'] = "VAULT_ID" + new_config["vault_id"] = "VAULT_ID" skyflow_client.add_vault_config(new_config) assert mock_validate_vault_config.call_count == 2 - self.assertEqual("VAULT_ID", - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id")) + self.assertEqual("VAULT_ID", skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id")) - skyflow_client.remove_vault_config(new_config['vault_id']) + skyflow_client.remove_vault_config(new_config["vault_id"]) with self.assertRaises(SkyflowError) as context: - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id") + skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format( - new_config['vault_id'])) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(new_config["vault_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['env'] = Env.SANDBOX + new_config["env"] = Env.SANDBOX skyflow_client.update_vault_config(new_config) mock_update_config.assert_called_once() @@ -267,29 +277,33 @@ def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): self.assertEqual(VALID_VAULT_CONFIG.get("vault_id"), vault.get("vault_id")) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_skyflow_client_add_remove_connection_config(self, mock_validate_connection_config): skyflow_client = self.builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_id'] = "CONNECTION_ID" + new_config["connection_id"] = "CONNECTION_ID" skyflow_client.add_connection_config(new_config) assert mock_validate_connection_config.call_count == 2 - self.assertEqual("CONNECTION_ID", skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id")) + self.assertEqual( + "CONNECTION_ID", skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + ) skyflow_client.remove_connection_config("CONNECTION_ID") with self.assertRaises(SkyflowError) as context: - skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id") - - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config['connection_id'])) + skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config["connection_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_connection_config(self, mock_update_config): builder = self.builder skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_url'] = 'updated_url' + new_config["connection_url"] = "updated_url" skyflow_client.update_connection_config(new_config) mock_update_config.assert_called_once() @@ -305,28 +319,56 @@ def test_skyflow_add_and_update_skyflow_credentials(self): self.assertEqual(VALID_CREDENTIALS, builder._Builder__skyflow_credentials) new_credentials = VALID_CREDENTIALS.copy() - new_credentials['path'] = 'path/to/new_credentials' + new_credentials["path"] = "path/to/new_credentials" skyflow_client.update_skyflow_credentials(new_credentials) self.assertEqual(new_credentials, builder._Builder__skyflow_credentials) - def test_skyflow_add_and_update_log_level(self): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() + skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() skyflow_client.set_log_level(LogLevel.INFO) self.assertEqual(LogLevel.INFO, builder._Builder__log_level) - skyflow_client.update_log_level(LogLevel.ERROR) - self.assertEqual(LogLevel.ERROR, builder._Builder__log_level) - - - @patch('skyflow.client.Skyflow.Builder.get_vault_config') + @patch("skyflow.client.Skyflow.Builder.get_vault_config") def test_skyflow_vault_and_connection_method(self, mock_get_vault_config): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + skyflow_client = ( + builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + ) skyflow_client.vault() skyflow_client.connection() - mock_get_vault_config.assert_called_once() \ No newline at end of file + mock_get_vault_config.assert_called_once() + + +class TestVaultClient(unittest.TestCase): + def _make_client(self): + client = VaultClient({"vault_id": "test_vault"}) + client._VaultClient__api_client = Mock() + return client + + def test_get_detect_text_api_returns_strings(self): + client = self._make_client() + result = client.get_detect_text_api() + self.assertEqual(result, client._VaultClient__api_client.strings) + + def test_get_detect_file_api_returns_files(self): + client = self._make_client() + result = client.get_detect_file_api() + self.assertEqual(result, client._VaultClient__api_client.files) + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.is_expired", return_value=True) + def test_get_bearer_token_passes_token_uri_option(self, _mock_expired, mock_gen): + mock_gen.return_value = ("test_token", "bearer") + client = VaultClient({"vault_id": "test_vault"}) + credentials = { + "credentials_string": '{"clientID":"id","privateKey":"pk","keyID":"kid","tokenURI":"https://token.uri"}', + "token_uri": "https://custom-token-uri.com/token", + } + client.get_bearer_token(credentials) + options_passed = mock_gen.call_args[0][1] + self.assertIn("token_uri", options_passed) + self.assertEqual(options_passed["token_uri"], "https://custom-token-uri.com/token") diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index ca82527a..505a7261 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -5,35 +5,57 @@ from unittest.mock import patch import os from skyflow.error import SkyflowError -from skyflow.service_account import is_expired, generate_bearer_token, \ - generate_bearer_token_from_creds +from skyflow.service_account import is_expired, generate_bearer_token, generate_bearer_token_from_creds from skyflow.utils import SkyflowMessages -from skyflow.service_account._utils import get_service_account_token, get_signed_jwt, generate_signed_data_tokens, get_signed_data_token_response_object, generate_signed_data_tokens_from_creds +from skyflow.service_account._utils import ( + get_service_account_token, + get_signed_jwt, + generate_signed_data_tokens, + get_signed_data_token_response_object, + generate_signed_data_tokens_from_creds, + _validate_and_resolve_ctx, + _normalize_credentials, + get_signed_tokens, +) creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) VALID_CREDENTIALS_STRING = json.dumps(credentials) -CREDENTIALS_WITHOUT_CLIENT_ID = { - 'privateKey': 'private_key' -} +CREDENTIALS_WITHOUT_CLIENT_ID = {"privateKey": "private_key"} -CREDENTIALS_WITHOUT_KEY_ID = { - 'privateKey': 'private_key', - 'clientID': 'client_id' -} +CREDENTIALS_WITHOUT_KEY_ID = {"privateKey": "private_key", "clientID": "client_id"} -CREDENTIALS_WITHOUT_TOKEN_URI = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id' -} +CREDENTIALS_WITHOUT_TOKEN_URI = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id"} VALID_SERVICE_ACCOUNT_CREDS = credentials +# Snake-case version of the real credentials (keys remapped to snake_case) +SNAKE_CASE_CREDS = { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], +} + +SNAKE_CASE_CREDS_STRING = json.dumps( + { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], + } +) + + class TestServiceAccountUtils(unittest.TestCase): + # ── is_expired ──────────────────────────────────────────────────────────── + + def test_is_expired_none_token(self): + self.assertTrue(is_expired(None)) + def test_is_expired_empty_token(self): self.assertTrue(is_expired("")) @@ -44,7 +66,7 @@ def test_is_expired_non_expired_token(self): def test_is_expired_expired_token(self): past_time = time.time() - 1000 - token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") + token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) @patch("skyflow.utils.logger._log_helpers.log_error_log") @@ -53,6 +75,8 @@ def test_is_expired_general_exception(self, mock_jwt_decode, mock_log_error): token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) + # ── generate_bearer_token ───────────────────────────────────────────────── + @patch("builtins.open", side_effect=FileNotFoundError) def test_generate_bearer_token_invalid_file_path(self, mock_open): with self.assertRaises(SkyflowError) as context: @@ -72,6 +96,8 @@ def test_generate_bearer_token_valid_file_path(self, mock_generate_bearer_token) generate_bearer_token(creds_path) mock_generate_bearer_token.assert_called_once() + # ── generate_bearer_token_from_creds ────────────────────────────────────── + @patch("skyflow.service_account._utils.get_service_account_token") def test_generate_bearer_token_from_creds_with_valid_json_string(self, mock_generate_bearer_token): generate_bearer_token_from_creds(VALID_CREDENTIALS_STRING) @@ -82,10 +108,11 @@ def test_generate_bearer_token_from_creds_invalid_json(self): generate_bearer_token_from_creds("invalid_json") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + # ── get_service_account_token ───────────────────────────────────────────── + def test_get_service_account_token_missing_private_key(self): - incomplete_credentials = {} with self.assertRaises(SkyflowError) as context: - get_service_account_token(incomplete_credentials, {}, None) + get_service_account_token({}, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) def test_get_service_account_token_missing_client_id_key(self): @@ -107,43 +134,42 @@ def test_get_service_account_token_with_valid_credentials(self): access_token, _ = get_service_account_token(VALID_SERVICE_ACCOUNT_CREDS, {}, None) self.assertTrue(access_token) + def test_get_service_account_token_with_snake_case_creds(self): + access_token, _ = get_service_account_token(SNAKE_CASE_CREDS, {}, None) + self.assertTrue(access_token) - @patch("jwt.encode", side_effect=Exception) - def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): + def test_get_service_account_token_missing_private_key_snake(self): + creds = { + "client_id": "id", + "key_id": "kid", + "token_uri": "https://example.com", + } with self.assertRaises(SkyflowError) as context: - get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None) - self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value) - - def test_get_signed_data_token_response_object(self): - token = "sample_token" - signed_token = "signed_sample_token" - response = get_signed_data_token_response_object(signed_token, token) - self.assertEqual(response[0], token) - self.assertEqual(response[1], signed_token) - - def test_generate_signed_data_tokens_from_file_path(self): - creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") - options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'} - result = generate_signed_data_tokens(creds_path, options) - self.assertEqual(len(result), 2) + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) - def test_generate_signed_data_tokens_from_invalid_file_path(self): - options = {"data_tokens": ["token1", "token2"]} + def test_get_service_account_token_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens('credentials1.json', options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) - - def test_generate_signed_data_tokens_from_creds(self): - options = {"data_tokens": ["token1", "token2"]} - result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) - self.assertEqual(len(result), 2) + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) - def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): - options = {"data_tokens": ["token1", "token2"]} - credentials_string = '{' + def test_get_service_account_token_invalid_token_uri_in_options(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "not-a-valid-url"} with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens_from_creds(credentials_string, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + get_service_account_token(creds, options, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) @patch("skyflow.service_account._utils.AuthClient") @patch("skyflow.service_account._utils.get_signed_jwt") @@ -152,13 +178,14 @@ def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_si "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } options = {"role_ids": ["role1", "role2"]} mock_get_signed_jwt.return_value = "signed" mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value - mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token", - "token_type": "bearer"}) + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) access_token, token_type = get_service_account_token(creds, options, None) self.assertEqual(access_token, "token") self.assertEqual(token_type, "bearer") @@ -173,16 +200,18 @@ def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } mock_get_signed_jwt.return_value = "signed" mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError + mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized") with self.assertRaises(SkyflowError) as context: get_service_account_token(creds, {}, None) - self.assertEqual(context.exception.message, - SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value + ) @patch("skyflow.service_account._utils.AuthClient") @patch("skyflow.service_account._utils.get_signed_jwt") @@ -191,7 +220,7 @@ def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } mock_get_signed_jwt.return_value = "signed" mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value @@ -200,16 +229,364 @@ def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, get_service_account_token(creds, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value) + # ── get_signed_jwt ──────────────────────────────────────────────────────── + + @patch("jwt.encode", side_effect=Exception) + def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): + with self.assertRaises(SkyflowError) as context: + get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_string_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": "valid_ctx"}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], "valid_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_dict_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": {"role": "admin"}}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], {"role": "admin"}) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_empty_string_ctx_not_added(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": ""}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertNotIn("ctx", payload) + + # ── get_signed_data_token_response_object ───────────────────────────────── + + def test_get_signed_data_token_response_object(self): + token = "sample_token" + signed_token = "signed_sample_token" + response = get_signed_data_token_response_object(signed_token, token) + self.assertIsInstance(response, tuple) + self.assertEqual(response[0], token) + self.assertEqual(response[1], signed_token) + + # ── get_signed_tokens ───────────────────────────────────────────────────── + @patch("jwt.encode", side_effect=Exception("jwt error")) def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): creds = { "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } options = {"data_tokens": ["token1"]} with self.assertRaises(SkyflowError) as context: - from skyflow.service_account._utils import get_signed_tokens get_signed_tokens(creds, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) \ No newline at end of file + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) + + def test_get_signed_tokens_returns_list_one_per_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + def test_get_signed_tokens_items_are_tuples_with_token_and_signed_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + for item in result: + self.assertIsInstance(item, tuple) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[1][0], "token2") + self.assertTrue(result[0][1].startswith("signed_token_")) + self.assertTrue(result[1][1].startswith("signed_token_")) + + def test_get_signed_tokens_returns_list_single_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + def test_get_signed_tokens_empty_data_tokens_returns_empty_list(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": []}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_string_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": "my_ctx"}) + call_args = mock_jwt_encode.call_args + claims = call_args[0][0] if call_args[0] else call_args.kwargs.get("args", [None])[0] + # jwt.encode(claims, key, algorithm=...) — first positional arg is claims + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], "my_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_dict_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + ctx_dict = {"role": "admin", "dept": "eng"} + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ctx_dict}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], ctx_dict) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_empty_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ""}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_none_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": None}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + def test_get_signed_tokens_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_missing_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_with_snake_case_creds(self): + result = get_signed_tokens(SNAKE_CASE_CREDS, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + # ── generate_signed_data_tokens (file path) ─────────────────────────────── + + def test_generate_signed_data_tokens_from_file_path(self): + options = {"data_tokens": ["token1", "token2"], "ctx": "ctx"} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_invalid_file_path(self): + options = {"data_tokens": ["token1", "token2"]} + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens("credentials1.json", options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) + + def test_generate_signed_data_tokens_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "department": "finance"}} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 1) + + # ── generate_signed_data_tokens_from_creds (string) ────────────────────── + + def test_generate_signed_data_tokens_from_creds(self): + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): + options = {"data_tokens": ["token1", "token2"]} + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds("{", options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + + def test_generate_signed_data_tokens_from_creds_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "level": 3}} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 1) + + # ── snake_case end-to-end ───────────────────────────────────────────────── + + def test_generate_signed_data_tokens_with_snake_creds_file(self): + """generate_signed_data_tokens reads the file (camelCase) but the normalize fn is a no-op for camelCase.""" + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_snake(self): + result = generate_signed_data_tokens_from_creds(SNAKE_CASE_CREDS_STRING, options={"data_tokens": ["t1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + # ── _normalize_credentials ──────────────────────────────────────────────── + + def test_normalize_credentials_snake_case(self): + snake = { + "private_key": "pk", + "client_id": "cid", + "key_id": "kid", + "token_uri": "https://uri", + "client_name": "name", + } + result = _normalize_credentials(snake) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertEqual(result["clientName"], "name") + self.assertNotIn("private_key", result) + self.assertNotIn("client_id", result) + self.assertNotIn("key_id", result) + self.assertNotIn("token_uri", result) + self.assertNotIn("client_name", result) + + def test_normalize_credentials_camel_case_unchanged(self): + camel = { + "privateKey": "pk", + "clientID": "cid", + "keyID": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(camel) + self.assertEqual(result, camel) + + def test_normalize_credentials_mixed_keys(self): + mixed = { + "private_key": "pk", + "clientID": "cid", + "key_id": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(mixed) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertNotIn("private_key", result) + self.assertNotIn("key_id", result) + + def test_normalize_credentials_unknown_key_passes_through(self): + creds = {"unknown_field": "value", "anotherField": "val2"} + result = _normalize_credentials(creds) + self.assertEqual(result["unknown_field"], "value") + self.assertEqual(result["anotherField"], "val2") + + def test_normalize_credentials_empty_dict(self): + self.assertEqual(_normalize_credentials({}), {}) + + # ── _validate_and_resolve_ctx ───────────────────────────────────────────── + + def test_validate_and_resolve_ctx_none(self): + self.assertIsNone(_validate_and_resolve_ctx(None)) + + def test_validate_and_resolve_ctx_empty_string(self): + self.assertIsNone(_validate_and_resolve_ctx("")) + self.assertIsNone(_validate_and_resolve_ctx(" ")) + + def test_validate_and_resolve_ctx_valid_string(self): + self.assertEqual(_validate_and_resolve_ctx("user_12345"), "user_12345") + + def test_validate_and_resolve_ctx_empty_dict(self): + self.assertIsNone(_validate_and_resolve_ctx({})) + + def test_validate_and_resolve_ctx_valid_dict(self): + ctx = {"role": "admin", "department": "finance"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_alphanumeric_keys(self): + ctx = {"role_1": "admin", "dept2": "finance", "ABC_123": "value"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_hyphen(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"valid_key": "value", "invalid-key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_space(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_dot(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid.key": "value"}) + + def test_validate_and_resolve_ctx_valid_type_int(self): + self.assertEqual(_validate_and_resolve_ctx(42), 42) + + def test_validate_and_resolve_ctx_valid_type_float(self): + self.assertEqual(_validate_and_resolve_ctx(3.14), 3.14) + + def test_validate_and_resolve_ctx_valid_type_bool_true(self): + self.assertEqual(_validate_and_resolve_ctx(True), True) + + def test_validate_and_resolve_ctx_valid_type_bool_false(self): + self.assertEqual(_validate_and_resolve_ctx(False), False) + + def test_validate_and_resolve_ctx_invalid_type_list(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx(["a", "b"]) + + def test_validate_and_resolve_ctx_dict_with_mixed_value_types(self): + ctx = {"role": "admin", "level": 3, "active": True, "timestamp": "2025-12-25T10:30:00Z"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_nested_objects(self): + ctx = {"role": "admin", "metadata": {"level": 2, "tags": ["a", "b"]}} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + # ── additional coverage gaps ────────────────────────────────────────────── + + @patch("skyflow.service_account._utils.jwt.decode", side_effect=jwt.ExpiredSignatureError) + def test_is_expired_expired_signature_error(self, mock_decode): + token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") + self.assertTrue(is_expired(token)) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_token_uri_option_override(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + override_uri = "https://override-url.com" + options = {"token_uri": override_uri} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + get_service_account_token(creds, options, None) + mock_get_signed_jwt.assert_called_once() + call_args = mock_get_signed_jwt.call_args + self.assertEqual(call_args[0][3], override_uri) + + @patch("json.load", side_effect=json.JSONDecodeError("bad json", "", 0)) + def test_generate_signed_data_tokens_from_file_invalid_json(self, mock_load): + invalid_path = os.path.join(os.path.dirname(__file__), "invalid_creds.json") + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(invalid_path, {"data_tokens": ["t1"]}) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.FILE_INVALID_JSON.value.format(invalid_path), + ) diff --git a/tests/utils/test__helpers.py b/tests/utils/test__helpers.py index 6758b62e..6016c798 100644 --- a/tests/utils/test__helpers.py +++ b/tests/utils/test__helpers.py @@ -39,9 +39,10 @@ def test_format_scope_special_characters(self): def test_is_valid_url_valid(self): self.assertTrue(is_valid_url("https://example.com")) - self.assertTrue(is_valid_url("http://example.com/path")) + self.assertTrue(is_valid_url("https://example.com/path")) def test_is_valid_url_invalid(self): + self.assertFalse(is_valid_url("http://example.com")) self.assertFalse(is_valid_url("ftp://example.com")) self.assertFalse(is_valid_url("example.com")) self.assertFalse(is_valid_url("invalid-url")) diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index b0466498..95983058 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1,40 +1,65 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock, PropertyMock import os -from unittest.mock import MagicMock from urllib.parse import quote import tempfile, json from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.generated.rest import ErrorResponse -from skyflow.service_account import generate_bearer_token, generate_signed_data_tokens, \ - generate_signed_data_tokens_from_creds, generate_bearer_token_from_creds -from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ - parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ - parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ - handle_exception, validate_api_key, encode_column_values, parse_deidentify_text_response, \ - parse_reidentify_text_response, convert_detected_entity_to_entity_info -from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error +from skyflow.service_account import ( + generate_bearer_token, + generate_signed_data_tokens, + generate_signed_data_tokens_from_creds, + generate_bearer_token_from_creds, +) +from skyflow.utils import ( + get_credentials, + SkyflowMessages, + get_vault_url, + construct_invoke_connection_request, + parse_insert_response, + parse_update_record_response, + parse_delete_response, + parse_get_response, + parse_detokenize_response, + parse_tokenize_response, + parse_query_response, + parse_invoke_connection_response, + handle_exception, + validate_api_key, + encode_column_values, + parse_deidentify_text_response, + parse_reidentify_text_response, + convert_detected_entity_to_entity_info, +) +from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error, r_urlencode from skyflow.utils.enums import EnvUrls, Env, ContentType from skyflow.vault.connection import InvokeConnectionResponse from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse from skyflow.vault.tokens import DetokenizeResponse, TokenizeResponse creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) TEST_ERROR_MESSAGE = "Test error message." VALID_ENV_CREDENTIALS = credentials -class TestUtils(unittest.TestCase): +class TestUtils(unittest.TestCase): @patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": json.dumps(VALID_ENV_CREDENTIALS)}) def test_get_credentials_env_variable(self): credentials = get_credentials() - credentials_string = credentials.get('credentials_string') - self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n')) + credentials_string = credentials.get("credentials_string") + self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace("\n", "\\n")) + + @patch("skyflow.utils._utils.dotenv.find_dotenv", return_value=None) + @patch.dict(os.environ, {}, clear=True) + def test_get_credentials_no_credentials_raises(self, mock_find_dotenv): + with self.assertRaises(SkyflowError) as context: + get_credentials(config_level_creds=None, common_skyflow_creds=None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) def test_get_credentials_with_config_level_creds(self): test_creds = {"authToken": "test_token"} @@ -60,11 +85,13 @@ def test_get_vault_url_with_invalid_cluster_id(self): valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id) + ) def test_get_vault_url_with_invalid_env(self): valid_cluster_id = "cluster_id" - valid_env =EnvUrls.DEV + valid_env = EnvUrls.DEV valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) @@ -79,7 +106,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): "http_code": 400, "http_status": "Bad Request", "grpc_code": 3, - "details": ["detail1"] + "details": ["detail1"], } } @@ -90,13 +117,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): handle_json_error(mock_error, error_dict, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Dict error message", - 400, - request_id, - "Bad Request", - 3, - ["detail1"], - logger=mock_logger + "Dict error message", 400, request_id, "Bad Request", 3, ["detail1"], logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -109,7 +130,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ "http_code": 403, "http_status": "Forbidden", "grpc_code": 7, - "details": ["detail2"] + "details": ["detail2"], } } @@ -120,13 +141,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ handle_json_error(mock_error, mock_error_response, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "ErrorResponse message", - 403, - request_id, - "Forbidden", - 7, - ["detail2"], - logger=mock_logger + "ErrorResponse message", 403, request_id, "Forbidden", 7, ["detail2"], logger=mock_logger ) def test_parse_path_params(self): @@ -140,13 +155,56 @@ def test_to_lowercase_keys(self): expected_output = {"key1": "value1", "key2": "value2"} self.assertEqual(to_lowercase_keys(input_dict), expected_output) + def test_r_urlencode_with_list_input(self): + pairs = {} + r_urlencode([], pairs, ["a", "b"]) + self.assertIn("[0]", pairs) + self.assertIn("[1]", pairs) + self.assertEqual(pairs["[0]"], "a") + self.assertEqual(pairs["[1]"], "b") + + def test_r_urlencode_with_tuple_input(self): + pairs = {} + r_urlencode([], pairs, ("x", "y")) + self.assertIn("[0]", pairs) + self.assertEqual(pairs["[0]"], "x") + def test_get_metrics(self): metrics = get_metrics() - self.assertIn('sdk_name_version', metrics) - self.assertIn('sdk_client_device_model', metrics) - self.assertIn('sdk_client_os_details', metrics) - self.assertIn('sdk_runtime_details', metrics) + self.assertIn("sdk_name_version", metrics) + self.assertIn("sdk_client_device_model", metrics) + self.assertIn("sdk_client_os_details", metrics) + self.assertIn("sdk_runtime_details", metrics) + def test_get_metrics_platform_node_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + with patch("skyflow.utils._utils.platform") as mock_platform: + mock_platform.node.side_effect = OSError("no node") + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_device_model"], "") + utils_module._CACHED_METRICS.clear() + + def test_get_metrics_sys_attribute_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + + class _RaisingSys: + @property + def platform(self): + raise RuntimeError("no platform") + + @property + def version(self): + raise RuntimeError("no version") + + with patch("skyflow.utils._utils.sys", _RaisingSys()): + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_os_details"], "") + self.assertIn("sdk_runtime_details", metrics) + utils_module._CACHED_METRICS.clear() def test_construct_invoke_connection_request_valid(self): mock_connection_request = Mock() @@ -166,7 +224,7 @@ def test_construct_invoke_connection_request_valid(self): self.assertEqual(result.url, expected_url) self.assertEqual(result.method, "POST") - self.assertEqual(result.headers['Content-Type'], ContentType.JSON.value) + self.assertEqual(result.headers["Content-Type"], ContentType.JSON.value) self.assertEqual(result.body, json.dumps(mock_connection_request.body)) @@ -232,9 +290,7 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} - mock_connection_request.body = { - "name": (None, "John Doe") - } + mock_connection_request.body = {"name": (None, "John Doe")} mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -244,13 +300,27 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): self.assertIsInstance(result, PreparedRequest) + def test_parse_insert_response_with_tokens_continue_on_error(self): + api_response = Mock() + api_response.headers = {"x-request-id": "req-1"} + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1", "tokens": {"col1": "tok1"}}]}}, + ] + ) + result = parse_insert_response(api_response, continue_on_error=True) + self.assertEqual(result.inserted_fields[0]["col1"], "tok1") + self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1") + def test_parse_insert_response(self): api_response = Mock() api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - api_response.data = Mock(responses=[ - {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, - {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}} - ]) + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, + {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}}, + ] + ) result = parse_insert_response(api_response, continue_on_error=True) self.assertEqual(len(result.inserted_fields), 1) self.assertEqual(len(result.errors), 1) @@ -264,17 +334,19 @@ def test_parse_insert_response(self): def test_parse_insert_response_continue_on_error_false(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), - Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}) - ]) + mock_api_response.data = Mock( + records=[ + Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), + Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}), + ] + ) result = parse_insert_response(mock_api_response, continue_on_error=False) self.assertIsInstance(result, InsertResponse) expected_inserted_fields = [ {"skyflow_id": "id_1", "token1": "token_value1"}, - {"skyflow_id": "id_2", "token2": "token_value2"} + {"skyflow_id": "id_2", "token2": "token_value2"}, ] self.assertEqual(result.inserted_fields, expected_inserted_fields) @@ -285,8 +357,8 @@ def test_parse_update_record_response(self): api_response.skyflow_id = "id1" api_response.tokens = {"token1": "value1"} result = parse_update_record_response(api_response) - self.assertEqual(result.updated_field['skyflow_id'], "id1") - self.assertEqual(result.updated_field['token1'], "value1") + self.assertEqual(result.updated_field["skyflow_id"], "id1") + self.assertEqual(result.updated_field["token1"], "value1") def test_parse_delete_response_successful(self): mock_api_response = Mock() @@ -304,18 +376,15 @@ def test_parse_delete_response_successful(self): def test_parse_get_response_successful(self): mock_api_response = Mock() mock_api_response.records = [ - Mock(fields={'field1': 'value1', 'field2': 'value2'}), - Mock(fields={'field1': 'value3', 'field2': 'value4'}) + Mock(fields={"field1": "value1", "field2": "value2"}), + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_get_response(mock_api_response) self.assertIsInstance(result, GetResponse) - expected_data = [ - {'field1': 'value1', 'field2': 'value2'}, - {'field1': 'value3', 'field2': 'value4'} - ] + expected_data = [{"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"}] self.assertEqual(result.data, expected_data) # self.assertEqual(result.errors, None) @@ -323,23 +392,23 @@ def test_parse_get_response_successful(self): def test_parse_detokenize_response_with_mixed_records(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(token="token1", value="value1", value_type="Type1", error=None), - Mock(token="token2", value=None, value_type=None, error="Some error"), - Mock(token="token3", value="value3", value_type="Type2", error=None), - ]) + mock_api_response.data = Mock( + records=[ + Mock(token="token1", value="value1", value_type="Type1", error=None), + Mock(token="token2", value=None, value_type=None, error="Some error"), + Mock(token="token3", value="value3", value_type="Type2", error=None), + ] + ) result = parse_detokenize_response(mock_api_response) self.assertIsInstance(result, DetokenizeResponse) expected_detokenized_fields = [ {"token": "token1", "value": "value1", "type": "Type1"}, - {"token": "token3", "value": "value3", "type": "Type2"} + {"token": "token3", "value": "value3", "type": "Type2"}, ] - expected_errors = [ - {"token": "token2", "error": "Some error", "request_id": "12345"} - ] + expected_errors = [{"token": "token2", "error": "Some error", "request_id": "12345"}] self.assertEqual(result.detokenized_fields, expected_detokenized_fields) self.assertEqual(result.errors, expected_errors) @@ -355,11 +424,7 @@ def test_parse_tokenize_response_with_valid_records(self): result = parse_tokenize_response(mock_api_response) self.assertIsInstance(result, TokenizeResponse) - expected_tokenized_fields = [ - {"token": "token1"}, - {"token": "token2"}, - {"token": "token3"} - ] + expected_tokenized_fields = [{"token": "token1"}, {"token": "token2"}, {"token": "token3"}] self.assertEqual(result.tokenized_fields, expected_tokenized_fields) @@ -367,7 +432,7 @@ def test_parse_query_response_with_valid_records(self): mock_api_response = Mock() mock_api_response.records = [ Mock(fields={"field1": "value1", "field2": "value2"}), - Mock(fields={"field1": "value3", "field2": "value4"}) + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_query_response(mock_api_response) @@ -376,7 +441,7 @@ def test_parse_query_response_with_valid_records(self): expected_fields = [ {"field1": "value1", "field2": "value2", "tokenized_data": {}}, - {"field1": "value3", "field2": "value4", "tokenized_data": {}} + {"field1": "value3", "field2": "value4", "tokenized_data": {}}, ] self.assertEqual(result.fields, expected_fields) @@ -384,7 +449,7 @@ def test_parse_query_response_with_valid_records(self): @patch("requests.Response") def test_parse_invoke_connection_response_successful(self, mock_response): mock_response.status_code = 200 - mock_response.content = json.dumps({"key": "value"}).encode('utf-8') + mock_response.content = json.dumps({"key": "value"}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} result = parse_invoke_connection_response(mock_response) @@ -398,7 +463,7 @@ def test_parse_invoke_connection_response_successful(self, mock_response): def test_parse_invoke_connection_response_json_decode_error(self, mock_response): """Test that non-JSON content in successful response is returned as string.""" mock_response.status_code = 200 - mock_response.content = "Non-JSON Content".encode('utf-8') + mock_response.content = "Non-JSON Content".encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status = Mock() @@ -412,7 +477,7 @@ def test_parse_invoke_connection_response_json_decode_error(self, mock_response) @patch("requests.Response") def test_parse_invoke_connection_response_http_error_with_json_error_message(self, mock_response): mock_response.status_code = 404 - mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode('utf-8') + mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("404 Error") @@ -423,10 +488,38 @@ def test_parse_invoke_connection_response_http_error_with_json_error_message(sel self.assertEqual(context.exception.message, "Not Found") self.assertEqual(context.exception.request_id, "1234") + @patch("requests.Response") + def test_parse_invoke_connection_response_with_error_from_client_header(self, mock_response): + from requests.models import HTTPError + + mock_response.status_code = 400 + mock_response.content = json.dumps( + { + "error": { + "message": "Client error", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": 3, + "details": None, + } + } + ).encode("utf-8") + mock_response.headers = { + "x-request-id": "rid-1", + "error-from-client": "true", + } + mock_response.raise_for_status.side_effect = HTTPError("400") + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + err = context.exception + self.assertEqual(err.message, "Client error") + self.assertIsNotNone(err.details) + self.assertTrue(any(d.get("error_from_client") is True for d in err.details)) + @patch("requests.Response") def test_parse_invoke_connection_response_http_error_without_json_error_message(self, mock_response): mock_response.status_code = 500 - mock_response.content = "Internal Server Error".encode('utf-8') + mock_response.content = "Internal Server Error".encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("500 Error") @@ -434,7 +527,7 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, "Internal Server Error") + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) self.assertEqual(context.exception.http_code, 500) self.assertEqual(context.exception.request_id, "1234") @@ -442,31 +535,24 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( def test_handle_exception_json_error(self, mock_log_and_reject_error): mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'application/json' - } - mock_error.body = json.dumps({ - "error": { - "message": "JSON error occurred.", - "http_code": 400, - "http_status": "Bad Request", - "grpc_code": "8", - "details": "Detailed message" + mock_error.headers = {"x-request-id": "1234", "content-type": "application/json"} + mock_error.body = json.dumps( + { + "error": { + "message": "JSON error occurred.", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": "8", + "details": "Detailed message", + } } - }).encode('utf-8') + ).encode("utf-8") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "JSON error occurred.", - 400, - "1234", - "Bad Request", - "8", - "Detailed message", - logger=mock_logger + "JSON error occurred.", 400, "1234", "Bad Request", "8", "Detailed message", logger=mock_logger ) def test_validate_api_key_valid_key(self): @@ -502,12 +588,7 @@ def test_parse_deidentify_text_response(self): mock_entity.value = "sensitive_value" mock_entity.entity_type = "EMAIL" mock_entity.entity_scores = {"EMAIL": 0.95} - mock_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 - ) + mock_entity.location = Mock(start_index=10, end_index=20, start_index_processed=15, end_index_processed=25) mock_api_response = Mock() mock_api_response.processed_text = "Sample processed text" @@ -564,10 +645,7 @@ def test__convert_detected_entity_to_entity_info(self): mock_detected_entity.entity_type = "EMAIL" mock_detected_entity.entity_scores = {"EMAIL": 0.95} mock_detected_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 + start_index=10, end_index=20, start_index_processed=15, end_index_processed=25 ) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -588,12 +666,7 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): mock_detected_entity.value = None mock_detected_entity.entity_type = "UNKNOWN" mock_detected_entity.entity_scores = {} - mock_detected_entity.location = Mock( - start_index=0, - end_index=0, - start_index_processed=0, - end_index_processed=0 - ) + mock_detected_entity.location = Mock(start_index=0, end_index=0, start_index_processed=0, end_index_processed=0) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -606,21 +679,18 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) - @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_connect_error(self, mock_log_and_reject_error): """Test handling httpx.ConnectError.""" import httpx + mock_error = httpx.ConnectError("Connection refused") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - 'Connection refused', - SkyflowMessages.ErrorCodes.INVALID_INPUT.value, - None, - logger=mock_logger + "Connection refused", SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -632,10 +702,7 @@ def test_handle_exception_no_headers_attribute(self, mock_log_and_reject_error): handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Generic error", - SkyflowMessages.ErrorCodes.SERVER_ERROR.value, - None, - logger=mock_logger + "Generic error", SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -643,89 +710,67 @@ def test_handle_exception_no_body_attribute(self, mock_log_and_reject_error): """Test handling error without body attribute.""" mock_error = Mock() mock_error.headers = {"x-request-id": "12345"} - delattr(mock_error, 'body') + delattr(mock_error, "body") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once() - self.assertEqual( - mock_log_and_reject_error.call_args[0][1], - SkyflowMessages.ErrorCodes.SERVER_ERROR.value - ) + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_text_plain_error(self, mock_log_and_reject_error): """Test handling text/plain content type error.""" mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'text/plain' - } + mock_error.headers = {"x-request-id": "1234", "content-type": "text/plain"} mock_error.body = "Plain text error message" mock_error.status = 500 mock_logger = Mock() handle_exception(mock_error, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - "Plain text error message", - 500, - "1234", - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with("Plain text error message", 500, "1234", logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_error): """Test handling generic error with unknown content type.""" mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'application/xml' - } + mock_error.headers = {"x-request-id": "1234", "content-type": "application/xml"} mock_error.body = "XML error" mock_error.status = 503 mock_logger = Mock() handle_exception(mock_error, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - str(mock_error), - 503, - "1234", - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, "1234", logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_no_content_type(self, mock_log_and_reject_error): """Test handling error without content-type header.""" mock_error = Mock() - mock_error.headers = {'x-request-id': '1234'} + mock_error.headers = {"x-request-id": "1234"} mock_error.body = "Some error" mock_error.status = 500 mock_logger = Mock() handle_exception(mock_error, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - str(mock_error), - 500, - "1234", - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 500, "1234", logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): """Test handling JSON error when data is a JSON string.""" - error_json_string = json.dumps({ - "error": { - "message": "String JSON error", - "http_code": 422, - "http_status": "Unprocessable Entity", - "grpc_code": 3, - "details": ["validation failed"] + error_json_string = json.dumps( + { + "error": { + "message": "String JSON error", + "http_code": 422, + "http_status": "Unprocessable Entity", + "grpc_code": 3, + "details": ["validation failed"], + } } - }) + ) mock_error = Mock() mock_logger = Mock() @@ -734,13 +779,7 @@ def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): handle_json_error(mock_error, error_json_string, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "String JSON error", - 422, - request_id, - "Unprocessable Entity", - 3, - ["validation failed"], - logger=mock_logger + "String JSON error", 422, request_id, "Unprocessable Entity", 3, ["validation failed"], logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -756,17 +795,12 @@ def test_handle_json_error_with_invalid_json(self, mock_log_and_reject_error): # Should call with INVALID_JSON_RESPONSE error mock_log_and_reject_error.assert_called_once() - self.assertEqual( - mock_log_and_reject_error.call_args[0][0], - SkyflowMessages.Error.INVALID_JSON_RESPONSE.value - ) + self.assertEqual(mock_log_and_reject_error.call_args[0][0], SkyflowMessages.Error.INVALID_JSON_RESPONSE.value) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_json_error_missing_error_field(self, mock_log_and_reject_error): """Test handling JSON error with missing error field.""" - error_dict = { - "message": "Error without error wrapper" - } + error_dict = {"message": "Error without error wrapper"} mock_error = Mock() mock_logger = Mock() @@ -793,14 +827,10 @@ def test_handle_text_error_with_status(self, mock_log_and_reject_error): error_data = "Resource not found" from skyflow.utils._utils import handle_text_error + handle_text_error(mock_error, error_data, request_id, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - "Resource not found", - 404, - request_id, - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with("Resource not found", 404, request_id, logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_generic_error_with_status(self, mock_log_and_reject_error): @@ -811,14 +841,10 @@ def test_handle_generic_error_with_status(self, mock_log_and_reject_error): status = 503 from skyflow.utils._utils import handle_generic_error_with_status + handle_generic_error_with_status(mock_error, request_id, status, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - str(mock_error), - 503, - request_id, - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, request_id, logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_with_none_error(self, mock_log_and_reject_error): @@ -831,10 +857,10 @@ def test_handle_exception_with_none_error(self, mock_log_and_reject_error): SkyflowMessages.Error.GENERIC_API_ERROR.value, SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, - logger=mock_logger + logger=mock_logger, ) - #failed + # failed @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_error): """Test handling empty string error.""" @@ -847,22 +873,54 @@ def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_erro mock_log_and_reject_error.assert_called_once() # Should use str(error) or default message - self.assertEqual( - mock_log_and_reject_error.call_args[0][1], - SkyflowMessages.ErrorCodes.SERVER_ERROR.value - ) + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) @patch("skyflow.utils._utils.log_and_reject_error") - def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): - """Test handling JSON error when data is bytes.""" + def test_handle_json_error_with_responses_key(self, mock_log_and_reject_error): + """Test handle_json_error when body has 'responses' key (batch/continue_on_error path).""" error_dict = { - "error": { - "message": "Bytes error", - "http_code": 401, - "http_status": "Unauthorized" - } + "responses": [ + {"Status": 400, "Body": {"error": "record not found"}}, + {"Status": 400, "Body": {"error": "invalid field"}}, + ] } - error_bytes = json.dumps(error_dict).encode('utf-8') + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-responses" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertIn("record not found", args[0]) + self.assertIn("invalid field", args[0]) + self.assertEqual(args[1], 400) + self.assertIsNone(args[3]) # http_status + self.assertIsNone(args[4]) # grpc_code + self.assertEqual(args[5], []) # details + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_responses_no_error_messages(self, mock_log_and_reject_error): + """Test handle_json_error with responses key but no error body — falls back to default message.""" + error_dict = { + "responses": [ + {"Status": 200, "Body": {"records": [{"skyflow_id": "abc"}]}}, + ] + } + mock_error = Mock() + request_id = "test-request-id-responses-empty" + + handle_json_error(mock_error, error_dict, request_id, None) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): + """Test handling JSON error when data is bytes.""" + error_dict = {"error": {"message": "Bytes error", "http_code": 401, "http_status": "Unauthorized"}} + error_bytes = json.dumps(error_dict).encode("utf-8") mock_error = Mock() mock_logger = Mock() @@ -871,13 +929,7 @@ def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): handle_json_error(mock_error, error_bytes, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Bytes error", - 401, - request_id, - "Unauthorized", - None, - [], - logger=mock_logger + "Bytes error", 401, request_id, "Unauthorized", None, [], logger=mock_logger ) # Add these new test methods to the TestUtils class: @@ -897,7 +949,7 @@ def test_construct_invoke_connection_request_with_no_headers(self): self.assertIsInstance(result, PreparedRequest) # Headers should be None when not provided - self.assertIsNone(result.headers.get('Content-Type')) + self.assertIsNone(result.headers.get("Content-Type")) def test_construct_invoke_connection_request_with_xml_content_type(self): """Test construct_invoke_connection_request with XML content type.""" @@ -913,10 +965,10 @@ def test_construct_invoke_connection_request_with_xml_content_type(self): result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) self.assertIsInstance(result, PreparedRequest) - self.assertEqual(result.headers['content-type'], 'application/xml') + self.assertEqual(result.headers["content-type"], "application/xml") # Body should be converted to XML - self.assertIn('', result.body) - self.assertIn('value', result.body) + self.assertIn("", result.body) + self.assertIn("value", result.body) def test_construct_invoke_connection_request_with_html_content_type(self): """Test construct_invoke_connection_request with HTML content type.""" @@ -932,7 +984,7 @@ def test_construct_invoke_connection_request_with_html_content_type(self): result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) self.assertIsInstance(result, PreparedRequest) - self.assertEqual(result.headers['content-type'], 'text/html') + self.assertEqual(result.headers["content-type"], "text/html") # Body should be JSON string for HTML self.assertEqual(result.body, json.dumps({"message": "Hello"})) @@ -951,8 +1003,8 @@ def test_construct_invoke_connection_request_multipart_removes_content_type(self self.assertIsInstance(result, PreparedRequest) # Content-Type should be auto-generated by requests library - self.assertIn('multipart/form-data', result.headers.get('Content-Type', '')) - self.assertIn('boundary=', result.headers.get('Content-Type', '')) + self.assertIn("multipart/form-data", result.headers.get("Content-Type", "")) + self.assertIn("boundary=", result.headers.get("Content-Type", "")) def test_construct_invoke_connection_request_with_no_body(self): """Test construct_invoke_connection_request when body is None.""" @@ -1119,10 +1171,7 @@ def test_parse_invoke_connection_response_xml_content(self, mock_response): """Test parsing XML response content.""" mock_response.status_code = 200 mock_response.content = b"success" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/xml" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/xml"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1137,10 +1186,7 @@ def test_parse_invoke_connection_response_url_encoded_content(self, mock_respons """Test parsing URL encoded response content.""" mock_response.status_code = 200 mock_response.content = b"card_number=4111111111111111&cvv=123" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/x-www-form-urlencoded" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/x-www-form-urlencoded"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1155,10 +1201,7 @@ def test_parse_invoke_connection_response_html_content(self, mock_response): """Test parsing HTML response content.""" mock_response.status_code = 200 mock_response.content = b"Success" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "text/html" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1173,17 +1216,14 @@ def test_parse_invoke_connection_response_html_error(self, mock_response): """Test parsing HTML error response.""" html_error = "

Error 500

" mock_response.status_code = 500 - mock_response.content = html_error.encode('utf-8') - mock_response.headers = { - "x-request-id": "1234", - "content-type": "text/html" - } + mock_response.content = html_error.encode("utf-8") + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} mock_response.raise_for_status = Mock(side_effect=HTTPError("500 Error")) with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, html_error) + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) self.assertEqual(context.exception.http_code, 500) self.assertEqual(context.exception.request_id, "1234") @@ -1192,10 +1232,7 @@ def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, """Test that JSON decode error falls back to returning string content.""" mock_response.status_code = 200 mock_response.content = b"Not valid JSON but still success" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/json" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/json"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1209,7 +1246,7 @@ def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, def test_parse_invoke_connection_response_no_content_type_with_json(self, mock_response): """Test parsing response with no content-type but valid JSON.""" mock_response.status_code = 200 - mock_response.content = json.dumps({"success": True}).encode('utf-8') + mock_response.content = json.dumps({"success": True}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status = Mock() @@ -1240,10 +1277,7 @@ def test_parse_invoke_connection_response_bytes_content(self, mock_response): """Test parsing response with bytes content.""" mock_response.status_code = 200 mock_response.content = b"Binary data response" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/octet-stream" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/octet-stream"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1269,7 +1303,7 @@ def __repr__(self): connection_url = "https://example.com/endpoint" - with patch('json.dumps', side_effect=TypeError("Object is not JSON serializable")): + with patch("json.dumps", side_effect=TypeError("Object is not JSON serializable")): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) @@ -1287,7 +1321,7 @@ def test_construct_invoke_connection_request_headers_generic_exception(self): connection_url = "https://example.com/endpoint" - with patch('skyflow.utils._utils.to_lowercase_keys', side_effect=Exception("Generic error")): + with patch("skyflow.utils._utils.to_lowercase_keys", side_effect=Exception("Generic error")): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) @@ -1305,7 +1339,7 @@ def test_construct_invoke_connection_request_body_processing_exception(self): connection_url = "https://example.com/endpoint" - with patch('skyflow.utils._utils.get_data_from_content_type', side_effect=Exception("Body processing error")): + with patch("skyflow.utils._utils.get_data_from_content_type", side_effect=Exception("Body processing error")): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) @@ -1344,7 +1378,7 @@ def test_construct_invoke_connection_request_invalid_url_exception(self): connection_url = "https://example.com/endpoint" - with patch('requests.Request') as mock_request_class: + with patch("requests.Request") as mock_request_class: mock_request_instance = Mock() mock_request_instance.prepare.side_effect = Exception("Invalid URL structure") mock_request_class.return_value = mock_request_instance @@ -1352,10 +1386,7 @@ def test_construct_invoke_connection_request_invalid_url_exception(self): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) - self.assertEqual( - context.exception.message, - SkyflowMessages.Error.INVALID_URL.value.format(connection_url) - ) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) def test_construct_invoke_connection_request_prepare_exception(self): @@ -1369,7 +1400,7 @@ def test_construct_invoke_connection_request_prepare_exception(self): connection_url = "https://example.com/endpoint" - with patch('requests.Request') as mock_request_class: + with patch("requests.Request") as mock_request_class: mock_request_instance = Mock() mock_request_instance.prepare.side_effect = Exception("Prepare failed") mock_request_class.return_value = mock_request_instance @@ -1377,10 +1408,7 @@ def test_construct_invoke_connection_request_prepare_exception(self): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) - self.assertEqual( - context.exception.message, - SkyflowMessages.Error.INVALID_URL.value.format(connection_url) - ) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) def test_construct_invoke_connection_request_body_not_dict_raises_error(self): @@ -1400,7 +1428,7 @@ def test_construct_invoke_connection_request_body_not_dict_raises_error(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - @patch('skyflow.utils._utils.validate_invoke_connection_params') + @patch("skyflow.utils._utils.validate_invoke_connection_params") def test_construct_invoke_connection_request_validation_exception(self, mock_validate): """Test that validation exceptions are properly propagated.""" mock_connection_request = Mock() @@ -1419,15 +1447,16 @@ def test_construct_invoke_connection_request_validation_exception(self, mock_val self.assertEqual(context.exception.message, "Validation failed") self.assertEqual(context.exception.http_code, 400) + def test_generate_bearer_token_invalid_token_uri_type(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 # invalid type + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": 12345, # invalid type } - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1435,13 +1464,8 @@ def test_generate_bearer_token_invalid_token_uri_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_bearer_token_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1450,13 +1474,13 @@ def test_generate_bearer_token_invalid_token_uri_url(self): def test_generate_bearer_token_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"token_uri": "https://another-valid-url.com"} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() # Patch AuthClient and jwt.encode to avoid real HTTP and signing @@ -1464,32 +1488,22 @@ def test_generate_bearer_token_options_override_token_uri(self): mock_get_signed_jwt.return_value = "signed" with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value - mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), - {"access_token": "token", - "token_type": "bearer"}) + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) generate_bearer_token(tmp.name, options) args, kwargs = mock_get_signed_jwt.call_args self.assertEqual(args[3], options["token_uri"]) def test_generate_bearer_token_from_creds_invalid_token_uri_type(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: generate_bearer_token_from_creds(creds_str) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: generate_bearer_token_from_creds(creds_str) @@ -1497,10 +1511,10 @@ def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): def test_generate_bearer_token_from_creds_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"token_uri": "https://another-valid-url.com"} creds_str = json.dumps(creds) @@ -1508,22 +1522,17 @@ def test_generate_bearer_token_from_creds_options_override_token_uri(self): mock_get_signed_jwt.return_value = "signed" with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value - mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), - {"access_token": "token", - "token_type": "bearer"}) + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) generate_bearer_token_from_creds(creds_str, options) args, kwargs = mock_get_signed_jwt.call_args self.assertEqual(args[3], options["token_uri"]) def test_generate_signed_data_tokens_invalid_token_uri_type(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} options = {"data_tokens": ["token1"]} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1531,14 +1540,9 @@ def test_generate_signed_data_tokens_invalid_token_uri_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_signed_data_tokens_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} options = {"data_tokens": ["token1"]} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1546,12 +1550,7 @@ def test_generate_signed_data_tokens_invalid_token_uri_url(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} options = {"data_tokens": ["token1"]} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: @@ -1559,12 +1558,7 @@ def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} options = {"data_tokens": ["token1"]} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: @@ -1573,34 +1567,36 @@ def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): def test_generate_signed_data_tokens_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with patch("jwt.encode") as mock_jwt_encode: mock_jwt_encode.return_value = "signed" result = generate_signed_data_tokens(tmp.name, options) - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], "token1") - self.assertEqual(result[1], "signed_token_signed") + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} creds_str = json.dumps(creds) with patch("jwt.encode") as mock_jwt_encode: mock_jwt_encode.return_value = "signed" result = generate_signed_data_tokens_from_creds(creds_str, options) - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], "token1") - self.assertEqual(result[1], "signed_token_signed") + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 8de9b219..ec4d5bec 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -12,13 +12,15 @@ validate_insert_request, validate_delete_request, validate_query_request, validate_get_detect_run_request, validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request, validate_invoke_connection_params, - validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request + validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request, + validate_file_upload_request ) from skyflow.utils import SkyflowMessages from skyflow.utils.enums import DetectEntities, RedactionType from skyflow.vault.data import GetRequest, UpdateRequest from skyflow.vault.detect import DeidentifyTextRequest, Transformations, DateTransformation, ReidentifyTextRequest, \ - FileInput, DeidentifyFileRequest + FileInput, DeidentifyFileRequest, Bleep +from skyflow.vault.data._file_upload_request import FileUploadRequest from skyflow.vault.tokens import DetokenizeRequest from skyflow.vault.connection._invoke_connection_request import InvokeConnectionRequest @@ -217,6 +219,18 @@ def test_validate_update_vault_config_invalid_cluster_id(self): validate_update_vault_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format("vault123")) + def test_validate_update_vault_config_missing_credentials(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123") + ) + def test_validate_connection_config_valid(self): config = { "connection_id": "conn123", @@ -250,6 +264,18 @@ def test_validate_connection_config_empty_connection_id(self): validate_connection_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value) + def test_validate_connection_config_missing_credentials(self): + config = { + "connection_id": "conn123", + "connection_url": "https://example.com", + } + with self.assertRaises(SkyflowError) as context: + validate_connection_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", "conn123") + ) + def test_validate_update_connection_config_valid(self): config = { "connection_id": "conn123", @@ -1163,3 +1189,279 @@ def test_validate_update_vault_config_with_invalid_token_uri_url(self): with self.assertRaises(SkyflowError) as context: validate_update_vault_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + # --- validate_file_from_request --- + + def test_validate_file_from_request_none_input(self): + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_INPUT.value) + + def test_validate_file_from_request_file_without_name_attr(self): + file_obj = MagicMock(spec=[]) # no attributes at all + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_file_with_empty_name(self): + file_obj = MagicMock() + file_obj.name = " " # whitespace-only name + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_extension_only_name(self): + file_obj = MagicMock() + # A trailing-slash path gives os.path.basename() == "", so splitext returns ("", "") + file_obj.name = "/some/directory/" + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_from_request_empty_string_file_path(self): + file_input = MagicMock() + file_input.file = None + file_input.file_path = "" # empty string — has_file_path=True, so goes to elif branch + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) + + # --- validate_deidentify_file_request bleep sub-fields --- + + def test_validate_deidentify_file_request_invalid_bleep_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, bleep="not_a_bleep") + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_TYPE.value) + + def test_validate_deidentify_file_request_invalid_bleep_gain(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(gain="loud") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_GAIN.value) + + def test_validate_deidentify_file_request_invalid_bleep_frequency(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(frequency="high") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value) + + def test_validate_deidentify_file_request_invalid_bleep_start_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(start_padding="early") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value) + + def test_validate_deidentify_file_request_invalid_bleep_stop_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(stop_padding="late") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value) + + # --- validate_deidentify_file_request output_directory --- + + def test_validate_deidentify_file_request_invalid_output_directory_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=123) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value) + + def test_validate_deidentify_file_request_output_directory_not_found(self): + file_input = FileInput(file_path=self.temp_file_path) + nonexistent = "/tmp/skyflow_nonexistent_dir_12345" + request = DeidentifyFileRequest(file=file_input, output_directory=nonexistent) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(nonexistent) + ) + + def test_validate_deidentify_file_request_valid_output_directory(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=self.temp_dir_path) + validate_deidentify_file_request(self.logger, request) + + # --- validate_file_upload_request --- + + def test_validate_file_upload_request_none(self): + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_table(self): + request = MagicMock() + request.table = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_empty_table(self): + request = MagicMock() + request.table = " " + request.column_name = "file_col" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(None)) + ) + + def test_validate_file_upload_request_empty_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = "" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type("")) + ) + + def test_validate_file_upload_request_empty_skyflow_id(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + skyflow_id=" ", + file_path=self.temp_file_path + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD") + ) + + def test_validate_file_upload_request_invalid_file_object_seek(self): + file_obj = MagicMock() + file_obj.seek.side_effect = OSError("seek failed") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=file_obj + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_OBJECT.value) + + def test_validate_file_upload_request_valid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path=self.temp_file_path + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_invalid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path="/nonexistent/path/file.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_PATH.value) + + def test_validate_file_upload_request_valid_base64(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded, + file_name="sample.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_base64_without_file_name(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_upload_request_invalid_base64_string(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64="not-valid-base64!!!", + file_name="sample.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BASE64_STRING.value) + + def test_validate_file_upload_request_valid_file_object(self): + with open(self.temp_file_path, "rb") as f: + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=f + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_missing_file_source(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + # --- validate_deidentify_text_request transformations --- + + def test_validate_deidentify_text_request_invalid_transformations(self): + request = DeidentifyTextRequest( + text="test text", + transformations="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value) + + # --- validate_reidentify_text_request masked_entities --- + + def test_validate_reidentify_text_request_invalid_masked_entities(self): + request = ReidentifyTextRequest( + text="test text", + masked_entities="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value) diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 6fa31e67..75826128 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -15,11 +15,19 @@ } CREDENTIALS_WITH_API_KEY = {"api_key": "dummy_api_key"} +CREDENTIALS_WITH_TOKEN = {"token": "dummy_static_token"} +CREDENTIALS_WITH_PATH = {"path": "/some/path/credentials.json"} +CREDENTIALS_WITH_STRING = {"credentials_string": '{"clientID": "x"}'} + class TestVaultClient(unittest.TestCase): def setUp(self): self.vault_client = VaultClient(CONFIG) + # ------------------------------------------------------------------ # + # Basic setters / getters # + # ------------------------------------------------------------------ # + def test_set_common_skyflow_credentials(self): credentials = {"api_key": "dummy_api_key"} self.vault_client.set_common_skyflow_credentials(credentials) @@ -31,173 +39,289 @@ def test_set_logger(self): self.assertEqual(self.vault_client.get_log_level(), "INFO") self.assertEqual(self.vault_client.get_logger(), mock_logger) + def test_get_vault_id(self): + self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + + def test_get_config(self): + self.assertEqual(self.vault_client.get_config(), CONFIG) + + def test_get_common_skyflow_credentials(self): + credentials = {"api_key": "dummy_api_key"} + self.vault_client.set_common_skyflow_credentials(credentials) + self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + + def test_get_log_level(self): + self.vault_client.set_logger("DEBUG", MagicMock()) + self.assertEqual(self.vault_client.get_log_level(), "DEBUG") + + def test_get_logger(self): + mock_logger = MagicMock() + self.vault_client.set_logger("INFO", mock_logger) + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + # ------------------------------------------------------------------ # + # initialize_client_configuration — first call (slow path) # + # ------------------------------------------------------------------ # + @patch("skyflow.vault.client.client.get_credentials") @patch("skyflow.vault.client.client.get_vault_url") @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") - def test_initialize_client_configuration(self, mock_init_api_client, mock_get_vault_url, mock_get_credentials): - mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY) + def test_initialize_client_configuration_first_call( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY mock_get_vault_url.return_value = "https://test-vault-url.com" self.vault_client.initialize_client_configuration() - mock_get_credentials.assert_called_once_with(CONFIG["credentials"], None, logger=None) - mock_get_vault_url.assert_called_once_with(CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None) + mock_get_credentials.assert_called_once_with( + CONFIG["credentials"], None, logger=None + ) + mock_get_vault_url.assert_called_once_with( + CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None + ) mock_init_api_client.assert_called_once() - @patch("skyflow.vault.client.client.Skyflow") - def test_initialize_api_client(self, mock_api_client): - self.vault_client.initialize_api_client("https://test-vault-url.com", "dummy_token") - mock_api_client.assert_called_once_with(base_url="https://test-vault-url.com", token="dummy_token") + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (static token) # + # ------------------------------------------------------------------ # - def test_get_records_api(self): - self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.records = MagicMock() - records_api = self.vault_client.get_records_api() - self.assertIsNotNone(records_api) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_api_key( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with api_key, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + # Side-effect simulates initialize_api_client actually setting __api_client + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) - def test_get_tokens_api(self): + self.vault_client.initialize_client_configuration() # first call — slow path + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() # second call — fast path + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_static_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with a static token, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_TOKEN + mock_get_vault_url.return_value = "https://test-vault-url.com" + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (service account) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_valid_sa_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, mock_is_expired + ): + """Service account with a still-valid token skips get_bearer_token entirely.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Seed the cached bearer token as if first call already ran self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.tokens = MagicMock() - tokens_api = self.vault_client.get_tokens_api() - self.assertIsNotNone(tokens_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "cached_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_query_api(self): + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — token expiry (no client reinit) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_sa_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_expired_token_no_reinit( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, + mock_is_expired, mock_generate_bearer_token + ): + """Expired service account token is regenerated in-place; httpx client is NOT recreated.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Client already initialized — simulate warm state with an expired token self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.query = MagicMock() - query_api = self.vault_client.get_query_api() - self.assertIsNotNone(query_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "expired_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_vault_id(self): - self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + self.vault_client.initialize_client_configuration() - @patch("skyflow.vault.client.client.generate_bearer_token") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token, - mock_generate_bearer_token_from_creds): - token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) - self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"]) - - def test_update_config(self): - new_config = {"credentials": "new_credentials"} - self.vault_client.update_config(new_config) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) - self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") + # Token was regenerated + mock_generate_bearer_token.assert_called_once() + self.assertEqual( + self.vault_client._VaultClient__bearer_token, "new_sa_token" + ) + # httpx client was NOT recreated + mock_init_api_client.assert_not_called() - def test_get_config(self): - self.assertEqual(self.vault_client.get_config(), CONFIG) + # ------------------------------------------------------------------ # + # initialize_client_configuration — config update forces reinit # + # ------------------------------------------------------------------ # - def test_get_common_skyflow_credentials(self): - credentials = {"api_key": "dummy_api_key"} - self.vault_client.set_common_skyflow_credentials(credentials) - self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_reinit_after_update_config( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """update_config() marks the client stale; next call must recreate it.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" - def test_get_log_level(self): - log_level = "DEBUG" - self.vault_client.set_logger(log_level, MagicMock()) - self.assertEqual(self.vault_client.get_log_level(), log_level) + # Simulate already-initialized client + self.vault_client._VaultClient__api_client = MagicMock() + self.vault_client._VaultClient__is_static_token = True - def test_get_logger(self): - mock_logger = MagicMock() - self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) + self.vault_client.update_config({"cluster_id": "new_cluster"}) + self.vault_client.initialize_client_configuration() - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_expired_token_raises_error(self, mock_generate_bearer_token, mock_is_expired): - """Test that expired token raises SkyflowError.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.return_value = ("expired_token", None) - mock_is_expired.return_value = True + mock_get_credentials.assert_called_once() + mock_get_vault_url.assert_called_once() + mock_init_api_client.assert_called_once() - with self.assertRaises(SkyflowError) as context: - self.vault_client.get_bearer_token(credentials) + # ------------------------------------------------------------------ # + # initialize_api_client — lambda token provider # + # ------------------------------------------------------------------ # - self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) - self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_passes_callable_token(self, mock_skyflow): + """initialize_api_client must pass a callable (lambda) as token, not a string.""" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - def test_get_bearer_token_expired_token_from_creds_string_raises_error(self, mock_generate_bearer_token_from_creds, mock_is_expired): - """Test that expired token from credentials string raises SkyflowError.""" - credentials = {"credentials_string": '{"key": "value"}'} - mock_generate_bearer_token_from_creds.return_value = ("expired_token", None) - mock_is_expired.return_value = True + args, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["base_url"], "https://test-vault-url.com") + self.assertTrue(callable(kwargs["token"]), "token must be a callable (lambda)") - with self.assertRaises(SkyflowError) as context: - self.vault_client.get_bearer_token(credentials) + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_returns_cached_bearer_token(self, mock_skyflow): + """Lambda returns __bearer_token when it is set (interceptor behaviour).""" + self.vault_client._VaultClient__bearer_token = "refreshed_token" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") - self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) - self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "refreshed_token") - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_reuses_valid_token(self, mock_generate_bearer_token, mock_is_expired): - """Test that valid bearer token is reused.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.return_value = ("valid_token", None) - mock_is_expired.return_value = False - - token1 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token1, "valid_token") - mock_generate_bearer_token.assert_called_once() + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_falls_back_to_initial_token(self, mock_skyflow): + """Lambda falls back to the initial token when __bearer_token is None.""" + self.vault_client._VaultClient__bearer_token = None + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "initial_token") + + # ------------------------------------------------------------------ # + # get_bearer_token # + # ------------------------------------------------------------------ # + + def test_get_bearer_token_with_api_key(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) + self.assertEqual(result, "dummy_api_key") + + def test_get_bearer_token_with_static_token(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_TOKEN) + self.assertEqual(result, "dummy_static_token") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("sa_token", None)) + def test_get_bearer_token_generates_from_path_on_first_call(self, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token") + self.assertEqual(self.vault_client._VaultClient__bearer_token, "sa_token") + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds", return_value=("sa_token_str", None)) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_generates_from_credentials_string(self, mock_log, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_STRING) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token_str") - token2 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token2, "valid_token") - mock_generate_bearer_token.assert_called_once() + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_regenerates_on_expiry(self, mock_log, mock_is_expired, mock_generate): + """Expired token is regenerated silently — no exception raised.""" + self.vault_client._VaultClient__bearer_token = "expired_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "new_token") - @patch("skyflow.vault.client.client.is_expired") @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_regenerates_after_config_update(self, mock_generate_bearer_token, mock_is_expired): - """Test that bearer token is regenerated after config update.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.side_effect = [("first_token", None), ("second_token", None)] - mock_is_expired.return_value = False + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_reuses_valid_cached_token(self, mock_log, mock_is_expired, mock_generate): + """Valid cached token is reused without calling generate_bearer_token.""" + self.vault_client._VaultClient__bearer_token = "valid_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_not_called() + self.assertEqual(result, "valid_token") + + # ------------------------------------------------------------------ # + # update_config # + # ------------------------------------------------------------------ # + + def test_update_config_sets_flag(self): + self.vault_client.update_config({"credentials": "new_credentials"}) + self.assertTrue(self.vault_client._VaultClient__is_config_updated) + self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") + + # ------------------------------------------------------------------ # + # API accessor stubs # + # ------------------------------------------------------------------ # + + def test_get_records_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_records_api()) - token1 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token1, "first_token") + def test_get_tokens_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_tokens_api()) - self.vault_client.update_config({"new_key": "new_value"}) + def test_get_query_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_query_api()) - token2 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token2, "second_token") - self.assertEqual(mock_generate_bearer_token.call_count, 2) - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_credentials_string(self, mock_log_info, mock_generate_bearer_token_from_creds, mock_is_expired): - """Test get_bearer_token with credentials_string.""" - credentials = {"credentials_string": '{"clientID": "test", "clientName": "test"}'} - mock_generate_bearer_token_from_creds.return_value = ("token_from_creds", None) - mock_is_expired.return_value = False - - token = self.vault_client.get_bearer_token(credentials) - - self.assertEqual(token, "token_from_creds") - mock_generate_bearer_token_from_creds.assert_called_once() - mock_log_info.assert_called_with( - SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, - None - ) - def test_get_bearer_token_with_token(self): - credentials = {"token": "dummy_token"} - token = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token, "dummy_token") - - def test_get_bearer_token_with_token_uri_in_credentials(self): - credentials = { - "path": "dummy_path", - "token_uri": "https://valid-url.com" - } - with patch("skyflow.vault.client.client.generate_bearer_token") as mock_generate_bearer_token, \ - patch("skyflow.vault.client.client.is_expired", return_value=False): - mock_generate_bearer_token.return_value = ("bearer_token", "bearer") - token = self.vault_client.get_bearer_token(credentials) - mock_generate_bearer_token.assert_called_once() - args, kwargs = mock_generate_bearer_token.call_args - self.assertIn("token_uri", args[1]) - self.assertEqual(args[1]["token_uri"], "https://valid-url.com") - self.assertEqual(token, "bearer_token") +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 35a13716..f073264c 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -1,9 +1,11 @@ +import json import unittest from unittest.mock import Mock, patch, MagicMock import requests from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages, parse_invoke_connection_response -from skyflow.utils.enums import RequestMethod +from skyflow.utils._utils import get_data_from_content_type, construct_invoke_connection_request +from skyflow.utils.enums import RequestMethod, ContentType from skyflow.utils._version import SDK_VERSION from skyflow.vault.connection import InvokeConnectionRequest from skyflow.vault.controller import Connection @@ -146,8 +148,9 @@ def test_invoke_request_error(self, mock_send, mock_get_credentials): with self.assertRaises(SkyflowError) as context: self.connection.invoke(request) - - self.assertEqual(context.exception.message, ERROR_RESPONSE_CONTENT) + + expected_message = SkyflowMessages.Error.API_ERROR.value.format(FAILURE_STATUS_CODE) + self.assertEqual(context.exception.message, expected_message) self.assertEqual(context.exception.http_code, FAILURE_STATUS_CODE) self.assertEqual(context.exception.request_id, "test-request-id") @@ -290,5 +293,383 @@ def test_invoke_construct_request_called(self, mock_construct, mock_get_credenti ) +class TestGetDataFromContentType(unittest.TestCase): + """Tests for get_data_from_content_type covering all supported content types.""" + + DATA = {'key': 'value', 'num': 42} + + # ── JSON ────────────────────────────────────────────────────────────────── + def test_json_content_type_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.JSON.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + # ── URL-encoded ─────────────────────────────────────────────────────────── + def test_urlencoded_content_type_returns_encoded_string(self): + data, files = get_data_from_content_type({'k': 'v'}, ContentType.URLENCODED.value) + self.assertIn('k=v', data) + self.assertEqual(files, {}) + + def test_urlencoded_nested_dict(self): + payload = {'a': {'b': 'c'}} + data, files = get_data_from_content_type(payload, ContentType.URLENCODED.value) + self.assertIsInstance(data, str) + self.assertIn('c', data) + self.assertEqual(files, {}) + + # ── Form-data ───────────────────────────────────────────────────────────── + def test_formdata_content_type_returns_files_dict(self): + data, files = get_data_from_content_type({'f1': 'v1', 'f2': 'v2'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertEqual(files, {'f1': (None, 'v1'), 'f2': (None, 'v2')}) + + def test_formdata_converts_values_to_str(self): + data, files = get_data_from_content_type({'num': 99}, ContentType.FORMDATA.value) + self.assertEqual(files['num'], (None, '99')) + + def test_formdata_single_key(self): + data, files = get_data_from_content_type({'only': 'one'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertIn('only', files) + + # ── XML ─────────────────────────────────────────────────────────────────── + def test_xml_text_xml_content_type_wraps_in_root(self): + data, files = get_data_from_content_type({'key': 'value'}, 'text/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertIn('', data) + self.assertEqual(files, {}) + + def test_xml_application_xml_content_type(self): + data, files = get_data_from_content_type({'key': 'value'}, 'application/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_content_type_enum_value(self): + data, files = get_data_from_content_type({'key': 'value'}, ContentType.XML.value) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw_string', 'text/xml') + self.assertEqual(data, 'raw_string') + self.assertEqual(files, {}) + + # ── HTML ────────────────────────────────────────────────────────────────── + def test_html_content_type_dict_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.HTML.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_text_html_content_type(self): + data, files = get_data_from_content_type(self.DATA, 'text/html') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw', ContentType.HTML.value) + self.assertEqual(data, 'raw') + self.assertEqual(files, {}) + + # ── None / unknown ──────────────────────────────────────────────────────── + def test_none_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, None) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, 'application/octet-stream') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_non_dict_returns_str(self): + data, files = get_data_from_content_type('blob', 'application/octet-stream') + self.assertEqual(data, 'blob') + self.assertEqual(files, {}) + + +class TestParseInvokeConnectionResponse(unittest.TestCase): + """Tests for parse_invoke_connection_response covering all success and error paths.""" + + def _make_response(self, status_code, content, headers=None, raise_http_error=False): + mock_resp = Mock(spec=requests.Response) + mock_resp.status_code = status_code + if isinstance(content, str): + mock_resp.content = content.encode('utf-8') + else: + mock_resp.content = content + mock_resp.headers = headers or {} + if raise_http_error: + mock_resp.raise_for_status.side_effect = requests.HTTPError() + else: + mock_resp.raise_for_status.return_value = None + return mock_resp + + # ── Success paths ───────────────────────────────────────────────────────── + def test_success_json_content_type_parses_body(self): + resp = self._make_response( + 200, + '{"result": "ok"}', + {'content-type': 'application/json', 'x-request-id': 'req-1'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'result': 'ok'}) + self.assertEqual(result.metadata.get('request_id'), 'req-1') + self.assertIsNone(result.errors) + + def test_success_plain_text_content_type_returns_string(self): + resp = self._make_response( + 200, + 'plain text response', + {'content-type': 'text/plain'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'plain text response') + + def test_success_no_content_type_tries_json_parse(self): + resp = self._make_response(200, '{"a": 1}', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'a': 1}) + + def test_success_no_content_type_invalid_json_returns_string(self): + resp = self._make_response(200, 'not json', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not json') + + def test_success_no_x_request_id_metadata_is_empty(self): + resp = self._make_response(200, '{}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.metadata, {}) + + def test_success_invalid_json_with_json_content_type_returns_raw_string(self): + resp = self._make_response( + 200, + 'not-json', + {'content-type': 'application/json'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not-json') + + def test_success_bytes_content_decoded(self): + resp = self._make_response(200, b'{"x": 1}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'x': 1}) + + # ── Error paths — standard Skyflow format ──────────────────────────────── + def test_error_standard_skyflow_format_extracts_message(self): + body = json.dumps({'error': {'message': 'bad input', 'http_code': 400, 'http_status': 'BAD_REQUEST', 'grpc_code': 3, 'details': []}}) + resp = self._make_response(400, body, {'x-request-id': 'r1'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + e = ctx.exception + self.assertEqual(e.message, 'bad input') + self.assertEqual(e.http_code, 400) + self.assertEqual(e.request_id, 'r1') + self.assertEqual(e.http_status, 'BAD_REQUEST') + self.assertEqual(e.grpc_code, 3) + + def test_error_standard_format_falls_back_to_http_code_when_missing(self): + body = json.dumps({'error': {'message': 'oops'}}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.http_code, 500) + + def test_error_standard_format_falls_back_to_sdk_message_when_missing(self): + body = json.dumps({'error': {}}) + resp = self._make_response(503, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — string error value ───────────────────────────────────── + def test_error_string_error_value_used_as_message(self): + body = json.dumps({'error': 'gateway timed out'}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.message, 'gateway timed out') + + def test_error_empty_string_error_value_falls_back_to_sdk_message(self): + body = json.dumps({'error': ''}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-standard JSON ────────────────────────────────────── + def test_error_no_error_key_uses_sdk_message(self): + body = json.dumps({'message': 'something went wrong'}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_non_dict_json_body_uses_sdk_message(self): + body = json.dumps(['list', 'not', 'dict']) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_numeric_error_value_uses_sdk_message(self): + body = json.dumps({'error': 12345}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-JSON / empty body ────────────────────────────────── + def test_error_empty_body_uses_sdk_message(self): + resp = self._make_response(502, '', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + self.assertEqual(ctx.exception.http_code, 502) + + def test_error_html_body_uses_sdk_message(self): + resp = self._make_response(502, 'Bad Gateway', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + def test_error_plain_text_body_uses_sdk_message(self): + resp = self._make_response(503, 'Service Unavailable', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── error-from-client header ────────────────────────────────────────────── + def test_error_from_client_true_appended_to_details(self): + body = json.dumps({'error': {'message': 'client error', 'http_code': 400, 'details': []}}) + resp = self._make_response(400, body, {'error-from-client': 'true', 'x-request-id': 'r2'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is True for d in ctx.exception.details)) + + def test_error_from_client_false_appended_to_details(self): + body = json.dumps({'error': {'message': 'server error', 'http_code': 500}}) + resp = self._make_response(500, body, {'error-from-client': 'false'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is False for d in ctx.exception.details)) + + def test_error_from_client_initialises_details_when_none(self): + body = json.dumps({'error': {'message': 'err', 'http_code': 400}}) + resp = self._make_response(400, body, {'error-from-client': 'true'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertIsNotNone(ctx.exception.details) + self.assertTrue(len(ctx.exception.details) > 0) + + +class TestConstructInvokeConnectionRequest(unittest.TestCase): + """Tests for construct_invoke_connection_request covering method, body, headers, path/query params.""" + + BASE_URL = 'https://example.com/api' + LOGGER = Mock() + + def _make_request(self, method=RequestMethod.POST, body=None, headers=None, + path_params=None, query_params=None): + return InvokeConnectionRequest( + method=method, + body=body, + headers=headers, + path_params=path_params or {}, + query_params=query_params or {} + ) + + def test_post_with_json_body_prepares_request(self): + req = self._make_request(body={'k': 'v'}, headers={'Content-Type': 'application/json'}) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIn('k', prepared.body) + + def test_get_with_no_body(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'GET') + + def test_urlencoded_body_is_form_encoded(self): + req = self._make_request( + body={'field': 'val'}, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('field=val', prepared.body) + + def test_formdata_body_produces_multipart_request(self): + req = self._make_request( + body={'file_field': 'data'}, + headers={'Content-Type': 'multipart/form-data'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIsNotNone(prepared.body) + + def test_xml_body_contains_xml_tags(self): + req = self._make_request( + body={'item': 'data'}, + headers={'Content-Type': 'text/xml'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('', prepared.body) + + def test_path_params_substituted_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + path_params={'id': '123'} + ) + url_with_placeholder = 'https://example.com/api/{id}/resource' + prepared = construct_invoke_connection_request(req, url_with_placeholder, self.LOGGER) + self.assertIn('123', prepared.url) + self.assertNotIn('{id}', prepared.url) + + def test_query_params_appear_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + query_params={'page': '1', 'limit': '10'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('page=1', prepared.url) + self.assertIn('limit=10', prepared.url) + + def test_invalid_headers_raises_skyflow_error(self): + req = InvokeConnectionRequest(method=RequestMethod.POST, headers='bad-headers') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + + def test_invalid_body_raises_skyflow_error(self): + req = InvokeConnectionRequest( + method=RequestMethod.POST, + body='not-a-dict', + headers={'Content-Type': 'application/json'} + ) + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + + def test_invalid_method_raises_skyflow_error(self): + req = InvokeConnectionRequest(method='INVALID_METHOD') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_METHOD.value) + + def test_trailing_slash_stripped_from_url(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL + '/', self.LOGGER) + self.assertNotIn('//', prepared.url.replace('https://', '')) + + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index c2f9a861..b86087f5 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch, MagicMock import base64 import os +import tempfile from skyflow.error import SkyflowError from skyflow.generated.rest import WordCharacterCount from skyflow.utils import SkyflowMessages @@ -513,16 +514,12 @@ def test_get_detect_run_in_progress_status(self, mock_validate): self.vault_client.get_detect_file_api.return_value = files_api - # Execute - with patch.object(self.detect, "_Detect__parse_deidentify_file_response") as mock_parse: - result = self.detect.get_detect_run(req) + # Execute — IN_PROGRESS is returned directly without going through the parser + result = self.detect.get_detect_run(req) - # Verify IN_PROGRESS handling - mock_parse.assert_called_once() - args = mock_parse.call_args[0][0] - self.assertIsInstance(args, DeidentifyFileResponse) - self.assertEqual(args.status, 'IN_PROGRESS') - self.assertEqual(args.run_id, run_id) + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, 'IN_PROGRESS') + self.assertEqual(result.run_id, run_id) def test_get_transformations_with_shift_dates(self): @@ -711,3 +708,98 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) + + def test_poll_for_processed_file_exception(self): + files_api = Mock() + files_api.with_raw_response = files_api + files_api.get_run.side_effect = Exception("poll error") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect._Detect__poll_for_processed_file("runid", max_wait_time=5) + + def test_save_output_directory_not_exists(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=False): + self.detect._Detect__save_deidentify_file_response_output( + response, "/nonexistent_dir", "file.txt", "file" + ) + + def test_save_output_second_non_redacted_item(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output1 = Mock() + output1.processedFile = base64.b64encode(b"data1").decode() + output1.processedFileType = "redacted_file" + output1.processedFileExtension = "txt" + output2 = Mock() + output2.processedFile = base64.b64encode(b"data2").decode() + output2.processedFileType = "entities" + output2.processedFileExtension = "json" + response = Mock() + response.output = [output1, output2] + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "original.txt", "original" + ) + + def test_save_output_path_traversal_blocked(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + call_count = [0] + + def fake_realpath(p): + call_count[0] += 1 + if call_count[0] == 1: + return "/safe_dir" + return "/outside/path" + + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=True), \ + patch("skyflow.vault.controller._detect.os.path.realpath", side_effect=fake_realpath): + self.detect._Detect__save_deidentify_file_response_output( + response, "/safe_dir", "file.txt", "file" + ) + + def test_save_output_write_exception(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.base64.b64decode", + side_effect=Exception("decode error")), \ + self.assertRaises(Exception): + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "file.txt", "file" + ) + + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") + @patch("skyflow.vault.controller._detect.base64") + def test_deidentify_file_api_error_inside_try(self, mock_base64, mock_validate): + file_content = b"test content" + file_obj = Mock() + file_obj.read.return_value = file_content + file_obj.name = "test.txt" + mock_base64.b64encode.return_value.decode.return_value = "encoded" + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) + req.entities = [] + req.token_format = None + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = None + req.wait_time = None + files_api = Mock() + files_api.with_raw_response = files_api + files_api.deidentify_text.side_effect = Exception("API error inside try") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect.deidentify_file(req) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 4e1a0dda..993cd72a 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -722,6 +722,26 @@ def test_upload_file_with_missing_file_source(self, mock_validate): self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_without_skyflow_id_successful(self, mock_validate): + """Test upload_file succeeds when skyflow_id is None (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/test.txt", + ) + mocked_open = mock_open_func(read_data=b"test file content") + mock_api_response = Mock() + mock_api_response.data = Mock(skyflow_id="generated-id-123") + records_api = self.vault_client.get_records_api.return_value + records_api.with_raw_response.upload_file_v_2.return_value = mock_api_response + with patch('builtins.open', mocked_open): + result = self.vault.upload_file(request) + mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + self.assertIsNone(request.skyflow_id) + self.assertEqual(result.skyflow_id, "generated-id-123") + self.assertIsNone(result.errors) + class TestFileUploadValidation(unittest.TestCase): def setUp(self): self.logger = Mock() @@ -874,3 +894,38 @@ def test_validate_missing_file_source(self): with self.assertRaises(SkyflowError) as error: validate_file_upload_request(self.logger, request) self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + def test_validate_none_skyflow_id_is_allowed(self): + """Test that skyflow_id=None passes validation (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + base64="dGVzdCBmaWxlIGNvbnRlbnQ=", + file_name="test.txt" + ) + self.assertIsNone(request.skyflow_id) + validate_file_upload_request(self.logger, request) + + @patch('os.path.exists') + @patch('os.path.isfile') + def test_validate_file_path_without_skyflow_id(self, mock_isfile, mock_exists): + """Test validation succeeds with file_path and no skyflow_id.""" + mock_exists.return_value = True + mock_isfile.return_value = True + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/file.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_object_without_skyflow_id(self): + """Test validation succeeds with file_object and no skyflow_id.""" + mock_file = Mock() + mock_file.seek = Mock() + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_object=mock_file + ) + validate_file_upload_request(self.logger, request)