AWS 기술 블로그

Amazon SageMaker 모델 병렬 라이브러리를 이용한 신규 성능 향상 방법

이 글은 AWS ML Blog의 New performance improvements in Amazon SageMaker model parallel library by Arjun Balasubramanian, Can Karakus, Fei Wu, Rahul Huilgol, Suhit Kodgule, and Zhaoqi Zhu의 한국어 번역 및 편집본입니다.

파운데이션(Foundation ) 모델은 대량의 데이터로 학습된 대규모 딥 러닝 모델을 말합니다. 이 모델들은 파인 튜닝(fine-tuned)을 추가로 진행하여, 다양한 다운스트림 (downstream) 작업을 수행하고 여러 AI 애플리케이션들을 가능하게 하는 핵심 백본 (backbone)을 생성합니다. 가장 눈에 띄는 범주는 Large Language Models (LLM)로서, 자연어 텍스트를 완성하도록 학습된 다양하게 변형된 GPT 알고리즘과 같은 자가 회귀 모델들 (auto-regressive models)을 포함합니다. LLM은 일반적으로 수십억 개의 모델 파라미터를 포함하므로 하나의 가속기 (accelerator)에서 학습하는 것이 거의 어려우므로, 모델 병렬화 (model parallelism) 기법이 필요합니다. 또 다른 범주는 AI 기반 Diffusion 모델들이며, 특히 Stable Diffusion 이 유명한데, 이 모델들은 간단한 텍스트 설명에서 놀라운 시각적 효과를 생성할 수 있도록 AI 기반 이미지 생성 작업에 대해 전례 없는 이정표를 만들어가고 있습니다. Diffusion 모델은 일반적으로 LLM보다 훨씬 작으며, 개발을 촉진하는 데 분산 학습이 중요한 역할을 수행하고 있습니다.

SageMaker 모델 병렬(model parallel) (SMP) 라이브러리Amazon SageMaker 플랫폼에서 이용할 수 있는 대규모 모델 학습 방법입니다. PyTorch 모델들과 통합되어 최신 (state-of-the-art)의 다양한 대규모 모델에 대해 분산 학습 기법을 쉽게 적용하여 규모에 맞게 학습할 수 있습니다. SMP는 아마존에서 개발한 MiCS 기술을 기반으로 하는 분산형 학습 기법인 Sharded data parallel를 제공하고 있습니다. Sharded data parallel는 모델 파라미터, 그레디언트와 옵티마이저 상태를 데이터 병렬 워커(data-parallel worker)에 따라 분할합니다. MiCS는 스케일 인식 (scale-aware) 파티셔닝을 포함한 여러 최적화를 수행하여 선형에 가까운 확장성을 제공합니다. Train gigantic models with near-linear scaling using sharded data parallelism 블로그에서, 시퀀스 길이가 2048인 300억 개 (30B) 모델 파라마터를 가진 GPT-2 모델에서 SMP의 sharded data parallel은 DeepSpeed ZeRO-3에 비해 39.7% 학습 속도를 향상시키는 결과를 얻었습니다.

고객 분들이 모델 학습 비용을 최소화하고 출시 기간을 단축할 수 있도록 SageMaker 모델 병렬 라이브러리에서는 SMDDP 통신 패턴 (SMDDP Collectives)와 FlashAttention이라는 두 개의 신규 성능 향상을 위한 기능을 추가하게 되었습니다. SMDDP 통신 패턴은 SageMaker 분산 데이터 병렬 라이브러리에서 제공하는 대규모 모델 학습을 위한 AWS 인프라에서 가장 성능이 뛰어난 통신 패턴 라이브러리입니다. Dao et al. 논문에서 소개된 FlashAttention은 IO (다른 속도의 메모리들에 대한 읽기/쓰기) 인식 방식으로 어텐션 (attention) 메커니즘을 재구현하여 메모리 대역폭의 요구 사항을 줄이고 어텐션 속도와 메모리 공간을 절약합니다. 이 두 개의 기능들은 32개의 p4d.24xlarge 인스턴스에서1,000억개 모델 파라미터를 가진 GPT-NeoX 모델을 학습할 때 sharded data parallel 기법을 사용하여 전체적으로 30.58% 더욱 빠르게 만듭니다. 지원되는 모델에서 sharded data parallel를 이미 사용한 고객의 경우, 이런 최신 기능들이 제공하는 성능 향상의 이점을 얻기 위해 별도 코드 변경이 필요 없습니다. 비교할 수 없이 이미지 생성 능력을 보여준 Stable Diffusion 모델의 제품군을 만들고 있는 Stability AI 는 Foundation 모델들을 생성하기 위해 SMP를 사용하고 있습니다. SMP를 통해 Stability AI는 32개 p4d.24xlarge 인스턴스에서 130억 개의 모델 파라미터를 가진 GPT-NeoX에 대해 GPU당 163 TFLOPS을 달성했으며, DeepSpeed에 비해 58% 의 속도가 향상되었습니다. AWS re:Invent 2022또는 다음 블로그에서 Stability AI CEO의 미션과 AWS와의 파트너십에 대한 자세한 이야기를 확인할 수 있습니다.

“Stability AI의 미션은 AI를 통해 인류의 잠재력을 활성화할 수 있는 기반을 구축하는 것입니다. 이 미션을 달성하려면 수백 개의 가속 컴퓨팅 인스턴스에서 오픈 소스 기반 모델을 효율적으로 학습해야 합니다. SageMaker와 분산 학습 라이브러리를 사용하면 성능을 최적화하고 최신 (state-of-the-art) 전략을 구현하여 학습 클러스터를 따라 모델과 데이터를 분류 (shard) 합니다. 이러한 최적화를 통해 학습 비용을 절감하고 고객 요구 사항을 빠르게 충족하며 신규 모델의 개발 속도를 높일 수 있습니다.”

— Emad Mostaque, Founder and CEO of Stability AI.

이 블로그 게시물에서는 먼저 SageMaker 모델 병렬 라이브러리의 최신 성능 개선 사항을 소개합니다. 그런 다음 sharded data parallel을 사용하여 Foundation 모델을 학습하는 방법에 대해 추가로 살펴보겠습니다. 마지막으로 130억, 500억, 그리고 1,000억 모델 파라미터를 가진 자가 회귀 모델의 성능을 벤치마킹하고 향후 작업으로 마무리하겠습니다.

SageMaker 모델 병렬 라이브러리의 신규 성능 향상

AWS Deep Learning Containers(DLC) PyTorch 1.12.1부터 SageMaker 모델 병렬 라이브러리 v1.13을 통해 학습 성능을 개선하는 데 중요한 다음 두 가지 신규 구성 요소를 함께 제공합니다. 현재 Elastic Fabric Adapter (EFA)가 활성화된 ml.p4d.24xlarge 인스턴스에서 사용할 수 있습니다.

1. SMDDP 통신패턴 내 AWS에 최적화된 AllGather

sharded data parallel에서, 개별 GPU에 분리 (shard)된 모델 상태가 존재하므로 정방향 (forward) 또는 역방향 (backward) 패스 계산 중에 분리된 (sharding) 그룹 내의 모든 GPU에서 전체 모델의 파라미터 세트를 수집하려면 AllGather 통신 패턴 (collective)이 필요합니다. 이전 버전의 SageMaker 모델 병렬에서는 이러한 통신 패턴들에 대해 NVIDIA Collective Communications Library (NCCL)을 사용했습니다. 그러나, NCCL은 AWS 인프라용으로 설계되지 않은 범용 통신 패턴 라이브러리이므로, EFA가 활성화되더라도 성능이 최적화되지 않았습니다.

이전에는 데이터 병렬 학습의 성능을 가속화하기 위해 AWS에 최적화된 All-Reduce 통신 패턴을 구현하여 제공하는 SMDDP 라이브러리를 개발했습니다. 이번에 추가로 sharded data parallelism로 대규모 모델 학습의 성능을 개선하기 위해 SMDDP 통신 패턴 라이브러리를 확장하면서, AllGather 통신 패턴을 최적화하여 구현하였습니다. SMDDP에서 제공하는 AllGather의 주요 장점은 노드 간 통신을 위해 모든 유형의 통신 패턴을 채택하여, AWS 통신 패턴이 높은 처리량을 가지고 지연 (latency)을 줄일 수 있도록 하는 것입니다. 또한, AWS의 AllGather은 통신과 관련 처리에 대해 CPU로 내리는 반면, 대규모 모델에서 상당한 성능 향상을 가져올 수 있는 그레디언트  계산과 관련된 중요한 처리는 GPU 사이클에서 확보할 수 있도록 하였습니다.

2. FlashAttention

최신 트랜스포머 아키텍처에서 가장 큰 메모리 소비 이유 중 하나는 셀프 어텐션 계층 (self-attention layer)의 활성화 영역입니다. 이는 각 어텐션 헤드가 각 입력에 대해 S×S 어텐션 행렬을 계산하기 때문입니다. 여기서, S는 시퀀스 길이이고, 이 행렬은 드롭아웃, 소프트맥스 및 행렬 곱셈과 같은 여러 연산을 거치게 되는데, 각 중간 출력을 역전파에 사용하기 위해 메모리 공간이 필요하게 됩니다.

FlashAttention(Dao et al.)은 I/O 인식 방식으로 셀프 어텐션 메커니즘을 다시 구현하는 스탠포드의 HazyResearch의 최신 연구 결과입니다. FlashAttention의 주요 관찰 결과는 셀프 어텐션 메커니즘이 GPU 고대역폭 메모리(HBM)과 GPU의 SRAM 메모리 대역폭에서 병목 현상이 발생한다는 점입니다. FlashAttention을 활용하면 한 번에 전체 셀프 어텐션 파이프라인을 통과하는 청크 (chunk)와 함께, 시퀀스 차원을 따라 청크들로 셀프 어텐션 계층를 계산할 수 있습니다. 모든 반복 (iteration) 동안 HBM으로의 비싼 왕복 (round-trip)을 피하고자 청크에 대한 중간 결과를 고대역폭 SRAM에 저장합니다. 단순 구현에서는 소프트맥스 (softmax) 계층에서 교차 청크 종속성의 문제에 부딪힐 수 있지만, FlashAttention은 이러한 종속성을 회피하는 영리한 구현 기법을 도입합니다. 역방향 패스의 재계산과 결합된 FlashAttention은 HBM의 왕복과 S×S행렬의 저장을 방지하여 상당한 메모리 절감과 성능 향상 (16개 p4d.24xlarge 인스턴스에서 130억 개 모델 파라미터를 가진 GPT-NeoX에 대해 25% 더 빠른 학습)을 얻게 됩니다. Hazy Research의 FlashAttention 저장소에서 다양한 시각 자료와 더 많은 설명을 찾을 수 있습니다.

SageMaker 모델 병렬로 대규모로 Foundation 모델 학습

SMDDP 통신 패턴 (Collectives)에서 제공하는 SMP로 Foundation 모델을 학습하려면, sharded data parallel의 학습 작업에 추가 변경이 필요하지 않습니다. sharded data parallel을 처음 사용하는 경우 데이터 처리, 학습 작업의 정의와 제출에서 학습 로그 모니터링에 이르기까지 전체 프로세스를 이 튜토리얼 노트북블로그 게시물을 참조하십시오. GPT-2 모델에 대한 즉시 학습이 가능한 훈련 스크립트는 train_gpt_simple.py을 사용할 수 있습니다. 다른 모델 유형을 학습하려면 API 문서를 따라 SMP API를 적용하는 방법을 알아볼 수 있습니다.

아래와 같이 sharded data parallel 학습 작업과 관련하여 PyTorch Estimator에서의 주요 하이퍼파라미터를 강조하여 표시합니다. smp_options의 하이퍼파라미터 ddp_dist_backend에는 이제 새 옵션으로 “auto”가 기본값으로 포함됩니다. sharded data parallel 처리 작업에 대해 “auto”를 사용하면 SMP는AWS에 최적화된 AllGather를 사용하고, 그렇지 않으면 NCCL로 대체합니다. 지원되는 구성에 대해서는 이 문서를 참조할 수 있습니다. 특히 NCCL을 선택한 통신 백엔드로 사용하여 SMP에서 sharded data parallel를 실행하려면 smp_options에서 “ddp_dist_backend”를 “nccl”로 설정할 수 있습니다.

import sagemaker
from sagemaker.pytorch import PyTorch

smp_options = {
    "enabled": True,
    "parameters": {
        "ddp": True,
        "ddp_dist_backend": "auto", #OR "nccl" to disable SMDDP Collectives
        # To enable sharded data parallelism.
        # Here we shard model states across 128 GPUs.
        "sharded_data_parallel_degree": 128,  
    }
}

smp_estimator = PyTorch(
    entry_point="train_gpt_simple.py",
    role=sagemaker.get_execution_role(),
    instance_type='ml.p4d.24xlarge',
    instance_count=32,
    distribution={
        "smdistributed": {"modelparallel": smp_options},
        ...
    },
    ...
)

smp_estimator.fit(inputs=data_channels)

최신 SMPv1.13 릴리스에서 sharded data parallel 학습 기법은 BERT, RoBERTa, GPT-2, GPT-J, GPT-Neo 및 GPT-NeoX를 포함하여 인기 있는 모델에 대해 FlashAttention을 지원합니다. 이것은 tensor_parallel_degree를 설정하지 않고 모델 생성 중에 tensor_parallelism=True를 전달하여 활성화됩니다. 동일한 학습 스크립트인 train_gpt_simple.py 에서 예제를 찾을 수 있습니다.

성능 벤치마킹

우리는 SageMaker 모델 병렬 라이브러리의 sharded data parallelism를 세 가지 다른 모델 메트릭으로 벤치마킹하여 두 가지 신규 기능인 FlashAttention과 AWS에 최적화된 AllGather가 성능 향상에 어떻게 기여하는지 알아 보았습니다. SageMaker에서 이러한 벤치마크를 재현하는 데 배치 그룹 (Placement group)은 필요하지 않습니다.

130억개 파라미터를 가진 GPT-NeoX

이 설정에서는 FlashAttention이 제공하는 성능 향상을 이해하는 데 초점을 맞추기 위해 AWS에 최적화된 AllGather는 제외합니다. FlashAttention을 사용하면 상당한 GPU 메모리가 절약되므로 배치 크기를 늘리거나 분리 (sharding) 정도를 줄일 수 있으므로 성능이 향상됩니다. 아래 결과에서 알 수 있듯이, 16 – 64 p4d.24xlarge 인스턴스의 다양한 구성에 대해 130억개 파라미터를 가진 GPT-NeoX 모델에 FlashAttention을 사용한 SMP의 결과에서 평균 약 20.4%의 속도 향상을 관찰하였습니다. 표준 어텐션 연산 중 메모리 사용량은 시퀀스 길이가 증가함에 따라 2차원적으로 확장되지만, FlashAttention의 메모리 사용량은 시퀀스 길이에 따라 선형적으로 증가합니다. 따라서, FlashAttention은 더 큰 시퀀스 길이를 사용할 수 있으므로 훨씬 더 유용한 결과를 얻게 됩니다. 모델 품질을 희생하지 않고 메모리가 효율적이기 때문에 FlashAttention은 HuggingFace DiffusersMosaic ML과의 통합을 포함하여 지난 몇 달 동안 대규모 모델 학습 커뮤니티에서 빠르게 주목을 받았습니다.

500억개 파라미터를 가진 Bloom

이제 SMDDP 통신 패턴 (Collectives)에서 AWS에 최적화된 AllGather가 SMP를 사용하여 대규모 모델 학습을 가속화하는 방법을 살펴봅니다. 우리는 500억개의 모델 파라미터를 가진 Bloom 모델을 벤치마킹하고 AWS에 최적화된 AllGather 가 있는 경우와 없는 경우의 성능을 비교합니다. SMDDP의 통신 패턴들은 32개 노드에서 64개 노드의 학습 작업에 걸쳐 모델 학습 속도를 최대 40% 향상시키는 것으로 확인되었습니다. SMDDP의 통신 패턴들은 p4d.24xlarge 인스턴스에서 사용할 수 있는 400Gbps 네트워크 대역폭을 더욱 잘 활용하여 성능을 향상시키는 데 도움이 됩니다. 이는 통신 관련 처리를 CPU로 넘기는 설계 선택과 결합하여 컴퓨팅과 네트워크 간의 좋은 동시 처리를 만들면서 성능을 최적화하는 데 도움이 됩니다. 노드 간에 전달되는 데이터 크기가 모델 크기의 증가에 따라 선형적으로 확장되기 때문에, 컴퓨팅과 네트워크 간 동시 처리는 대규모 모델 학습에서 특히 중요해 지고 있습니다.

1,000억개 파라미터를 가진 GPT-NeoX

마지막으로 두 가지 최신 기능을 모두 사용하여 SMP를 벤치마킹 합니다. 1,000억개 파라미터를 가진 GPT-NeoX 모델에서  SMP v1.13은 이전 버전보다 30% 더 빠른 결과를 보여 줍니다.

향후 작업에서 SMDDP 통신 패턴 (Collectives)은 AWS에 최적화된 Reduce-Scatter를 지원하는 작업을 진행할 예정입니다. Reduce-Satter는 역방향 패스에서 연산된 그레디언트의 평균화와 분리 (sharding)에 중요합니다. 이를 통해 향후 릴리즈에서는 SMP 라이브러리의 속도가 더욱 빨라질 것으로 기대합니다.

결론

이 게시물에서는  SageMaker 모델 병렬 라이브러리의 sharded data parallel 기법에 대한 두 가지의 최신 성능 향상에 대해 설명하였고, 이를 통해 LLM이 ML 모델의 품질과 재사용 가능성을 개선하는 데 기여하는 모습을 보여 주었습니다. AWS 팀은 고객과 긴밀하게 협력하여 학습 비용과 출시 기간을 지속적으로 단축하고 있습니다. Amazon SageMaker 예제 GitHub 리포지토리에서 더 많은 SageMaker 모델 병렬 예제를 확인하거나 다음과 같은 분산 학습 워크샵에 참석할 수 있습니다. 대규모 모델 학습의 속도를 높이는 데 관심이 있으시다면, 이런 기능들을 확인하시고 무엇을 만들지 저희에게 알려주시기 바랍니다. 감사합니다.

Youngjoon Choi

Youngjoon Choi

최영준 Principal AI/ML Expert SA는 제조, 하이테크, 금융 등의 다양한 산업에서 엔터프라이즈 IT를 경험하면서 개발자로, 아키텍트로, 데이터 과학자로 다양한 활동을 하였습니다. 기계학습과 딥러닝 연구를 진행하였고, 구체적으로 Hyperparameter optimization과Domain adaptation등을 주제로 알고리즘 연구와 논문 발표를 진행하였습니다. AWS에서는 AI/ML를 전문 분야로 다양한 산업군에서 분산학습/초거대 모델 개발과 ML 파이프라인 구축 등에 대해 AWS 서비스를 이용한 기술 검증 지원, 아키텍처 제안 및 검토 등을 수행하고 있으며, AI/ML 생태계가 더욱 확장할 수 있도록 다양한 기여를 하고자 합니다.