AWS Machine Learning Blog
Revolutionizing large language model training with Arcee and AWS Trainium
This is a guest post by Mark McQuade, Malikeh Ehghaghi, and Shamane Siri from Arcee.
In recent years, large language models (LLMs) have gained attention for their effectiveness, leading various industries to adapt general LLMs to their data for improved results, making efficient training and hardware availability crucial. At Arcee, we focus primarily on enhancing the domain adaptation of LLMs in a client-centric manner. Arcee’s innovative continual pre-training (CPT) and model merging techniques have brought a significant leap forward in the efficient training of LLMs, with particularly strong evaluations in the medical, legal, and financial verticals. Close collaboration with AWS Trainium has also played a major role in making the Arcee platform extremely performant, not only accelerating model training but also reducing overall costs and enforcing compliance and data integrity in the secure AWS environment. In this post, we show you how efficient we make our continual pre-training by using Trainium chips.
Understanding continual pre-training
Arcee recognizes the critical importance of continual CPT [1] in tailoring models to specific domains, as evidenced by previous studies such as PMC-LLaMA [2] and ChipNeMo [3]. These projects showcase the power of domain adaptation pre-training in enhancing model performance across diverse fields, from medical applications to industrial chip design. Inspired by these endeavors, our approach to CPT involves extending the training of base models like Llama 2 using domain-specific datasets, allowing us to fine-tune models to the nuances of specialized fields. To further amplify the efficiency of our CPT process, we collaborated with the Trainium team, using their cutting-edge technology to enhance a Llama 2 [4] model using a PubMed dataset [2] comprising 88 billion tokens. This collaboration represents a significant milestone in our quest for innovation, and through this post, we’re excited to share the transformative insights we’ve gained. Join us as we unveil the future of domain-specific model adaptation and the potential of CPT with Trainium in optimizing model performance for real-world applications.
Dataset collection
We followed the methodology outlined in the PMC-Llama paper [6] to assemble our dataset, which includes PubMed papers sourced from the Semantic Scholar API and various medical texts cited within the paper, culminating in a comprehensive collection of 88 billion tokens. For further details on the dataset, the original paper offers in-depth information.
To prepare this dataset for training, we used the Llama 2 tokenizer within an AWS Glue pipeline for efficient processing. We then organized the data so that each row contained 4,096 tokens, adhering to recommendations from the Neuron Distributed tutorials.
Why Trainium?
Continual pre-training techniques like the ones described in this post require access to high-performance compute instances, which has become more difficult to get as more developers are using generative artificial intelligence (AI) and LLMs for their applications. Traditionally, these workloads have been deployed to GPUs; however, in recent years, the cost and availability of GPUs has stifled model building innovations. With the introduction of Trainium, we are able to unlock new techniques that enable us to continue model innovations that will allow us to build models more efficiently and most importantly, at lower costs. Trainium is the second-generation machine learning (ML) accelerator that AWS purpose built to help developers access high-performance model training accelerators to help lower training costs by up to 50% over comparable Amazon Elastic Compute Cloud (Amazon EC2) instances. With Trainium available in AWS Regions worldwide, developers don’t have to take expensive, long-term compute reservations just to get access to clusters of GPUs to build their models. Trainium instances offer developers the performance they need with the elasticity they want to optimize both for training efficiency and lowering model building costs.
Setting up the Trainium cluster
We used AWS ParallelCluster to build a High Performance Computing (HPC) compute environment that uses Trn1 compute nodes to run our distributed ML training job (see the GitHub tutorial). You can also use developer flows like Amazon SageMaker, Amazon Elastic Kubernetes Service (Amazon EKS), Ray, or others (to learn more, see Developer Flows). After the nodes were launched, we ran a training task to confirm that the nodes were working, and used slurm commands to check the job status. In this part, we used the AWS pcluster
command to run a .yaml file to generate the cluster. Our cluster consisted of 16 nodes, each equipped with a trn1n.32xlarge instance featuring 32 GB of VRAM.
We set up our ParallelCluster
infrastructure as shown in the following diagram (source).
As shown in the preceding figure, inside a VPC, there are two subnets, a public one and a private one. The head node resides in the public subnet, and the compute fleet (in this case, Trn1 instances) is in the private subnet. A NAT gateway is also needed in order for nodes in the private subnet to connect to clients outside the VPC. In the following section, we describe how to set up the necessary infrastructure for Trn1 ParallelCluster
.
Set up the environment
To set up your environment, complete the following steps:
- Install the VPC and necessary components for
ParallelCluster
. For instructions, see VPC setup for ParallelCluster with Trn1. - Create and launch
ParallelCluster
in the VPC. For instructions, see Create ParallelCluster.
Now you can launch a training job to submit a model training script as a slurm job.
Deploy to Trainium
Trainium-based EC2 Trn1 instances use the AWS Neuron SDK and support common ML frameworks like PyTorch and TensorFlow. Neuron allows for effortless distributed training and has integrations with Megatron Nemo and Neuron Distributed.
When engaging with Trainium, it’s crucial to understand several key parameters:
- Tensor parallel size – This determines the level of tensor parallelization, particularly in self-attention computations within transformers, and is crucial for optimizing memory usage (not computational time efficiency) during model loading
- NeuronCores – Each Trainium device has two NeuronCores, and an eight-node setup equates to a substantial 256 cores
- Mini batch – This reflects the number of examples processed in each batch as determined by the data loader
- World size – This is the total count of nodes involved in the training operation
A deep understanding of these parameters is vital for anyone looking to harness the power of Trainium devices effectively.
Train the model
For this post, we train a Llama 2 7B model with tensor parallelism. For a streamlined and effective training process, we adhered to the following steps:
- Download the Llama 2 full checkpoints (model weights and tokenizer) from Hugging Face.
- Convert these checkpoints to a format compatible with the Neuron Distributed setup, so they can be efficiently utilized in our training infrastructure.
- Determine the number of steps required per epoch, incorporating the effective batch size and dataset size to tailor the training process to our specific needs.
- Launch the training job, carefully monitoring its progress and performance.
- Periodically save training checkpoints. Initially, this process may be slow due to its synchronous nature, but improvements are anticipated as the NeuronX team works on enhancements.
- Finally, convert the saved checkpoints back to a standard format for subsequent use, employing scripts for seamless conversion.
For more details, you can find the full implementation of the training steps in the following GitHub repository.
Clean up
Don’t forget to tear down any resources you set up in this post.
Results
Our study focused on evaluating the quality of the CPT-enhanced checkpoints. We monitored the perplexity of a held-out PubMed dataset [6] across various checkpoints obtained during training, which provided valuable insights into the model’s performance improvements over time.
Through this journey, we’ve advanced our model’s capabilities, and hope to contribute to the broader community’s understanding of effective model adaptation strategies.
The following figure shows the perplexity of the baseline Llama 2 7B checkpoint vs. its CPT-enhanced checkpoint on the PMC test dataset. Based on these findings, continual pre-training on domain-specific raw data, specifically PubMed papers in our study, resulted in an enhancement of the Llama 2 7B checkpoint, leading to improved perplexity of the model on the PMC test set.
The following figure shows the perplexity of the CPT-enhanced checkpoints of the Llama 2 7B model across varying numbers of trained tokens. The increasing number of trained tokens correlated with enhanced model performance, as measured by the perplexity metric.
The following figure shows the perplexity comparison between the baseline Llama 2 7B model and its CPT-enhanced checkpoints, with and without data mixing. This underscores the significance of data mixing, where we have added 1% of general tokens to the domain-specific dataset, wherein utilizing a CPT-enhanced checkpoint with data mixing exhibited better performance compared to both the baseline Llama 2 7B model and the CPT-enhanced checkpoint solely trained on PubMed data.
Conclusion
Arcee’s innovative approach to CPT and model merging, as demonstrated through our collaboration with Trainium, signifies a transformative advancement in the training of LLMs, particularly in specialized domains such as medical research. By using the extensive capabilities of Trainium, we have not only accelerated the model training process, but also significantly reduced costs, with an emphasis on security and compliance that provides data integrity within a secure AWS environment.
The results from our training experiments, as seen in the improved perplexity scores of domain-specific models, underscore the effectiveness of our method in enhancing the performance and applicability of LLMs across various fields. This is particularly evident from the direct comparisons of time-to-train metrics between Trainium and traditional GPU setups, where Trainium’s efficiency and cost-effectiveness shine.
Furthermore, our case study using PubMed data for domain-specific training highlights the potential of Arcee’s CPT strategies to fine-tune models to the nuances of highly specialized datasets, thereby creating more accurate and reliable tools for professionals in those fields.
As we continue to push the boundaries of what’s possible in LLM training, we encourage researchers, developers, and enterprises to take advantage of the scalability, efficiency, and enhanced security features of Trainium and Arcee’s methodologies. These technologies not only facilitate more effective model training, but also open up new avenues for innovation and practical application in AI-driven industries.
The integration of Trainium’s advanced ML capabilities with Arcee’s pioneering strategies in model training and adaptation is poised to revolutionize the landscape of LLM development, making it more accessible, economical, and tailored to meet the evolving demands of diverse industries.
To learn more about Arcee.ai, visit Arcee.ai or reach out to our team.
Additional resources
- Arcee’s whitepaper: Case Study on How Arcee is Innovating Domain Adaptation, through Continual Pre-Training and Model Merging
- Arcee’s paper on Arxiv on model merging: Arcee’s MergeKit: A Toolkit for Merging Large Language Models
- Arcee’s Mergekit repository on GitHub
References
- Gupta, Kshitij, et al. “Continual Pre-Training of Large Language Models: How to (re) warm your model?.” arXiv preprint arXiv:2308.04014 (2023).
- Wu, Chaoyi, et al. “Pmc-LLaMA: Towards building open-source language models for medicine.” arXiv preprint arXiv:2305.10415 6 (2023).
- Liu, Mingjie, et al. “Chipnemo: Domain-adapted llms for chip design.” arXiv preprint arXiv:2311.00176 (2023).
- Touvron, Hugo, et al. “Llama 2: Open foundation and fine-tuned chat models.” arXiv preprint arXiv:2307.09288 (2023).
- https://thinkwithwp.com/ec2/instance-types/trn1/
- Wu, C., Zhang, X., Zhang, Y., Wang, Y., & Xie, W. (2023). Pmc-llama: Further fine tuning llama on medical papers. arXiv preprint arXiv:2304.14454.