|
1 | 1 | import json |
| 2 | +import tempfile |
| 3 | + |
2 | 4 | import click |
3 | 5 | import boto3 |
4 | 6 | import ipify |
@@ -473,76 +475,99 @@ def impl_get_task_ip(cluster_name, task_arn, region): |
473 | 475 | return nics[0]["Association"]["PublicIp"] |
474 | 476 |
|
475 | 477 |
|
476 | | -def impl_upload_model_file(local_file_path, bucket_name, region): |
| 478 | +def impl_upload_model_file(model_name, bucket_name, region): |
477 | 479 | """ |
478 | 480 | Uploads a model to S3. |
479 | 481 |
|
480 | | - :param local_file_path: The path to the model file on disk. |
| 482 | + :param model_name: The filename of the model to upload (must be in the current directory). |
481 | 483 | :param bucket_name: The S3 bucket name. |
482 | 484 | :param region: The region, or `None` to pull the region from the environment. |
483 | 485 | """ |
484 | 486 | client = make_client("s3", region) |
485 | | - remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path) |
486 | | - client.upload_file(local_file_path, bucket_name, remote_path) |
| 487 | + remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name) |
| 488 | + client.upload_file(model_name, bucket_name, remote_path) |
487 | 489 | print("Uploaded to: {}\n".format(remote_path)) |
488 | 490 |
|
489 | 491 |
|
490 | | -def impl_download_model_file(local_file_path, bucket_name, region): |
| 492 | +def impl_download_model_file(model_name, bucket_name, region): |
491 | 493 | """ |
492 | 494 | Downloads a model from S3. |
493 | 495 |
|
494 | | - :param local_file_path: The path to the model file on disk. |
| 496 | + :param model_name: The filename of the model to download (must be in the current directory). |
495 | 497 | :param bucket_name: The S3 bucket name. |
496 | 498 | :param region: The region, or `None` to pull the region from the environment. |
497 | 499 | """ |
498 | 500 | client = make_client("s3", region) |
499 | | - remote_path = "axon-uploaded-trained-models/" + os.path.basename(local_file_path) |
500 | | - client.download_file(bucket_name, remote_path, local_file_path) |
| 501 | + remote_path = "axon-uploaded-trained-models/" + os.path.basename(model_name) |
| 502 | + client.download_file(bucket_name, remote_path, model_name) |
501 | 503 | print("Downloaded from: {}\n".format(remote_path)) |
502 | 504 |
|
503 | 505 |
|
504 | | -def impl_download_training_script(local_script_path, bucket_name, region): |
| 506 | +def impl_download_training_script(script_name, bucket_name, region): |
505 | 507 | """ |
506 | 508 | Downloads a training script from S3. |
507 | 509 |
|
508 | | - :param local_script_path: The path to the training script on disk. |
| 510 | + :param script_name: The filename of the script to download (must be in the current directory). |
509 | 511 | :param bucket_name: The S3 bucket name. |
510 | 512 | :param region: The region, or `None` to pull the region from the environment. |
511 | 513 | """ |
512 | 514 | client = make_client("s3", region) |
513 | | - remote_path = "axon-uploaded-training-scripts/" + os.path.basename(local_script_path) |
514 | | - client.download_file(bucket_name, remote_path, local_script_path) |
| 515 | + remote_path = "axon-uploaded-training-scripts/" + os.path.basename(script_name) |
| 516 | + client.download_file(bucket_name, remote_path, script_name) |
515 | 517 | print("Downloaded from: {}\n".format(remote_path)) |
516 | 518 |
|
517 | 519 |
|
518 | | -def impl_upload_dataset(local_dataset_path, bucket_name, region): |
| 520 | +def impl_upload_dataset(dataset_name, bucket_name, region): |
519 | 521 | """ |
520 | 522 | Uploads a dataset to S3. |
521 | 523 |
|
522 | | - :param local_dataset_path: The path to the dataset on disk. |
| 524 | + :param dataset_name: The filename of the dataset to upload (must be in the current directory). |
523 | 525 | :param bucket_name: The S3 bucket name. |
524 | 526 | :param region: The region, or `None` to pull the region from the environment. |
525 | 527 | """ |
526 | 528 | client = make_client("s3", region) |
527 | | - remote_path = "axon-uploaded-datasets/" + os.path.basename(local_dataset_path) |
528 | | - client.upload_file(local_dataset_path, bucket_name, remote_path) |
| 529 | + remote_path = "axon-uploaded-datasets/" + os.path.basename(dataset_name) |
| 530 | + client.upload_file(dataset_name, bucket_name, remote_path) |
529 | 531 | print("Uploaded to: {}\n".format(remote_path)) |
530 | 532 |
|
531 | 533 |
|
532 | | -def impl_download_dataset(local_dataset_path, bucket_name, region): |
| 534 | +def impl_download_dataset(dataset_name, bucket_name, region): |
533 | 535 | """ |
534 | 536 | Downloads a dataset from S3. |
535 | 537 |
|
536 | | - :param local_dataset_path: The path to the dataset on disk. |
| 538 | + :param dataset_name: The filename of the dataset to download (must be in the current directory). |
537 | 539 | :param bucket_name: The S3 bucket name. |
538 | 540 | :param region: The region, or `None` to pull the region from the environment. |
539 | 541 | """ |
540 | 542 | client = make_client("s3", region) |
541 | | - remote_path = "axon-uploaded-datasets/" + os.path.basename(local_dataset_path) |
542 | | - client.download_file(bucket_name, remote_path, local_dataset_path) |
| 543 | + remote_path = "axon-uploaded-datasets/" + os.path.basename(dataset_name) |
| 544 | + client.download_file(bucket_name, remote_path, dataset_name) |
543 | 545 | print("Downloaded from: {}\n".format(remote_path)) |
544 | 546 |
|
545 | 547 |
|
| 548 | +def impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region): |
| 549 | + """ |
| 550 | + Updates the training progress in S3 for a model specified by its name. |
| 551 | +
|
| 552 | + :param model_name: The filename of the model. |
| 553 | + :param dataset_name: The filename of the dataset. |
| 554 | + :param progress_text: The text to write into the progress file. |
| 555 | + :param bucket_name: The S3 bucket name. |
| 556 | + :param region: The region, or `None` to pull the region from the environment. |
| 557 | + """ |
| 558 | + local_file, path = tempfile.mkstemp() |
| 559 | + try: |
| 560 | + with open(local_file, "w") as f: |
| 561 | + f.write(progress_text) |
| 562 | + client = make_client("s3", region) |
| 563 | + remote_path = "axon-training-progress/" + os.path.basename(model_name) + "/" + \ |
| 564 | + os.path.basename(dataset_name) + "/progress.txt" |
| 565 | + client.upload_file(path, bucket_name, remote_path) |
| 566 | + print("Updated progress in: {}\n".format(remote_path)) |
| 567 | + finally: |
| 568 | + os.remove(path) |
| 569 | + |
| 570 | + |
546 | 571 | @click.group() |
547 | 572 | def cli(): |
548 | 573 | return |
@@ -613,40 +638,50 @@ def get_container_ip(cluster_name, task, region): |
613 | 638 |
|
614 | 639 |
|
615 | 640 | @cli.command(name="upload-model-file") |
616 | | -@click.argument("local-file-path") |
| 641 | +@click.argument("model-name") |
617 | 642 | @click.argument("bucket-name") |
618 | 643 | @click.option("--region", default="us-east-1", help="The region to connect to.") |
619 | | -def upload_model_file(local_file_path, bucket_name, region): |
620 | | - impl_upload_model_file(local_file_path, bucket_name, region) |
| 644 | +def upload_model_file(model_name, bucket_name, region): |
| 645 | + impl_upload_model_file(model_name, bucket_name, region) |
621 | 646 |
|
622 | 647 |
|
623 | 648 | @cli.command(name="download-model-file") |
624 | | -@click.argument("local-file-path") |
| 649 | +@click.argument("model-name") |
625 | 650 | @click.argument("bucket-name") |
626 | 651 | @click.option("--region", default="us-east-1", help="The region to connect to.") |
627 | | -def download_model_file(local_file_path, bucket_name, region): |
628 | | - impl_download_model_file(local_file_path, bucket_name, region) |
| 652 | +def download_model_file(model_name, bucket_name, region): |
| 653 | + impl_download_model_file(model_name, bucket_name, region) |
629 | 654 |
|
630 | 655 |
|
631 | 656 | @cli.command(name="download-training-script") |
632 | | -@click.argument("local-script-path") |
| 657 | +@click.argument("script-name") |
633 | 658 | @click.argument("bucket-name") |
634 | 659 | @click.option("--region", default="us-east-1", help="The region to connect to.") |
635 | | -def download_training_script(local_script_path, bucket_name, region): |
636 | | - impl_download_training_script(local_script_path, bucket_name, region) |
| 660 | +def download_training_script(script_name, bucket_name, region): |
| 661 | + impl_download_training_script(script_name, bucket_name, region) |
637 | 662 |
|
638 | 663 |
|
639 | 664 | @cli.command(name="download-dataset") |
640 | | -@click.argument("local-dataset-path") |
| 665 | +@click.argument("dataset-name") |
641 | 666 | @click.argument("bucket-name") |
642 | 667 | @click.option("--region", default="us-east-1", help="The region to connect to.") |
643 | | -def download_dataset(local_dataset_path, bucket_name, region): |
644 | | - impl_download_dataset(local_dataset_path, bucket_name, region) |
| 668 | +def download_dataset(dataset_name, bucket_name, region): |
| 669 | + impl_download_dataset(dataset_name, bucket_name, region) |
645 | 670 |
|
646 | 671 |
|
647 | 672 | @cli.command(name="upload-dataset") |
648 | | -@click.argument("local-dataset-path") |
| 673 | +@click.argument("dataset-name") |
| 674 | +@click.argument("bucket-name") |
| 675 | +@click.option("--region", default="us-east-1", help="The region to connect to.") |
| 676 | +def upload_dataset(dataset_name, bucket_name, region): |
| 677 | + impl_upload_dataset(dataset_name, bucket_name, region) |
| 678 | + |
| 679 | + |
| 680 | +@cli.command(name="update-training-progress") |
| 681 | +@click.argument("model-name") |
| 682 | +@click.argument("dataset-name") |
| 683 | +@click.argument("progress-text") |
649 | 684 | @click.argument("bucket-name") |
650 | 685 | @click.option("--region", default="us-east-1", help="The region to connect to.") |
651 | | -def upload_dataset(local_dataset_path, bucket_name, region): |
652 | | - impl_upload_dataset(local_dataset_path, bucket_name, region) |
| 686 | +def update_training_progress(model_name, dataset_name, progress_text, bucket_name, region): |
| 687 | + impl_update_training_progress(model_name, dataset_name, progress_text, bucket_name, region) |
0 commit comments