亚马逊AWS官方博客

基于 Claude 3 和 WhisperX 构建 ASR 方案(二)

1. 前言

在《基于 Claude 3 和 WhisperX 实现 ASR 方案(一)》中我们介绍了 WhisperX 模型的实现原理,以及如何在 AWS 上快速部署和使用 WhisperX 模型,实现语音转文字,视频字幕生成与对齐,识别不同的说话人等功能。通过 StreamlitUI 的方式,我们可以快速对 YouTube 视频、本地音频文件实现 ASR,方便我们做技术调研和 Demo 演示。但是在实际的生产环境应用中,我们会有大量的视频或音频文件生成,我们需要对这些离线音视频做批量的转录,除此之外,我们也需要对转录之后的文本内容进行总结,以方便我们快速提取重要信息,实现内容审核等。本篇我们将继续介绍在生产环境中,我们如何利用 WhisperX 实现无侵入性、自动化、可扩展的 ASR 方案,并使用 Amazon Bedrock 对 ASR 转录之后的内容实现自动总结和审计。

2. 架构介绍

整体架构图:

2.1 架构介绍

本方案采用多层架构,可以平滑接入现有业务应用系统,上传音视频文件到 S3 后自动实现转录,结合 ASG 和 SQS 实现自动扩容和批量处理,配合 API 接口查询转录结果,并同时提供 StreamlitUI 的交互式界面上传文件。下面是各个组件的说明。

自动转录流程:应用端将音视频文件上传到 S3 之后,触发 Lambda 获取上传事件,Lambda 判断文件是否符合音视频文件格式,并构造消息,写入到 SQS 队列中。AutoScaling Group 中启动 GPU 服务器,部署 WhisperX 的推理服务,从 SQS 队列中接收消息实现 ASR 转录,将转录结果回传到 S3,并调用 Amazon Bedrock 对转录后的文本进行总结和审核,将结果同样传回到 S3 中。当大量文件上传到 S3 时,SQS 队列阻塞,Auto Scaling 根据消息长度自动扩容集群,实现快速批量处理。

API 查询转录结果:由于转录为异步处理流程,API 服务将提供 HTTP 接口用于实时查询转录结果。AutoScaling Group 中的 EC2 同时部署了 FastAPI 服务,并通过 Application LoadBalancer 实现负载均衡。应用服务器上传音视频文件后通过 API 接口可以随时获取转录结果。

StreamUI 交互式界面:AutoScaling Group 中的 EC2  同时也通过 Streamlit 部署了交互式的 UI 服务,为了优化交互体验,通过 CloudFront 进行网络加速。对于某些需要临时处理的少量文件,可以通过 UI 界面进行交互式的转录和总结。

3. 部署项目

本项目使用 CloudFormation 实现一键部署,自动创建所需的 VPC 网络环境、S3 存储桶、Lambda 函数、SQS 队列、Auto Scaling Group、EC2、Application LoadBalancer、CloudFront 等资源。本文源码已经发布在 https://github.com/superyhee/whisper-on-aws-jumpstart

安装指南:

  • 通过控制台创建 EC2 Key Pair,用于登录服务器
  • 在控制台进入 CloudFormation 服务,点击创建堆栈,选择上传模板文件,从代码目录中选择 whisper-prod.yaml,点击下一步
  • 设置参数,在参数列表中我们需要填写以下堆栈名称、存储桶名称、huggingface token、密钥对名称,其他参数为可选项,机型默认使用 xlarge。参数填完后直接下一步到创建堆栈即可
  • 堆栈创建完成后,在输出 tab 页可以看到用于访问 Streamlit 的 CloudFront 地址和用于 API 结果检查的 ALB 地址

4. 项目使用指南

4.1 UI 使用指南

在浏览器中打开 CloudFormation 输出的 CloudFrontDomainName,可以访问到 Streamlit 的 UI 界面

在《基于 Claude 3 和 WhisperX 实现 ASR 方案(一)》中我们重点介绍了 UI 的使用方式,这里不再赘述,我们将重点介绍基于 S3 事件的自动处理流程。

4.2  S3 自动处理流程使用指南

在 AWS 控制台进入 S3 服务,从列表中找到我们设置的 S3 存储桶,点击上传音视频文件

在 Properties 中设置标签,标签名称 summary 表示是否生成总结,audit 表示是否进行内容审核

文件上传后将自动实现转录,总结和审核,处理完成之后会在 S3 存储桶中生成结果文件,转录文件为 .txt 文件,总结文件为 .json 文件,审核文件为 -audit.json 文件,如下图

在该流程中,我们为了尽量减少对现有业务逻辑的侵入,通过 S3 文件标签的方式将参数传入到处理逻辑中。在应用中,我们可以通过 SDK 在上传文件的时候指定标签参数,对于 Python SDK,示例如下:

import boto3
def upload_file_to_s3(file_path, bucket_name, object_key):

    s3_client = boto3.client('s3')
    try:
        s3_client.upload_file(file_path, bucket_name, object_key)
        tags = {
            'TagSet': [
                {
                    'Key': 'summary',
                    'Value': 'true'
                },
                {
                    'Key': 'audit',
                    'Value': 'true'
                }
            ]
        }
        s3_client.put_object_tagging(Bucket=bucket_name, Key=object_key, Tagging=tags)
        
        return f"文件 {file_path} 已成功上传到 S3 存储桶 {bucket_name}/{object_key}"

4.3 API 使用指南

文件上传完成之后,我们可以通过以下 API 的方式查询文件是否转录成功,其中域名为 CloudFormation 的输出中负载均衡器的地址:

curl http://whispe-xxxx.elb.us-west-2.amazonaws.com:9000/check_files?file_path=s3://whisper-bucket-xxx/test.mp4

5. 项目代码说明

Lambda 主要用于接收 S3 上传事件,并将视频和音频文件的上传消息传递到 SQS 中,目前支持的文件格式为 mp3,mp4,m4a,如果需要支持更多格式,可以进行扩展,代码如下:

import json
import boto3
import urllib.parse
import os

s3 = boto3.client('s3')
sqs = boto3.client('sqs')

def lambda_handler(event, context):
    # 获取S3事件记录
    bucket_name = event['Records'][0]['s3']['bucket']['name']
    file_key = urllib.parse.unquote_plus(event['Records'][0]['s3']['object']['key'], encoding='utf-8')

    queue_url = os.environ['QUEUE_URL']

    try:
        # 判断文件后缀,如果文件不是以 .mp3 .mp4 .m4a 结尾的,则直接return
        file_extension = file_key.split('.')[-1].lower()
        allowed_extensions = ['mp3', 'mp4', 'm4a']
        if file_extension not in allowed_extensions:
            return {
            'statusCode': 200,
            'body': json.dumps('File {file_key} is not an audio or video file. Skipping...')
        }

        # 获取文件标签
        response = s3.get_object_tagging(
            Bucket=bucket_name,
            Key=file_key
        )
        tags = {tag['Key']: tag['Value'] for tag in response['TagSet']}

        # 构建SQS消息体
        message_body = {
            'bucket': bucket_name,
            'key': file_key,
            'tags': tags
        }

        # 发送SQS消息
        sqs.send_message(
            QueueUrl=queue_url,
            MessageBody=json.dumps(message_body)
        )
    except Exception as e:
        print(e)
        raise e

whisper_sqs_message_processor.py 是该项目核心服务,主要从 SQS 中接收消息,调用 WhisperX 进行 ASR 转录,并根据文件上传时候携带的标签参数调用 Amazon Bedrock 进行总结和审计,SQSMessageProcessor 封装了 SQS 消息接收和处理逻辑的抽象类,具体代码如下:

import json
import boto3
import os
import logging
import whisperx_transcribe
from subprocess import call
from sqs_message_processor import SQSMessageProcessor
from bedrock_handler.summary_bedrock_handler import SummaryBedrockHandler
from bedrock_handler.audit_bedrock_handler import AuditBedrockHandler
from dotenv import load_dotenv
load_dotenv()

class WhisperSQSMessageProcessor(SQSMessageProcessor): 
    def __init__(self, queue_url, max_number_of_messages=20, wait_time_seconds=10):
        super().__init__(queue_url, max_number_of_messages, wait_time_seconds)
        self.s3 = boto3.client('s3',region_name=self.region)
        self.bedrock_runtime = boto3.client(service_name='bedrock-runtime',region_name=self.region)
        self.audio_extensions = ['.wav', '.mp3', '.m4a']
        self.video_extensions = ['.mp4', '.avi', '.mkv', '.mov']
        
    def download_file(self, bucket_name, object_key):
        self.s3.download_file(bucket_name, object_key, f'/tmp/{object_key}')
        self.logger.info("Downloaded file from S3: s3://%s/%s",bucket_name,object_key)

    def convert_to_audio(self, video_file):
        self.logger.info("Converting the video_file %s to audio_file",video_file)
        audio_file = os.path.splitext(video_file)[0] + '.wav'
        call(['ffmpeg', '-i', video_file, '-vn', '-ar', '16000', '-ac', '1', '-f', 'wav', audio_file])
        self.logger.info(f"The audio file is : %s",audio_file)
    
    def get_tag_value(self, tags, key):
        value = tags.get(key)
        if value:
            return value
        else:
            return ""

    def llm_summary(self,transcription_text):
        llm = SummaryBedrockHandler(region=self.region,content=transcription_text)
        response_body = llm.invoke()
        self.logger.info("The summary info is : %s",response_body)
        return response_body

    def llm_audit(self,transcription_text):
        llm = AuditBedrockHandler(region=self.region,content=transcription_text)
        response_body = llm.invoke()
        self.logger.info("The audit info is : %s",response_body)
        return response_body

    def transcribe(self, audio_file, message_body,message_receipt_handle):
        self.logger.info(f"Transcribing the audio file : %s",audio_file)
        file_size_mb = os.path.getsize(audio_file) / (1024 * 1024)
        visibility_timeout = int(file_size_mb * 5 * 10)  # 每MB文件10秒
        self.change_message_visibility(message_receipt_handle, visibility_timeout)

        # 执行transcribe操作
        model_size = self.get_tag_value(message_body.get('tags', {}),"model_size")
        if model_size == "":
            model_size = "medium"
        self.logger.info("Use the model size:{0}".format(model_size))
        transcription_text = whisperx_transcribe.transcribe(audio_file, model_size)
        self.logger.info("The transcription text:{0}".format(transcription_text))
        # 判断transcription是否为空,如果不为空,则将transcription内容以json文件的形式上传到S3中
        if transcription_text:
            bucket_name = message_body['bucket']
            object_key = message_body['key']
            transcription_key = f"{os.path.splitext(object_key)[0]}.json"
            self.s3.put_object(Body=json.dumps(transcription_text,ensure_ascii=False), Bucket=bucket_name, Key=transcription_key)
            self.logger.info("Uploaded transcription text to s3://%s/%s",bucket_name,transcription_key)

        # 判断message中的tags,如果tags中存在summary的tag,则调用llm_summary方法对transcription内容进行总结
        if 'summary' in message_body.get('tags', []):
            summary = self.llm_summary(transcription_text)
            summary_key = f"{os.path.splitext(object_key)[0]}.txt"
            self.s3.put_object(Body=summary.encode('utf-8'), Bucket=bucket_name, Key=summary_key)
            self.logger.info("Uploaded summary to s3://%s/%s",bucket_name,summary_key)
        # 判断message中的tags,如果tags中存在audit的tag,则调用llm_audit方法对transcription内容进行总结
        if 'audit' in message_body.get('tags', []):
            audit = self.llm_audit(transcription_text)
            audit_key = f"{os.path.splitext(object_key)[0]}-audit.txt"
            self.s3.put_object(Body=summary.encode('utf-8'), Bucket=bucket_name, Key=audit_key)
            self.logger.info("Uploaded audit file to s3://%s/%s",bucket_name,audit_key)
    
    # 实现抽象方法,处理业务逻辑
    def process_message(self, message):
        #1. get info from the message
        self.logger.info("Handle the message begin ...")
        message_body = json.loads(message['Body'])
        bucket_name = message_body['bucket']
        object_key = message_body['key']
        message_receipt_handle = message['ReceiptHandle']

        #2.download the file from s3
        self.download_file(bucket_name, object_key)

        #3.检查文件类型和文件大小
        file_extension = os.path.splitext(object_key)[1].lower()
        if file_extension in self.audio_extensions:
            self.transcribe(f'/tmp/{object_key}',message_body,message_receipt_handle)
        elif file_extension in self.video_extensions:
            self.convert_to_audio(f'/tmp/{object_key}')
            self.transcribe(f'/tmp/{os.path.splitext(object_key)[0]}.wav',message_body,message_receipt_handle)
        else:
            self.logger.info("Unsupported file type: %s",file_extension)
if __name__ == '__main__':
    # queue_url = sys.argv[1]
    queue_url = os.environ['SQS_QUEUE_URL']
    processor = WhisperSQSMessageProcessor(queue_url, max_number_of_messages=1, wait_time_seconds=20)
    processor.process()

bedrock_handler.py 封装了对 Bedrock 调用的抽象方法,并在 summary_bedrock_handler.py 和 audit_bedrock_handler.py 中分别定义了 PE,利用大模型对文本内容进行总结和审核:

import json
from bedrock_handler.bedrock_handler import BedrockHandler

class SummaryBedrockHandler(BedrockHandler):
    def __init__(self,region,model_id="anthropic.claude-3-haiku-20240307-v1:0",max_tokens = 1000,content=""):
        super().__init__(region,model_id=model_id,max_tokens = 1000)
        self.content = content

    def prompt(self):
        system_prompt = """
你是一个文案专员,请认真阅读其中的内容<transcription_text>标签中包含的上下文内容,并按照以下要求进行总结
- 识别 <transcription_text> 中的语言种类,用相同语言进行总结和返回
- 理解 <transcription_text> 中的主要情节和场景,用精简的语言总结内容
- 如果 <transcription_text> 中有多个speaker,请分别总结每个人的情感情绪和想要表达的中心思想

以下是上下文:
<transcription_text>
{speak_context}
</transcription_text>
"""
        system_prompt = system_prompt.replace("speak_context",json.dumps(self.content))
        return system_prompt
import json
from bedrock_handler.bedrock_handler import BedrockHandler


class AuditBedrockHandler(BedrockHandler):
    def __init__(self,region,model_id="anthropic.claude-3-haiku-20240307-v1:0",max_tokens = 1000,content=""):
        super().__init__(region, model_id=model_id, max_tokens = 1000)
        self.content = content

    def prompt(self):
        system_prompt = """ 你是一个文本内容审核专家,你的任务是审查<audit_content>中给定的文本,并从以下<tags>标签列表中选择最合适的标签.
在审核过程中,请保持客观、中立的态度,不带任何个人偏见和主观判断。如果选择了2-6任一标签,请在标签后简要说明是谁的内容以及为什么不合规.
如果无法确定内容属性,请选择"无法判断"。请在审核结果前注明"审核结果:"
<tags>
1. 合规
2. 色情内容
3. 暴力内容 
4. 仇恨言论
5. 政治敏感
6. 争议观点
7. 无法判断
</tags>

以下是需要审核的内容:
<audit_content>
{speak_context}
</audit_content>

api.py 文件使用 Fast API 定义了简单的转录结果查询的接口:

from fastapi import FastAPI, HTTPException
import boto3

app = FastAPI()

# Create an S3 client
s3 = boto3.client('s3',region_name='us-east-1')

@app.get("/check_files")
async def check_files(file_path: str):
    # Parse the input file path
    bucket_name, key = file_path.replace("s3://", "").split("/", 1)

    # Define the file extensions to check
    extensions = [".json", ".txt"]

    # Create a list to store the existing file URLs
    result = {
        "status":200,
        "origin_file":file_path,
        "transcribe_file":"None",
        "transcribe_content":"None",
        "summary_file":"None",
        "summary_content":"None"
    }
    # Check if each file exists
    for ext in extensions:
        key_path,_ = key.split(".", -1)
        file_key = key_path+ext
        try:
            s3_object = s3.get_object(Bucket=bucket_name, Key=file_key)
            content = s3_object['Body'].read().decode('utf-8')
            if ext == ".json":
                result["transcribe_file"] = f"s3://{bucket_name}/{file_key}"
                result["transcribe_content"] = content
            elif ext == ".txt":
                result["summary_file"] = f"s3://{bucket_name}/{file_key}"
                result["summary_content"] = content
        except s3.exceptions.ClientError as e:
            if e.response['Error']['Code'] == "404":
                # The object does not exist, continue to the next extension
                print("file not found")
            else:
                result["status"]=500
                result["error"]=str(e)
                return result

    # Return the list of existing file URLs
    return result

在 Auto Scaling 启动 EC2 之后,我们需要通过 Userdata 的方式实现服务自启动,为此需要定义启动脚本 launch.sh,在启动脚本中我们对 S3 添加文件上传事件触发 Lambda 执行,顺序启动 API 服务、Streamlit UI 和 SQS 消费者服务,代码如下:

#!/bin/bash

# Switch to the project directory
cd /home/ubuntu/whisper/ || exit

add_notification_for_s3(){
   local bucket_name=$(read_env_var "BUCKET_NAME")
   local function_name=$(read_env_var "LAMBDA_NAME")
   local region=$(read_env_var "SQS_QUEUE_URL" | awk -F'.' '{print $2}')

   # 获取 Lambda 函数的 ARN
    function_arn=$(aws lambda get-function --function-name "$function_name" --query 'Configuration.FunctionArn' --region $region --output text)

    # 创建事件通知配置
    notification_config=$(cat <<EOF
{
  "LambdaFunctionConfigurations": [
    {
      "LambdaFunctionArn": "$function_arn",
      "Events": [
        "s3:ObjectCreated:*"
      ]
    }
  ]
}
EOF
)
    # 添加事件通知
    aws s3api put-bucket-notification-configuration \
        --bucket "$bucket_name" \
        --notification-configuration "$notification_config"
        
    echo "事件通知已添加到存储桶 $bucket_name,当创建新对象时将触发 Lambda 函数 $function_name"
}

# Function to read a variable from the .env file
read_env_var() {
    local var_name=$1
    local var_value=$(cat .env | grep "^${var_name}=" | awk -F'=' '{print $2}')
    echo "$var_value"
}

# Start the API server if enabled
start_api_server() {
    local start_api=$(read_env_var "LAUNCH_API")
    if [ "$start_api" = "True" ]; then
        echo "Starting the API server"
        nohup uvicorn api:app --host 0.0.0.0 --port 9000 --reload &>/dev/null &
    fi
}

# Start the Streamlit demo if enabled
start_demo() {
    local start_demo=$(read_env_var "LAUNCH_DEMO")
    if [ "$start_demo" = "True" ]; then
        echo "Starting the Streamlit demo"
        nohup streamlit run ui.py --server.headless=true --server.port=8501 --browser.serverAddress=0.0.0.0 &>/dev/null &
    fi
}

# Start the event workflow if enabled
start_event_workflow() {
    local start_event=$(read_env_var "LAUNCH_EVENT")
    if [ "$start_event" = "True" ]; then
        echo "Starting the event workflow"
        nohup python3 whisper_sqs_message_processor.py &>/dev/null &
    fi
}

# Start the required services
add_notification_for_s3
start_api_server
start_demo
start_event_workflow

wait

其余完整代码请参考 :https://github.com/superyhee/whisper-on-aws-jumpstart

6. 总结

本文介绍了一种基于 AWS 云服务、WhisperX 开源语音识别模型和 Claude 3 大型语言模型的自动语音转录(ASR)方案。该方案为语音数据处理提供了完整的端到端解决方案和参考实现。WhisperX 的语音视频与字幕对齐技术为多媒体内容处理带来了革命性变化。它不仅提高了视频内容的可理解性和编辑效率,还为视频制作、教育和娱乐等领域开创了创新可能。随着技术不断进步,未来的语音视频与字幕对齐技术有望变得更加精准、高效和智能。


*前述特定亚马逊云科技生成式人工智能相关的服务仅在亚马逊云科技海外区域可用,亚马逊云科技中国仅为帮助您了解行业前沿技术和发展海外业务选择推介该服务。

本篇作者

严军

亚马逊云科技解决方案架构师,目前主要负责帮客户进行云架构设计和技术咨询,对容器化等技术方向有深入的了解,在云迁移方案设计和实施方面有丰富的经验。

贺杨

亚马逊云科技解决方案架构师,具备 17 年 IT 专业服务经验,工作中担任过研发、开发经理、解决方案架构师等多种角色。在加入亚马逊云科技前,拥有多年外企研发和售前架构经验,在传统企业架构和中间件解决方案有深入的理解和丰富的实践经验。

粟伟

亚马逊云科技资深解决方案架构师,专注游戏行业,开源项目爱好者,致力于云原生应用推广、落地。具有 15 年以上的信息技术行业专业经验,担任过高级软件工程师,系统架构师等职位,在加入 AWS 之前曾就职于 Bea,Oracle,IBM 等公司。