Because the performance of the learned policy is unstable – and not good enough to approximate human performance, it’s a good idea to scale training up to multiple attempts across different hyperparameters. SageMaker makes this really easy. First, you will send a training script to a single training job and then you will search across hyperparameter values with the hyperparameter tuning feature.
To train a job on SageMaker, you need a single script that serves as an entrypoint for the training container. SageMaker will pass in hyperparameters as command-line arguments. So for example, if your script is called
main.py and you want to pass in a
target_update value of 5 as a hyperparameter, SageMaker will run the command
python main.py --target_update 5 in a training container.
The script needs to train the model and save it to disk in a specific location. SageMaker will then upload the model to S3 where it can be used locally or in hosted endpoints.
For this tutorial, The
main.py script has already been written and is in the top-level
frozen-lake folder with
setup.py. To run it as a training job on SageMaker:
Create a SageMaker
Estimator object in your notebook cell with the following:
estimator = PyTorch( entry_point='main.py', source_dir='/home/ec2-user/SageMaker/frozen-lake/', framework_version='1.2.0', train_instance_type='ml.m5.large', train_instance_count=1, role=get_execution_role(), )
source_dir argument specifies a directory with helper code. The directory structure will remain the same on the training instance. SageMaker will install dependencies defined in
Run the training code by calling the
This runs synchronously for about four or five minutes. After it finishes training, you can get the S3 output location from the
model_data attribute. SageMaker creates an S3 bucket for you automatically if it doesn’t already exist.
Download the tarball locally and decompress it by running the following in a new cell:
!aws s3 cp $estimator.model_data ./ !tar xvzf model.tar.gz !rm model.tar.gz
The policy network weights will now be located at
In a new cell, create a new policy network and load the weights into memory:
sagemaker_policy = DeepQNetwork(config.n_state_features, config.n_actions) sagemaker_policy.load_state_dict(torch.load('policy.pth'))
Evaluate performance on the test level:
np.random.seed(1) n_attempts = 10000 test_level = get_test_level() rewards = [ play_level(test_level, sagemaker_policy.learned_action) for _ in range(n_attempts) ] sum(rewards) / n_attempts
Performance should be the same as the
local_policy network because all random seeds are fixed.
With the PyTorch estimator still in memory, create a
HyperparameterTuner object to run multiple training jobs with different config settings:
tuner = HyperparameterTuner( estimator, objective_metric_name='MaxReward', metric_definitions=[ dict( Name='MaxReward', Regex='MaxReward=([0-9\\.]+)', ) ], hyperparameter_ranges=dict( target_update=IntegerParameter(10, 500), epsilon_start=ContinuousParameter(0.25, 0.75), ), max_jobs=20, max_parallel_jobs=5, )
metric_definictions argument is a list of metrics to parse from training jobs and the regular expression pattern to use to parse them. The
hyperparameter_ranges specifies which config values to search.
Run the hyperparameter optimization job:
The job takes about 20-25 minutes.
model_path of the best job. Then download and decompress the tarball:
model_path = ( estimator.output_path + tuner.best_training_job() + '/output/model.tar.gz' ) !aws s3 cp $model_path ./ !tar xvzf model.tar.gz !rm model.tar.gz
Load the weights into memory and evaluate the policy:
tuned_policy = DeepQNetwork(config.n_state_features, config.n_actions) tuned_policy.load_state_dict(torch.load('policy.pth')) n_attempts = 10000 rewards = [ play_level(test_env, tuned_policy.learned_action) for _ in range(n_attempts) ] sum(rewards) / n_attempts
This is a good policy network that can be used to create an agent that mimics human performance.
The win percentage will be different for you because the random seed for the Hyperparameter Tuning job is controlled by AWS. But it should be higher than the local policy you trained in the previous section.