Skip to main content

Summary 

PyTorch Distributed Checkpointing (DCP) is making investments into addressing the interoperability blockers to ensure that popular formats, like HuggingFace safetensors, can work well with PyTorch’s ecosystem. Since HuggingFace has become a leading format in inference and fine-tuning, DCP is beginning to support HuggingFace safetensors. The first customer of these changes is torchtune, who has seen an improved user experience as they can now cleanly read and write directly to HuggingFace with DCP APIs.

Problem

Since HuggingFace is used widely, with over 5 million users, many ML engineers would like to save and load their checkpoints in safetensors format to be able to easily work with their ecosystem. By supporting safetensors format natively in DCP, checkpointing is simplified for our users in the following ways:

  • DCP currently has its own custom format, so users who want to work with HuggingFace models, but leverage DCP’s performance wins and features, had to build custom converters and components so that they could work between both systems.
  • Instead of users having to download and upload their checkpoints to local storage every time, HuggingFace models can now be saved and loaded directly into the fsspec-supported storage of their choosing.

How to Use

From a user’s perspective, the only change needed to use safetensors is to call load with the new load planner and storage reader, and similarly save with the new save planner and storage writer.

The load and save APIs are called as follows:


load(
	state_dict=state_dict,
	storage_reader=HuggingFaceStorageReader(path=path),
)

save(
	state_dict=state_dict,
	storage_writer=HuggingFaceStorageWriter(
				path=path,
				fqn_to_index_mapping=mapping
			),
)

The HuggingFaceStorageReader and HuggingFaceStorageWriter can take any fsspec based path and so it can read/write in HF safetensors format to any fsspec supported back-end, including local storage and HF storage. Since HuggingFace safetensors metadata doesn’t natively provide the same level of information as DCP metadata, distributed checkpoints are currently not well-supported in these APIs, but DCP plans on supporting this natively in the future.

 

torchtune

Our first customer of HuggingFace DCP support is torchtune – a post-training library written in native PyTorch. The primary way torchtune users retrieve model weights is from the Hugging Face Hub. Before, users had to download the model weights and upload the trained checkpoints via extra CLI commands; the new DCP APIs allow them to directly read and write to HuggingFace, resulting in a much better user experience. 

In addition, the support of safetensor serialization in DCP greatly simplifies the checkpointing code in torchtune. No longer will there need to be format-specific checkpointing solutions, thus increasing developer efficiency in the project.

Future Work

DCP plans to handle the distributed loading and saving of HuggingFace safetensors checkpoints with resharding. DCP also plans to support the ability to produce a consolidated final checkpoint to a single file for publishing.