ElasticDL Training Checkpoints Design
Checkpoints stores the state of a training process, it includes not only variables but also certain states like . In ElasticDL, we need to save all parameters of the model to checkpoints during training for the following reasons:
- We can export the model with the best evaluation performance during training from checkpoints. Too many epochs or steps can lead to over-fitting of the training dataset. Early stop is widely used to avoid over-fitting which will stop training when the evaluation performance stops improving even goes worse. So, the final parameters may not be the best. With checkpoints, we can choose a checkpoint with the best evaluation metrics to export for serving.
- We can restore the model parameters from a checkpoint to train or predict. Fault tolerance in ElasticDL can not resolve all unexpected errors e.g. a power outage and OS fault which is fatal. We are going to have a bad time if we haven’t finished storing any single checkpoint due to those errors. Other times, even if we don’t experience an unforeseen error, we might just want to resume a particular state of the training for a new experiment or try different things from a given state.
In the following sections, we will describe the design of how to save checkpoints and restore model parameters from a checkpoint in ElasticDL.
Design Components
This design contains two key parts: export and restore.
When to save a checkpoint during training
Using ParameterServerStrategy in ElasticDL, Communication with PS is asynchronous and we cannot guarantee that the iteration steps of all PS instances are consistent at the same time. So, we don’t support to save checkpoints every N seconds.
Where to save a checkpoint
We need to set the checkpoint_dir
in args if we want to save checkpoints. The saved checkpoint directory will be generated by checkpoint_dir
and iteration steps which we call model version in ElasticDL, e.g. /{checkpoint_dir}/version_{version}/
.
ElasticDL is a Kubernetes-native framework and local directory in the pod will be lost if the pod exits in Kubernetes. So the checkpoint_dir
should be a persistent directory which will not be lost after the pod exit. Now, we can utilize and PersistentVolumes to get a persistent directory to save checkpoints in ElasticDL.
How to save model parameters to a checkpoint directory
The parameters on each PS instance may contain non-embedding variables and a part of embedding vectors of each embedding table. We create a Tensor
protobuf with values for each variable or with values and indices for embedding vectors. To save an embedding table, we also need to create EmbeddingTableInfo
to save the meta information like initializer
and dim
. Then, we create a Model
protobuf including those Tensor
s and EmbeddingTableInfo
s. Finally, we save the to a file named variables-{ps_id}-of-{ps_num}.chkpt
where ps_id
is the index of a PS instance and ps_num
is the total number of PS instances. The checkpoint file will be located in a subdirectory version_{version}
for each checkpoint version. After the iteration versions of all PS instances exceed the checkpoint version in version_{version}
, the version_{version}
will contain files of the entire model parameters.
How many recent checkpoints to keep
We can set the keep_checkpoint_max
in ElasticDL to determine the maximum number of recent checkpoints. After saving a checkpoint, each PS instance will check the number of checkpoints it saved and will remove the checkpoint of its own shard in the folder ‘model_v{version}’ with the smallest version if the number exceeds the keep_checkpoint_max
. Then, it will remove the version_{version}
subdirectory if it is empty.
Restore Model Parameters from a Checkpoint
The number of PS instances may vary when we restore model parameters to train or predict. So, we need to repartition variables and embedding vectors in checkpoint files to new PS instances. For each new PS instance, it will traverse all protobuf files in the checkpoint directory and get variables and embedding vectors if hash_utils.string_to_id(var.name)
or hash_utils.int_to_id(embedding_id)
is equal to its index (ps_id
). The pseudo-code is:
embedding_table_vectors = {}
for pb_file in os.listdir(checkpoint_dir):
model_pb = load_model_file(pb_file)
for param in model_pb.params:
embedding_table_vectors.setdefault(param.name, ([], []))
for embedding_id, vector in zip(param.indices, param.values):
if hash_util.int_to_id(embedding_id) == ps_id:
embedding_table_vectors[param.name].append(embedding_id)
embedding_table_vectors[param.name].append(vector)
else:
if hash_utils.string_to_id(param.name) == ps_id