99import click
1010from click .core import ParameterSource # type: ignore[attr-defined]
1111from rich import print as rprint
12+ from rich .json import JSON
1213from tabulate import tabulate
1314
1415from together import Together
15- from together .cli .api .utils import BOOL_WITH_AUTO , INT_WITH_MAX
16+ from together .cli .api .utils import BOOL_WITH_AUTO , INT_WITH_MAX , generate_progress_bar
1617from together .types .finetune import (
1718 DownloadCheckpointType ,
1819 FinetuneEventType ,
1920 FinetuneTrainingLimits ,
21+ FullTrainingType ,
22+ LoRATrainingType ,
2023)
2124from together .utils import (
2225 finetune_price_to_dollars ,
2932
3033_CONFIRMATION_MESSAGE = (
3134 "You are about to create a fine-tuning job. "
32- "The cost of your job will be determined by the model size, the number of tokens "
35+ "The estimated price of this job is {price}. "
36+ "The actual cost of your job will be determined by the model size, the number of tokens "
3337 "in the training file, the number of tokens in the validation file, the number of epochs, and "
34- "the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n "
38+ "the number of evaluations. Visit https://www.together.ai/pricing to learn more about fine-tuning pricing.\n "
39+ "{warning}"
3540 "You can pass `-y` or `--confirm` to your command to skip this message.\n \n "
3641 "Do you want to proceed?"
3742)
3843
44+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
45+ "The estimated price of this job is significantly greater than your current credit limit and balance combined. "
46+ "It will likely get cancelled due to insufficient funds. "
47+ "Consider increasing your credit limit at https://api.together.xyz/settings/profile\n "
48+ )
49+
3950
4051class DownloadCheckpointTypeChoice (click .Choice ):
4152 def __init__ (self ) -> None :
@@ -357,12 +368,36 @@ def create(
357368 "You have specified a number of evaluation loops but no validation file."
358369 )
359370
360- if confirm or click .confirm (_CONFIRMATION_MESSAGE , default = True , show_default = True ):
371+ finetune_price_estimation_result = client .fine_tuning .estimate_price (
372+ training_file = training_file ,
373+ validation_file = validation_file ,
374+ model = model ,
375+ n_epochs = n_epochs ,
376+ n_evals = n_evals ,
377+ training_type = "lora" if lora else "full" ,
378+ training_method = training_method ,
379+ )
380+
381+ price = click .style (
382+ f"${ finetune_price_estimation_result .estimated_total_price :.2f} " ,
383+ bold = True ,
384+ )
385+
386+ if not finetune_price_estimation_result .allowed_to_proceed :
387+ warning = click .style (_WARNING_MESSAGE_INSUFFICIENT_FUNDS , fg = "red" , bold = True )
388+ else :
389+ warning = ""
390+
391+ confirmation_message = _CONFIRMATION_MESSAGE .format (
392+ price = price ,
393+ warning = warning ,
394+ )
395+
396+ if confirm or click .confirm (confirmation_message , default = True , show_default = True ):
361397 response = client .fine_tuning .create (
362398 ** training_args ,
363399 verbose = True ,
364400 )
365-
366401 report_string = f"Successfully submitted a fine-tuning job { response .id } "
367402 if response .created_at is not None :
368403 created_time = datetime .strptime (
@@ -401,6 +436,9 @@ def list(ctx: click.Context) -> None:
401436 "Price" : f"""${
402437 finetune_price_to_dollars (float (str (i .total_price )))
403438 } """ , # convert to string for mypy typing
439+ "Progress" : generate_progress_bar (
440+ i , datetime .now ().astimezone (), use_rich = False
441+ ),
404442 }
405443 )
406444 table = tabulate (display_list , headers = "keys" , tablefmt = "grid" , showindex = True )
@@ -420,7 +458,15 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
420458 # remove events from response for cleaner output
421459 response .events = None
422460
423- click .echo (json .dumps (response .model_dump (exclude_none = True ), indent = 4 ))
461+ rprint (JSON .from_data (response .model_dump (exclude_none = True )))
462+ progress_text = generate_progress_bar (
463+ response , datetime .now ().astimezone (), use_rich = True
464+ )
465+ status = "Unknown"
466+ if response .status is not None :
467+ status = response .status .value
468+ prefix = f"Status: [bold]{ status } [/bold],"
469+ rprint (f"{ prefix } { progress_text } " )
424470
425471
426472@fine_tuning .command ()
0 commit comments