@@ -674,12 +674,11 @@ def impl_download_dataset(dataset_path, bucket_name, region):
674674 print ("Downloaded from: {}\n " .format (key ))
675675
676676
677- def impl_update_training_progress (model_name , dataset_name , progress_text , bucket_name , region ):
677+ def impl_update_training_progress (job_id , progress_text , bucket_name , region ):
678678 """
679679 Updates the training progress in S3 for a model specified by its name.
680680
681- :param model_name: The filename of the model.
682- :param dataset_name: The filename of the dataset.
681+ :param job_id: The unique Job ID.
683682 :param progress_text: The text to write into the progress file.
684683 :param bucket_name: The S3 bucket name.
685684 :param region: The region, or `None` to pull the region from the environment.
@@ -689,46 +688,43 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke
689688 with open (local_file , "w" ) as f :
690689 f .write (progress_text )
691690 client = make_client ("s3" , region )
692- remote_path = create_progress_prefix (model_name , dataset_name ) + "/progress.txt"
691+ remote_path = create_progress_prefix (job_id ) + "/progress.txt"
693692 client .upload_file (path , bucket_name , remote_path )
694693 print ("Updated progress in: {}\n " .format (remote_path ))
695694 finally :
696695 os .remove (path )
697696
698697
699- def impl_create_heartbeat (model_name , dataset_name , bucket_name , region ):
698+ def impl_create_heartbeat (job_id , bucket_name , region ):
700699 """
701700 Creates a heartbeat that Axon uses to check if the training script is running properly.
702701
703- :param model_name: The filename of the model.
704- :param dataset_name: The filename of the dataset.
702+ :param job_id: The unique Job ID.
705703 :param bucket_name: The S3 bucket name.
706704 :param region: The region, or `None` to pull the region from the environment.
707705 """
708706 client = make_client ("s3" , region )
709- remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
707+ remote_path = create_progress_prefix (job_id ) + "/heartbeat.txt"
710708 client .put_object (Body = "1" , Bucket = bucket_name , Key = remote_path )
711709 print ("Created heartbeat file in: {}\n " .format (remote_path ))
712710
713711
714- def impl_remove_heartbeat (model_name , dataset_name , bucket_name , region ):
712+ def impl_remove_heartbeat (job_id , bucket_name , region ):
715713 """
716714 Removes a heartbeat that Axon uses to check if the training script is running properly.
717715
718- :param model_name: The filename of the model.
719- :param dataset_name: The filename of the dataset.
716+ :param job_id: The unique Job ID.
720717 :param bucket_name: The S3 bucket name.
721718 :param region: The region, or `None` to pull the region from the environment.
722719 """
723720 client = make_client ("s3" , region )
724- remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
721+ remote_path = create_progress_prefix (job_id ) + "/heartbeat.txt"
725722 client .put_object (Body = "0" , Bucket = bucket_name , Key = remote_path )
726723 print ("Removed heartbeat file in: {}\n " .format (remote_path ))
727724
728725
729- def create_progress_prefix (model_name , dataset_name ):
730- return "axon-training-progress/" + os .path .basename (model_name ) + "/" + \
731- os .path .basename (dataset_name )
726+ def create_progress_prefix (job_id ):
727+ return "axon-training-progress/" + job_id
732728
733729
734730@click .group ()
@@ -918,53 +914,44 @@ def download_dataset(dataset_path, region):
918914
919915
920916@cli .command (name = "update-training-progress" )
921- @click .argument ("model-name" )
922- @click .argument ("dataset-name" )
917+ @click .argument ("job-id" )
923918@click .argument ("progress-text" )
924919@click .option ("--region" , help = "The region to connect to." ,
925920 type = click .Choice (region_choices ))
926- def update_training_progress (model_name , dataset_name , progress_text , region ):
921+ def update_training_progress (job_id , progress_text , region ):
927922 """
928923 Updates the training progress. Meant to be used while a training script is running to provide
929924 progress updates to Axon.
930925
931- MODEL_NAME The filename of the model currently being trained.
932-
933- DATASET_NAME The name of the dataset currently being trained on.
926+ JOB_ID The unique Job ID.
934927
935928 PROGRESS_TEXT The text to write to the progress file.
936929 """
937- impl_update_training_progress (model_name , dataset_name , progress_text , ensure_s3_bucket (region ),
930+ impl_update_training_progress (job_id , progress_text , ensure_s3_bucket (region ),
938931 region )
939932
940933
941934@cli .command (name = "create-heartbeat" )
942- @click .argument ("model-name" )
943- @click .argument ("dataset-name" )
935+ @click .argument ("job-id" )
944936@click .option ("--region" , help = "The region to connect to." ,
945937 type = click .Choice (region_choices ))
946- def create_heartbeat (model_name , dataset_name , region ):
938+ def create_heartbeat (job_id , region ):
947939 """
948940 Creates a heartbeat that Axon uses to check if the training script is running properly.
949941
950- MODEL_NAME The filename of the model currently being trained.
951-
952- DATASET_NAME The name of the dataset currently being trained on.
942+ JOB_ID The unique Job ID.
953943 """
954- impl_create_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
944+ impl_create_heartbeat (job_id , ensure_s3_bucket (region ), region )
955945
956946
957947@cli .command (name = "remove-heartbeat" )
958- @click .argument ("model-name" )
959- @click .argument ("dataset-name" )
948+ @click .argument ("job-id" )
960949@click .option ("--region" , help = "The region to connect to." ,
961950 type = click .Choice (region_choices ))
962- def remove_heartbeat (model_name , dataset_name , region ):
951+ def remove_heartbeat (job_id , region ):
963952 """
964953 Removes a heartbeat that Axon uses to check if the training script is running properly.
965954
966- MODEL_NAME The filename of the model currently being trained.
967-
968- DATASET_NAME The name of the dataset currently being trained on.
955+ JOB_ID The unique Job ID.
969956 """
970- impl_remove_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
957+ impl_remove_heartbeat (job_id , ensure_s3_bucket (region ), region )
0 commit comments