亚马逊AWS官方博客
使用 Amazon SageMaker 微调和部署 Whisper
前言
随着人工智能技术的快速发展,语音识别和处理技术在各个领域的应用日益广泛。本文将介绍如何利用 AWS 的云服务平台,结合开源的 Whisper 模型,构建一个高效、可扩展的语音识别系统。特别地,我们将以《水浒传》这部中国古典文学作品中的人物对话识别为具体案例,展示如何通过模型微调来提高特定领域的识别准确率,并通过 Amazon SageMaker 实现模型的高效部署。
本文主要涉及三个核心技术组件:
- Amazon SageMaker:AWS 的托管式机器学习服务,用于模型训练和部署。
- Amazon Polly:AWS 的文本转语音服务,用于生成训练数据。
- Whisper:OpenAI 开发的开源语音识别模型,作为我们系统的基础模型。
通过这三个组件的结合,我们将展示如何构建一个完整的语音识别解决方案,从数据准备、模型训练到最终的服务部署,为读者提供一个实用的技术参考。
Amazon SageMaker
Amazon SageMaker 是一个托管式机器学习服务,可帮助您构建和训练模型,然后将其部署到生产就绪的托管环境中。通过使用 SageMaker 可以减少运营开销,加速模型的训练与部署。
- SageMaker Training Job:用于训练机器学习模型的托管环境。它可以自动配置和扩展训练基础设施,支持分布式训练,并提供内置算法和自定义算法支持。
- SageMaker Endpoint:用于部署训练好的模型并提供实时推理服务。它可以自动扩展以处理不同的流量负载,支持 A/B 测试和模型版本管理。
Amazon Polly
Amazon Polly 是一项将文本转换为语音的云服务。它使用深度学习技术来合成自然 sounding 的人类语音,让您能够为应用程序创建会说话的产品和全新的语音控制类别的产品。
Polly 的主要特点包括:
- 文本转语音(NTTS)技术,提供更自然、更富有表现力的语音。
- 支持多种语言和口音。
- 自定义词典功能,可以控制特定单词的发音。
- SSML(Speech Synthesis Markup Language)支持,可以更精细地控制语音输出。
Whisper
Whisper 是由 OpenAI 开发的一个开源的自动语音识别(ASR)系统。它具有以下特点:
- 多语言支持:可以识别和翻译多种语言。
- 鲁棒性:对背景噪音和口音有较强的适应能力。
- 灵活性:可以进行语音转文本、翻译和语言识别任务。
- 开源;可以自由使用和修改,适合研究和商业应用。
Whisper 模型微调
任务概述
本节中,我们会聚焦于一个特殊而具有挑战性的语音识别任务:准确识别中文小说中的人物名字。我们选择了《水浒传》作为研究对象,因为它包含了大量独特的中文人名,这些名字往往由两个或三个字组成,在发音和含义上都具有特殊性。
现有的通用语音识别模型在处理这类专有名词时常常表现不佳。例如,“朱武”可能被错误识别为“猪五”,“郁保四”可能被误认为“玉宝寺”。这些错误不仅影响了文本的准确性,也可能导致对原著内容的误解。我们的目标是通过微调现有的 Whsiper 模型,提高其对《水浒传》中人物名字的识别准确率。这项工作不仅对于古典文学作品的数字化和音频处理具有重要意义,也为解决其他领域中类似的专有名词识别问题提供了可行的方法。
具体的完整 Whisper 微调代码详见 finetune-and-deploy-whisper-models 。
数据准备
为了训练和评估我们的模型,我们需要准备一个专门针对《水浒传》人物对话的数据集。数据准备过程分为以下三个主要步骤:
- 人物选择:我们手工选择了 6 个《水浒传》中具有代表性的人物:史进、杨志、吴用、公孙胜、郁保四和朱武。这些人物在小说中扮演了重要角色。
- 文本生成:利用 Claude 3.5 Sonnet 大语言模型,我们为每个选定的人物生成了 10 条符合其特征的文本。生成过程中,我们要求模型围绕每个角色的特点和在小说中的表现来创作对话,以确保生成的文本能够真实反映角色的语言特征。最终,我们得到了总计 60 条高质量的对话文本。
- 音频合成:为了获得与文本对应的音频数据,我们使用了 Amazon Polly 文本转语音服务。我们选择了具有中文发音能力的语音合成引擎,将所有 60 条文本转换成了对应的音频文件。这些音频文件模拟了真实的人物对话,为我们的语音识别任务提供了理想的训练和测试数据。
- 数据集划分:我们将生成的数据集划分为训练集和测试集。每个角色的 10 条数据中,8 条用于训练,2 条用于测试。因此,我们最终得到了 48 条训练数据和 12 条测试数据。这种划分方式确保了每个角色在训练和测试集中都有合适的表示。
通过这种方法,创建了一个专门针对《水浒传》人物对话的数据集,为后续的模型训练和评估奠定了基础。这个数据集不仅包含了丰富的中文人名,还体现了古典小说的语言特点,为后续实验如何提高特定领域语音识别准确率提供了理想的实验材料。
LoRA 微调
为了提高模型在识别《水浒传》人物名字方面的性能,我们采用了 LoRA(Low-Rank Adaptation)微调技术。LoRA 是一种高效的模型适应方法,它允许我们在保持大部分预训练模型参数不变的情况下,通过添加少量可训练参数来适应特定任务。
模型选择
我们选择了 openai/whisper-large-v3 作为基础模型。Whisper 是一个强大的多语言语音识别模型,在各种语音识别任务中表现出色。通过 LoRA 微调,我们旨在进一步提升其在识别中文古典小说人物名字方面的能力。
超参数设置
我们选择了以下超参数来进行 LoRA 微调:
- 语言:中文(–language zh)
- 学习率:1e-3 。注意,对于大批量数据(如 1k+样本),建议使用较小的学习率,如 1e-5
- 批次大小:2(–batch_size 2)
- 梯度累积步数:8(–gradient_accumulation_steps 8)
- 训练轮数:3(–num_epochs 3)
这些超参数的选择旨在平衡训练效率和模型性能。较小的批次大小和梯度累积步数有助于在有限的计算资源下进行有效训练,而适度的学习率和训练轮数则有助于模型充分学习而不过拟合。
训练过程
我们使用准备好的 48 条训练数据对模型进行了微调。下图展示了训练过程中的损失曲线:
图 1. whisper large v3 模型训练 loss 曲线
从损失曲线可以看出,模型在训练过程中稳定收敛,表明LoRA微调有效地适应了我们的特定任务。
微调后效果
通过对 Whisper 模型进行 LoRA 微调,我们在《水浒传》人物对话识别任务上取得了显著的进步。以下是微调前后效果的详细对比:
1. 整体性能提升
- 微调前词错误率(WER):1184 (27/228)
- 微调后词错误率(WER):0263 (6/228)
微调后,模型的词错误率从 11.84% 大幅下降到 2.63%,降低了约 78%。这表明 LoRA 微调极大地提高了模型在识别《水浒传》人物对话方面的准确性。
2. 错误类型分析
微调前,模型主要存在以下错误:
- 人名识别错误:如“朱武”被识别为“猪五”,“郁保四”被识别为“玉宝寺”
- 近音字错误:如“箭”被识别为“剑”,“良久”被识别为“两久”
- 罕见字错误:如“吴用”中的“吴”被错误识别
微调后,大多数这些错误都得到了纠正。剩余的少量错误主要集中在一些特别罕见或容易混淆的名字上,如“时迁”仍被错误识别为“石谦”。
3. 典型案例分析
以下是几个典型案例的对比:
案例 1:
- 原文:朱武性格内向但对兄弟们都很关心
- 微调前:猪五性格内向但对兄弟们都很关心(2/15 = 13.3% WER)
- 微调后:朱武性格内向但对兄弟们都很关心(0/15 = 0% WER)
案例 2:
- 原文:郁保四站在船头目光如炬地注视着前方的水路
- 微调前:玉宝寺站在船头目光如炬地注视着前方的水路(3/20 = 15% WER)
- 微调后:郁保四站在船头目光如炬地注视着前方的水路(0/20 = 0% WER)
案例 3:
- 原文:战斗中朱武箭无虚发是梁山的神射手
- 微调前:战斗中诸武剑无虚发是梁山的神射手(2/16 = 12.5% WER)
- 微调后:战斗中朱武箭无虚发是梁山的神射手(0/16 = 0% WER)
这些案例清楚地展示了微调后模型在识别人名和特定词汇方面的显著进步。
结论
通过 LoRA 微调,我们成功地提高了 Whisper 模型在识别《水浒传》人物对话方面的能力。模型现在能够更准确地识别中文人名和古典文学中的特定表达,大大减少了之前常见的错误。这不仅提高了转录的准确性,也为处理其他类似的特定领域语音识别任务提供了有效的方法。
尽管如此,仍有少量复杂或罕见的名字存在识别困难,这表明在极具挑战性的专有名词识别方面还有进一步改进的空间。未来的工作可以考虑增加训练数据的多样性或探索更先进的微调技术来解决这些剩余的问题。
模型部署
本章节主要介绍如何使用 SageMaker 来部署 Whisper 模型。在 Amazon SageMaker 上部署 Whisper 模型可以让您使用其托管环境进行实时推理,并利用自动扩缩容等功能。在部署中,可以使用不同方法将模型部署在 SageMaker Endpoint,分别是:
- 使用 TorchServe 作为推理服务框架,部署原模型。
- 使用 Triton 作为推理服务框架,部署使用 Tensorrt-llm 编译后的 Whisper 模型,在我们测试中发现,使用这种部署方式较直接使用 TorchServe 部署有 4-5 倍速度的提升。
TorchServe 部署方式
TorchServe 部署的方式,可以参考代码仓库中已有的部署脚本进行部署,本次不再重复赘述其部署方式。
Triton 部署方式
整体架构
图 2. 使用自定义的容器部署编译后的 whisper 模型在 SageMaker Endpoint 的整体架构图
部署步骤
使用 Triton 方式部署的大致分为以下 4 步,完整代码请移步代码库 triton 部署示例代码,同时代码库中也已提供全自动脚本 prepare_and_deploy.sh,直接执行即可。
1. 执行 ./build_and_push.sh
打包 docker 镜像并上传到 ECR,关键的 Dockerfile 文件如下所示:
FROM nvcr.io/nvidia/tritonserver:24.10-py3
RUN apt update && apt-get install -y ffmpeg
WORKDIR /workspace
COPY requirements.txt /workspace/requirements.txt
COPY serve /workspace/serve
RUN pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com tensorrt-llm==tensorrt-llm==0.15.0.dev2024110500 && \
pip install tritonclient[all] && \
pip install -r requirements.txt && \
chmod +x /workspace/serve
# 定义环境变量
ENV PATH="/workspace:${PATH}"
# 运行serve
ENTRYPOINT []
CMD ["serve"]
2. 执行 python merge_lora.py --model-id "$HUGGING_FACE_MODEL_ID" --lora-path "$LORA_PATH" --export-to "$OUTPUT_MODEL_PATH"
命令合并 lora 微调后的模型,再导出为 pt 格式。注意,这里需要保持合并导出后的模型与 OpenAI 官方提供 large-v3.pt 的 state_dict 的名称格式一致,关键的替换函数如下所示:
def hf_to_whisper_states(text):
text = re.sub('.layers.', '.blocks.', text)
text = re.sub('.self_attn.', '.attn.', text)
text = re.sub('.q_proj.', '.query.', text)
text = re.sub('.k_proj.', '.key.', text)
text = re.sub('.v_proj.', '.value.', text)
text = re.sub('.out_proj.', '.out.', text)
text = re.sub('.fc1.', '.mlp.0.', text)
text = re.sub('.fc2.', '.mlp.2.', text)
text = re.sub('.fc3.', '.mlp.3.', text)
text = re.sub('.fc3.', '.mlp.3.', text)
text = re.sub('.encoder_attn.', '.cross_attn.', text)
text = re.sub('.cross_attn.ln.', '.cross_attn_ln.', text)
text = re.sub('.embed_positions.weight', '.positional_embedding', text)
text = re.sub('.embed_tokens.', '.token_embedding.', text)
text = re.sub('model.', '', text)
text = re.sub('attn.layer_norm.', 'attn_ln.', text)
text = re.sub('.final_layer_norm.', '.mlp_ln.', text)
text = re.sub('encoder.layer_norm.', 'encoder.ln_post.', text)
text = re.sub('decoder.layer_norm.', 'decoder.ln.', text)
text = re.sub('proj_out.weight', 'decoder.token_embedding.weight', text)
return text
3. 使用打包的镜像对原始/微调后的模型使用 TensorRT-LLM 进行编译,并将编译好的模型打包成模型文件上传到 S3。
- 执行以下 docker 命令
PROJECT_ROOT=/home/ec2-user/SageMaker/ # 以 SageMaker Notebook 目录为例 docker run --rm -it --net host --shm-size=2g --gpus all \ -v "$PROJECT_ROOT/sagemaker_triton/:/workspace/" \ $DOCKER_IMAGE bash -c "cd /workspace && bash export_model.sh"
- docker 运行的
export_model.sh
文件,定义了从模型转换到 TensorRT 编译的指令,详细参数如下python3 convert_checkpoint.py \ --model_name $model_name \ --output_dir $checkpoint_dir # Build the large-v3 trtllm engines trtllm-build --checkpoint_dir ${checkpoint_dir}/encoder \ --output_dir ${output_dir}/encoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ --max_batch_size ${MAX_BATCH_SIZE} \ --gemm_plugin disable \ --bert_attention_plugin ${INFERENCE_PRECISION} \ --remove_input_padding disable \ --max_input_len 3000 --max_seq_len 3000 trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \ --output_dir ${output_dir}/decoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ --max_beam_width ${MAX_BEAM_WIDTH} \ --max_batch_size ${MAX_BATCH_SIZE} \ --max_seq_len 114 \ --max_input_len 14 \ --max_encoder_input_len 3000 \ --gemm_plugin ${INFERENCE_PRECISION} \ --bert_attention_plugin ${INFERENCE_PRECISION} \ --gpt_attention_plugin ${INFERENCE_PRECISION} \ --remove_input_padding disable
4. 按照 triton 部署的方式打包编译后的模型上传到 S3
- Triton 部署模型的文件结构,以 Whisper 模型部署为例
triton_model_whisper/ └── whisper ├── 1 │ ├── fbank.py │ ├── mel_filters.npz │ ├── model.py # Triton Inference Server的Python后端实现,用于部署经TensorRT-LLM优化的Whisper模型 │ ├── multilingual.tiktoken # whisper 模型的 tokenizer 文件 │ ├── tokenizer.py # whisper 模型的 tokenizer 文件 │ ├── whisper_large_v3 # TensorRT-LLM 编译后的模型文件 │ │ ├── decoder │ │ │ ├── config.json │ │ │ └── rank0.engine │ │ └── encoder │ │ ├── config.json │ │ └── rank0.engine │ └── whisper_trtllm.py # 定义了TensorRT-LLM优化后的Whisper模型推理类,包括 encoder 和 decoder 两个部分 └── config.pbtxt # 描述模型的输入输出、版本等信息的配置文件
- 打包模型文件夹上传到 S3
5. 调用 SageMaker SDK 进行部署
模型文件上传到 S3 之后,则可以参考 SageMaker 官方文档 Deploy models for real-time inference 中给出的步骤进行模型的部署。本项目的部署代码已有完整实现,可以参考代码仓库中的 deploy_and_test.ipynb 文件,按照 notebook 中的流程,修改对应的配置执行(deploy_config.sh),即可完成部署。
在本博客创建 SageMaker Endpoint 的过程中,使用的是 byoc (bring your own container) 的方式进行模型的部署,需要主要关注以下几个部分:
- 模型文件,已在第四步打包上传到 S3
- 模型部署文件,即 sagemaker-triton/model_data 下的所有文件,会在执行部署过程中,同样会打包上传到 S3(详见ipynb 文件),本文重点解释以下几个文件:
start_triton_and_client.sh
SageMaker Endpoint 创建时启动模型部署的入口文件,首先会启动 triton server 部署编译后的模型。run_server.py
,whisper_api.py
这两个文件主要是用 FastAPI 和 Triton Client 实现在部署节点内部使用 triton gRPC 端口进行模型的推理,以及实现 SageMaker Endpoint 上部署模型必须的 ping 接口和 invocation 接口(SageMaker 官方文档)。deploy_config.sh
定义了模型文件具体的 S3 路径。
注意,如果是全手动执行的整个过程,而不是执行(prepare_an_deploy.sh 完成的整个模型编译上传的过程)需要在部署之前,手动修改 deploy_config.sh 中的 S3 路径为实际上传后的路径。如下所示:
- 部署使用的镜像,执行
build_and_push.sh
这个脚本以后,路径会是以下格式:{account_id}.dkr.ecr.{region}.amazonaws.com/{REPO_NAME}:latest
6. 测试调用,完整代码参考 deploy_and_test.ipynb 文件
import boto3
import json
import base64
import os
import io
from pydub import AudioSegment
endpoint_name = endpoint_name
def encode_audio(audio_file_path):
# 加载音频文件
audio = AudioSegment.from_wav(audio_file_path)
# 检查是否为双通道
if audio.channels == 2:
print("检测到双通道音频,正在转换为单通道...")
# 将双通道转换为单通道
audio = audio.set_channels(1)
# 将音频数据写入内存缓冲区
buffer = io.BytesIO()
audio.export(buffer, format="wav")
buffer.seek(0)
# 将缓冲区的内容编码为 base64
return base64.b64encode(buffer.read()).decode('utf-8')
def invoke_sagemaker_endpoint(runtime_client, endpoint_name, audio_data, whisper_prompt=""):
"""Invoke SageMaker endpoint with audio data"""
payload = {
"whisper_prompt": whisper_prompt,
"audio_data": audio_data
}
response = runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType='application/json',
Body=json.dumps(payload)
)
result = json.loads(response['Body'].read().decode())
return result
def transcribe_audio(audio_path, endpoint_name, whisper_prompt=""):
# Read and encode the audio file
print("Reading and encoding audio file...")
audio_data = encode_audio(audio_path)
# Create a SageMaker runtime client
runtime_client = boto3.client('sagemaker-runtime')
# Invoke the SageMaker endpoint
print(f"Invoking SageMaker endpoint: {endpoint_name}")
result = invoke_sagemaker_endpoint(
runtime_client,
endpoint_name,
audio_data,
whisper_prompt
)
return result
# Example usage
if __name__ == "__main__":
# Set your parameters here
audio_path = "./your-audio.wav"
whisper_prompt = "" # Optional: add a prompt if needed, the defualt is <|startoftranscript|><|en|><|transcribe|><|notimestamps|>
# Call the function
result = transcribe_audio(audio_path, endpoint_name,whisper_prompt)
# Print the result
print("Transcription result:")
print(result)
增加推理参数
本次的 Whisper 部署中,增加重复惩罚(repetition_penalty)参数,提供在推理过程对重复控制的入口。
如果需要增加其他推理参数,则需修改对应的 triton 模型配置文件 config.pbtxt
, triton 模型文件中的 model.py
,whisper_trtllm.py
,以及 SageMaker 模型部署文件中的 whisper_api.py
。
部署参数调优
- dynamic_batching 参数的设置:Dynamic Batching 是 Triton 提供的一个重要特性,它可以自动将多个请求合并成一个批次进行处理,从而提高 GPU 利用率和整体吞吐量。
- max_queue_delay_microseconds
- 建议值范围: 100 至 infrence_time/2 微秒
- 较小的值(如 100)适合对延迟敏感的场景,较大的值可以提高吞吐量但会增加延迟
- 实践中发现设置为 1000/5000 是一个较好的平衡点
- Uvicorn worker 数量的设置直接影响服务的并发处理能力和资源利用率。设置的数量可以参考公示:(2 x CPU核心数) + 1。一般来讲,IO 密集型任务可以设置更多 worker,而 CPU 密集型任务建议减少 worker 数量,实际配置过程中可以通过压测来寻找最优值。更多详细设置方式可以参考 gunicorn 的说明。
结论
本博客的方法通过 Tensorrt-llm 编译模型,使用 Tritonserver 部署模型的方式可以得到较为理想的推理时延和吞吐,在未来可以使用 Tritonserver C++ 后端以及使用 inflight-batch 的实现进一步减小时延和吞吐。
总结与展望
本文详细介绍了如何利用 AWS 云服务和开源模型构建一个专门用于识别《水浒传》人物对话的语音识别系统。通过实践,我们取得了以下主要成果:
1. 技术成果
- 成功将 Whisper 模型的词错误率(WER)从 84% 降低到了 2.63%,提升了约 78%。
- 实现了基于 TorchServe 和 Triton 的两种部署方案,为不同场景提供灵活选择。
- 通过参数优化和配置调整,确保了模型在生产环境中的稳定性能。
2. 方法论贡献
- 提供了一个完整的工作流程,从数据准备到模型部署。
- 展示了如何利用 LoRA 技术进行高效模型微调。
- 提出了针对中文古典文学作品的特定优化方案。
3. 未来应用扩展
本项目的方法论和技术框架不仅适用于《水浒传》,还可以扩展到其他领域:
- 其他古典文学作品的语音识别
- 专业领域术语的准确识别
- 多方言环境下的语音处理
参考链接
https://docs.thinkwithwp.com/sagemaker/latest/dg/your-algorithms-inference-code.html
https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper
https://github.com/triton-inference-server