流式推理 vs 训练模式详细对比

2026-01-14 21:18:52 生存指南

文章目录

一、概述

什么是训练模式?

什么是流式推理模式?

二、核心区别总览

快速对比表

工作流程对比图

三、详细对比分析

[1. 状态管理机制](#1. 状态管理机制)

[1.1 训练模式:无状态处理](#1.1 训练模式:无状态处理)

[1.2 流式推理:有状态处理](#1.2 流式推理:有状态处理)

[2. 网络行为差异](#2. 网络行为差异)

[2.1 Layer Dropout机制](#2.1 Layer Dropout机制)

[2.2 RandomCombine机制](#2.2 RandomCombine机制)

[3. 数据处理流程](#3. 数据处理流程)

[3.1 训练模式的完整数据流](#3.1 训练模式的完整数据流)

[3.2 流式推理的完整数据流](#3.2 流式推理的完整数据流)

[4. 适用场景详解](#4. 适用场景详解)

[4.1 训练模式的应用场景](#4.1 训练模式的应用场景)

[4.2 流式推理的应用场景](#4.2 流式推理的应用场景)

四、代码示例

训练模式完整示例

流式推理完整示例

五、性能对比

[1. 吞吐量对比](#1. 吞吐量对比)

训练模式

流式推理

[2. 延迟对比](#2. 延迟对比)

训练模式延迟

流式推理延迟

[3. 内存占用对比](#3. 内存占用对比)

训练模式内存

流式推理内存

[4. 计算效率对比](#4. 计算效率对比)

批量处理效率

Chunk大小影响

六、最佳实践

训练模式最佳实践

[1. Warmup调度](#1. Warmup调度)

[2. Batch Size选择](#2. Batch Size选择)

[3. 梯度累积](#3. 梯度累积)

[4. 混合精度训练](#4. 混合精度训练)

流式推理最佳实践

[1. 状态管理](#1. 状态管理)

[2. Chunk大小选择](#2. Chunk大小选择)

[3. 内存优化](#3. 内存优化)

[4. 多线程/多进程](#4. 多线程/多进程)

七、常见问题

[Q1: 流式推理的结果和训练时不一致?](#Q1: 流式推理的结果和训练时不一致?)

[Q2: 流式推理时chunk边界有断裂感?](#Q2: 流式推理时chunk边界有断裂感?)

[Q3: 多会话时显存不足?](#Q3: 多会话时显存不足?)

[Q4: 如何加速流式推理?](#Q4: 如何加速流式推理?)

八、总结

核心差异

选择建议

关键要点

一、概述

在LSTM-based RNN编码器中,训练模式(Training Mode) 和流式推理模式(Streaming Inference Mode) 是两种完全不同的工作方式。理解它们的区别对于正确使用模型至关重要。

什么是训练模式?

训练模式用于学习模型参数,处理完整的音频序列,通过反向传播优化网络权重。

特点:

批量处理多个完整样本

需要计算梯度

使用随机性(Dropout、RandomCombine)提高泛化能力

高吞吐量,高延迟

什么是流式推理模式?

流式推理模式用于实时应用,将音频流分成小chunk逐段处理,通过维护LSTM状态保持连续性。

特点:

单样本分chunk处理

不需要梯度

无随机性,结果确定

低延迟,适合实时场景

二、核心区别总览

快速对比表

维度

训练模式

流式推理模式

主要目标

学习模型参数

实时输出结果

数据形式

完整序列一次性处理

音频流分chunk逐段处理

批次大小

较大 (5-32+)

通常为1

序列长度

长 (数百到数千帧)

短 (16-32帧/chunk)

状态管理

❌ 不需要

✅ 必须维护LSTM状态

模式标志

model.train()

model.eval()

梯度计算

✅ 需要 (requires_grad=True)

❌ 不需要 (torch.no_grad())

RandomCombine

✅ 启用,随机组合层输出

❌ 禁用,只用最后一层

Layer Dropout

✅ 启用 (alpha可能<1)

❌ 禁用 (alpha=1.0)

Warmup参数

✅ 使用,控制层bypass

❌ 固定为1.0

内存占用

高 (~65MB/batch)

低 (~125KB/chunk)

延迟

高 (秒级)

低 (毫秒级)

吞吐量

高 (50,000帧/秒)

中等 (16,000帧/秒)

GPU利用率

高 (批量并行)

低 (单样本)

确定性

❌ 非确定性

✅ 确定性

工作流程对比图

复制代码

训练模式流程:

┌─────────────────────────────────────────┐

│ 输入: Batch样本 (N, T_long, F) │

│ 例如: (32, 1000, 80) │

└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐

│ 卷积下采样 (4倍) │

│ (32, 1000, 80) → (32, 247, 512) │

└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐

│ 12层LSTM编码器 │

│ - 从零状态开始 │

│ - Layer Dropout: 随机bypass一些层 │

│ - RandomCombine: 随机组合多层输出 │

└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐

│ 输出: (32, 247, 512) │

│ 计算损失 → 反向传播 → 更新参数 │

└─────────────────────────────────────────┘

流式推理流程:

┌─────────────────────────────────────────┐

│ 初始化状态: states = get_init_states() │

└─────────────────────────────────────────┘

┌─────────────────┐

│ 音频流循环 │

└─────────────────┘

┌─────────────────────────────────────────┐

│ 输入: Chunk (1, T_short, F) │

│ 例如: (1, 16, 80) │

└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐

│ 卷积下采样 (4倍) │

│ (1, 16, 80) → (1, 1, 512) │

└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐

│ 12层LSTM编码器 │

│ - 使用前一chunk的状态 │

│ - 无Layer Dropout │

│ - 无RandomCombine,只用最后一层 │

│ - 输出新状态 │

└─────────────────────────────────────────┘

┌─────────────────────────────────────────┐

│ 输出: (1, 1, 512) + new_states │

│ states ← new_states (用于下一chunk) │

└─────────────────────────────────────────┘

(回到音频流循环)

三、详细对比分析

1. 状态管理机制

1.1 训练模式:无状态处理

python

复制代码

# RNN.forward() - 训练模式代码片段

if states is None: # 训练时states为None

# 每个样本从零状态开始,样本间完全独立

x = self.encoder(x, warmup=warmup)[0]

# 返回空状态(仅为满足接口要求)

new_states = (torch.empty(0), torch.empty(0))

原理说明:

LSTM的hidden和cell状态初始化为零向量

每个训练样本是完整的独立utterance

样本之间没有时序关系,可以shuffle

不需要记忆之前的信息

适用场景:

离线训练:每个样本是完整录音

批量评估:处理录音文件集合

不关心样本间的连续性

1.2 流式推理:有状态处理

python

复制代码

# RNN.forward() - 流式推理代码片段

if states is not None: # 流式时必须提供states

# 确保在评估模式

assert not self.training

# 验证状态的形状

assert len(states) == 2

assert states[0].shape == (num_layers, batch_size, d_model)

assert states[1].shape == (num_layers, batch_size, rnn_hidden_size)

# 使用之前的状态处理当前chunk

x, new_states = self.encoder(x, states)

状态内容:

python

复制代码

states = (hidden_states, cell_states)

# hidden_states: (12, 1, 512)

# - 12层,每层的隐藏状态

# - 用于LSTM的输出

#

# cell_states: (12, 1, 1024)

# - 12层,每层的细胞状态

# - LSTM的内部记忆

状态初始化:

python

复制代码

# 第一个chunk开始前

states = model.get_init_states(batch_size=1, device=device)

# 内部实现

def get_init_states(self, batch_size=1, device=torch.device("cpu")):

hidden_states = torch.zeros(

(self.num_encoder_layers, batch_size, self.d_model),

device=device

)

cell_states = torch.zeros(

(self.num_encoder_layers, batch_size, self.rnn_hidden_size),

device=device

)

return (hidden_states, cell_states)

状态传递流程:

python

复制代码

# 流式推理主循环

states = model.get_init_states(batch_size=1, device=device)

for chunk in audio_stream:

# 1. 使用当前states处理chunk

embeddings, lengths, new_states = model(

chunk, chunk_lens, states=states

)

# 2. 更新states,传递给下一个chunk

states = new_states

# 3. 使用embeddings做后续处理

process(embeddings)

为什么需要状态?

保持连续性:音频流是连续的,LSTM需要记住之前的信息

上下文依赖:当前chunk的理解依赖之前的context

避免边界效应:chunk边界不会导致信息丢失

状态管理注意事项:

⚠️ 必须正确传递states,否则每个chunk独立处理

⚠️ 新对话/新音频流需要重置states

⚠️ 多线程场景需要为每个流维护独立的states

2. 网络行为差异

2.1 Layer Dropout机制

训练模式 - 有Layer Dropout:

python

复制代码

# RNNEncoderLayer.forward()

def forward(self, src, states=None, warmup=1.0):

src_orig = src # 保存原始输入

# 计算warmup缩放

warmup_scale = min(0.1 + warmup, 1.0)

if self.training:

# 训练时:随机决定是否bypass该层

if torch.rand(()).item() <= (1.0 - self.layer_dropout):

alpha = warmup_scale # 使用该层

else:

alpha = 0.1 # bypass该层

else:

alpha = 1.0 # 推理时完全使用

# ... LSTM和FeedForward处理 ...

# 应用layer dropout

if alpha != 1.0:

# 混合原始输入和处理后的输出

src = alpha * src + (1 - alpha) * src_orig

return src, new_states

Alpha值的含义:

alpha = 1.0: 完全使用该层的输出

alpha = 0.1: 基本bypass该层(90%使用原始输入)

0.1 < alpha < 1.0: 部分使用该层

Layer Dropout的作用:

渐进式训练:训练初期(warmup小)更频繁bypass层,减少训练难度

正则化:随机bypass增强模型鲁棒性

加速收敛:避免深层网络训练初期梯度问题

Warmup调度示例:

python

复制代码

# 训练循环

total_steps = 100000

warmup_steps = 10000

for step in range(total_steps):

# 前10000步warmup从0增长到1

warmup = min(1.0, step / warmup_steps)

# warmup对layer dropout的影响:

# step=0: warmup=0, warmup_scale=0.1

# step=5000: warmup=0.5, warmup_scale=0.6

# step>=10000: warmup=1.0, warmup_scale=1.0

output = model(x, x_lens, warmup=warmup)

流式推理 - 无Layer Dropout:

python

复制代码

# 推理模式

if self.training:

# 训练逻辑(上面的代码)

else:

alpha = 1.0 # 始终完全使用每一层

# 结果:

# src = 1.0 * src + (1-1.0) * src_orig = src

# 不会混合原始输入,完全使用处理后的输出

为什么推理不用Layer Dropout?

确定性:推理结果需要可复现

最优性能:使用全部层获得最佳效果

无正则化需求:推理不需要防止过拟合

2.2 RandomCombine机制

训练模式 - 启用RandomCombine:

python

复制代码

# RNNEncoder.forward()

def forward(self, src, states=None, warmup=1.0):

output = src

outputs = [] # 存储辅助层输出

# 逐层处理

for i, layer in enumerate(self.layers):

output = layer(output, warmup=warmup)[0]

# 收集辅助层输出

if self.combiner is not None and i in self.aux_layers:

outputs.append(output)

# 训练时:随机组合多层输出

if self.combiner is not None:

output = self.combiner(outputs)

return output, new_states

RandomCombine的实现:

python

复制代码

# RandomCombine.forward()

def forward(self, inputs): # inputs是多层的输出列表

# 推理时:直接返回最后一层

if not self.training:

return inputs[-1]

# 训练时:随机组合

# 例如:inputs = [layer4_out, layer7_out, layer10_out, layer11_out]

# 生成随机权重

weights = self._get_random_weights(...)

# weights: (num_frames, 4),每帧的权重不同

# 加权组合

output = weighted_sum(inputs, weights)

return output

随机权重生成策略:

python

复制代码

# 以pure_prob=0.333的概率:选择单一层(one-hot)

if rand() < 0.333:

# 以final_weight=0.5的概率选择最后一层

if rand() < 0.5:

weights = [0, 0, 0, 1] # 最后一层

else:

weights = [1, 0, 0, 0] # 随机非最后层

# 或 [0, 1, 0, 0], [0, 0, 1, 0]

# 以(1-pure_prob)=0.667的概率:加权组合

else:

# 生成连续权重,给最后一层更高权重

log_weights = randn(4) * stddev

log_weights[3] += final_log_weight

weights = softmax(log_weights)

# 例如: [0.1, 0.2, 0.15, 0.55]

RandomCombine的作用:

类似Iterated Loss:让中间层也参与最终输出

改善梯度流:中间层获得更直接的监督信号

提高鲁棒性:测试时只用最后一层也能工作

辅助层配置示例:

python

复制代码

# 12层网络,aux_layer_period=3

aux_layers = list(range(12//3, 12-1, 3))

# aux_layers = [4, 7, 10]

# 加上最后一层: [4, 7, 10, 11]

# RandomCombine会随机组合这4层的输出

流式推理 - 禁用RandomCombine:

python

复制代码

# RandomCombine.forward()

def forward(self, inputs):

if not self.training:

# 推理时:只返回最后一层

return inputs[-1]

# (训练逻辑被跳过)

为什么推理只用最后一层?

效率:不需要计算随机权重

最优性能:最后一层通常表现最好

确定性:避免随机性

3. 数据处理流程

3.1 训练模式的完整数据流

数据准备:

python

复制代码

# DataLoader批次

batch = {

'features': torch.randn(32, 1000, 80), # 32个样本,最长1000帧

'feature_lens': torch.tensor([1000, 980, 950, ..., 600]), # 实际长度

'targets': ..., # 目标标签

}

# 特点:

# 1. 批量处理:32个样本并行

# 2. 变长序列:使用padding统一长度

# 3. 完整utterance:每个样本是完整的录音

前向传播:

python

复制代码

# 设置训练模式

model.train()

# 准备数据

x = batch['features'] # (32, 1000, 80)

x_lens = batch['feature_lens'] # (32,)

targets = batch['targets']

# 前向传播

with torch.enable_grad(): # 需要梯度

embeddings, lengths, _ = model(

x,

x_lens,

states=None, # 不传递状态

warmup=current_warmup # 当前warmup值

)

# embeddings: (32, 247, 512)

# lengths: (32,) - [247, 242, 234, ..., 147]

# 计算损失(例如CTC Loss或Transducer Loss)

loss = criterion(embeddings, targets, lengths)

反向传播:

python

复制代码

# 梯度清零

optimizer.zero_grad()

# 反向传播

loss.backward()

# 梯度裁剪(可选)

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

# 参数更新

optimizer.step()

# 学习率调度(可选)

scheduler.step()

内存占用分析:

python

复制代码

# 前向传播需要保存的张量:

# 1. 输入: (32, 1000, 80) = 32 * 1000 * 80 * 4bytes ≈ 10MB

# 2. Conv输出: (32, 247, 512) ≈ 16MB

# 3. 每层LSTM输出: (247, 32, 512) ≈ 16MB * 12层 = 192MB

# 4. 梯度: 约等于参数量(96.5M参数 * 4bytes ≈ 386MB)

#

# 总计:约 600MB (单个batch)

# 实际GPU显存占用:1-2GB(包括优化器状态等)

3.2 流式推理的完整数据流

初始化:

python

复制代码

# 设置评估模式

model.eval()

# 移动到设备

device = torch.device('cuda')

model = model.to(device)

# 初始化状态

states = model.get_init_states(batch_size=1, device=device)

# states[0]: (12, 1, 512) - hidden states

# states[1]: (12, 1, 1024) - cell states

音频流处理:

python

复制代码

# 模拟音频流(实际应用中从麦克风/网络获取)

def audio_stream_generator(audio_file, chunk_size=16):

"""

从音频文件生成chunk流

Args:

audio_file: 音频文件路径

chunk_size: 每个chunk的帧数

Yields:

chunk: (1, chunk_size, 80)

"""

# 加载音频

features = load_audio_features(audio_file) # (T, 80)

# 分chunk

for i in range(0, len(features), chunk_size):

chunk = features[i:i+chunk_size]

# 填充到chunk_size(最后一个chunk可能不足)

if len(chunk) < chunk_size:

chunk = F.pad(chunk, (0, 0, 0, chunk_size - len(chunk)))

# 添加batch维度

chunk = chunk.unsqueeze(0) # (1, chunk_size, 80)

yield chunk, min(chunk_size, len(features) - i)

# 主处理循环

all_embeddings = []

with torch.no_grad(): # 推理不需要梯度

for chunk, chunk_len in audio_stream_generator(audio_file):

# 移动到设备

chunk = chunk.to(device)

chunk_lens = torch.tensor([chunk_len], device=device)

# 处理当前chunk

embeddings, lengths, new_states = model(

chunk, # (1, 16, 80)

chunk_lens, # (1,)

states=states, # 使用上一chunk的状态

warmup=1.0 # 推理不使用warmup

)

# 保存结果

all_embeddings.append(embeddings)

# 更新状态

states = new_states

# 实时处理(例如关键词检测)

if detect_keyword(embeddings):

print("检测到关键词!")

# 拼接所有输出

final_embeddings = torch.cat(all_embeddings, dim=1)

内存占用分析:

python

复制代码

# 流式推理需要保存的张量:

# 1. 当前chunk: (1, 16, 80) ≈ 5KB

# 2. Conv输出: (1, 1, 512) ≈ 2KB

# 3. 每层输出: (1, 1, 512) ≈ 2KB * 12层 = 24KB

# 4. LSTM状态: (12, 1, 1024) * 2 ≈ 96KB

#

# 总计:约 127KB (单个chunk)

# 实际GPU显存占用:模型参数(~386MB) + 运行时(~1MB) ≈ 400MB

延迟分析:

python

复制代码

# 假设音频采样率16kHz,帧率100Hz(10ms per frame)

chunk_size = 16 # 帧

# 音频延迟

audio_latency = chunk_size * 10ms = 160ms

# 计算延迟(GPU推理)

compute_latency ≈ 5-10ms

# 总延迟

total_latency = 160ms + 10ms = 170ms

# 实时因子 (RTF)

RTF = compute_latency / audio_latency = 10ms / 160ms ≈ 0.06

# 结论:可以实时处理(RTF < 1)

4. 适用场景详解

4.1 训练模式的应用场景

✅ 场景1:模型训练

python

复制代码

# 离线训练脚本

import torch

from torch.utils.data import DataLoader

from lstm import RNN

# 数据集

train_dataset = AudioDataset(data_dir='train')

train_loader = DataLoader(

train_dataset,

batch_size=32,

shuffle=True, # 打乱样本顺序

num_workers=4,

collate_fn=collate_fn # 处理变长序列

)

# 模型

model = RNN(num_features=80, d_model=512, num_encoder_layers=12)

model.train()

# 训练循环

for epoch in range(num_epochs):

for batch in train_loader:

x, x_lens, targets = batch

# 前向传播

embeddings, lengths, _ = model(x, x_lens, warmup=epoch/100)

# 计算损失

loss = criterion(embeddings, targets, lengths)

# 反向传播

optimizer.zero_grad()

loss.backward()

optimizer.step()

适用条件:

✅ 有大量标注数据

✅ 有GPU资源

✅ 可以批量处理

✅ 无实时要求

✅ 场景2:离线批量评估

python

复制代码

# 批量评估脚本

model.eval()

test_loader = DataLoader(test_dataset, batch_size=16)

all_predictions = []

with torch.no_grad():

for batch in test_loader:

x, x_lens = batch

embeddings, lengths, _ = model(x, x_lens)

# 后续处理(如解码)

predictions = decoder(embeddings, lengths)

all_predictions.extend(predictions)

# 计算指标

accuracy = compute_accuracy(all_predictions, ground_truth)

适用条件:

✅ 处理录音文件集合

✅ 无实时要求

✅ 可以批量处理提高效率

✅ 场景3:研究实验

python

复制代码

# 对比不同配置

configs = [

{'num_layers': 6, 'd_model': 256},

{'num_layers': 12, 'd_model': 512},

{'num_layers': 18, 'd_model': 768},

]

for config in configs:

model = RNN(**config)

train_and_evaluate(model)

适用条件:

✅ 需要快速迭代实验

✅ 对比不同超参数

✅ 分析模型行为

4.2 流式推理的应用场景

✅ 场景1:语音助手

python

复制代码

# 智能音箱/手机语音助手

class VoiceAssistant:

def __init__(self):

self.model = load_model()

self.model.eval()

self.states = self.model.get_init_states(1, device)

def process_audio_stream(self):

"""处理实时音频流"""

mic = Microphone()

while True:

# 从麦克风获取chunk(例如160ms音频)

chunk = mic.read_chunk()

# 特征提取

features = extract_features(chunk) # (1, 16, 80)

# 模型推理

with torch.no_grad():

embeddings, _, new_states = self.model(

features,

torch.tensor([16]),

states=self.states

)

# 更新状态

self.states = new_states

# 关键词检测

if keyword_detector(embeddings) == "小爱同学":

self.wake_up()

self.reset_states() # 唤醒后重置

关键要求:

⚡ 低延迟 (< 200ms)

📱 边缘设备(手机、音箱)

🔄 连续处理音频流

💾 内存受限

✅ 场景2:实时字幕系统

python

复制代码

# 视频会议/直播实时字幕

class RealtimeTranscriber:

def __init__(self):

self.encoder = RNN(...)

self.decoder = Decoder(...)

self.states = self.encoder.get_init_states(1, device)

def transcribe_stream(self, audio_stream):

"""实时转录音频流"""

for chunk in audio_stream:

# 编码

embeddings, _, new_states = self.encoder(

chunk, chunk_lens, states=self.states

)

self.states = new_states

# 解码

text = self.decoder(embeddings)

# 实时显示

display_subtitle(text)

yield text

关键要求:

⚡ 实时响应

📺 流媒体场景

🔄 连续输出文本

✅ 场景3:电话客服系统

python

复制代码

# 智能客服语音识别

class CallCenterASR:

def __init__(self):

self.model = RNN(...)

self.sessions = {} # 每个通话维护独立状态

def handle_call(self, call_id, audio_stream):

"""处理电话音频流"""

# 为新通话初始化状态

if call_id not in self.sessions:

self.sessions[call_id] = {

'states': self.model.get_init_states(1, device),

'transcript': []

}

session = self.sessions[call_id]

for chunk in audio_stream:

# 处理音频chunk

embeddings, _, new_states = self.model(

chunk, chunk_lens, states=session['states']

)

# 更新状态

session['states'] = new_states

# 识别文本

text = recognize(embeddings)

session['transcript'].append(text)

# 意图理解

intent = understand_intent(text)

response = generate_response(intent)

yield response

def end_call(self, call_id):

"""通话结束,清理状态"""

del self.sessions[call_id]

关键要求:

📞 多路并发(多个通话同时进行)

💾 每个通话独立状态

⚡ 低延迟响应

✅ 场景4:边缘设备部署

python

复制代码

# 嵌入式设备(如树莓派)

class EdgeKWS:

"""边缘设备关键词识别"""

def __init__(self, model_path):

# 加载量化/压缩的模型

self.model = load_quantized_model(model_path)

self.model.eval()

self.states = self.model.get_init_states(1, 'cpu')

def detect_keyword(self, audio_stream):

"""在边缘设备上运行"""

for chunk in audio_stream:

# CPU推理

with torch.no_grad():

embeddings, _, new_states = self.model(

chunk, chunk_lens, states=self.states

)

self.states = new_states

# 关键词检测

if is_keyword(embeddings):

return True

return False

关键要求:

💾 内存极度受限 (< 100MB)

🔋 功耗受限

🚫 无网络连接(离线工作)

📱 CPU推理

四、代码示例

训练模式完整示例

python

复制代码

"""

完整的训练脚本示例

包含数据加载、训练循环、验证、保存模型等

"""

import torch

import torch.nn as nn

from torch.utils.data import DataLoader

from lstm import RNN

# ============================================================================

# 1. 数据准备

# ============================================================================

class AudioDataset(torch.utils.data.Dataset):

"""音频数据集"""

def __init__(self, data_dir, manifest_file):

self.data = self.load_manifest(manifest_file)

def __len__(self):

return len(self.data)

def __getitem__(self, idx):

# 加载音频特征

features = load_features(self.data[idx]['audio_path']) # (T, 80)

targets = self.data[idx]['targets']

return features, targets

def collate_fn(batch):

"""处理变长序列"""

features_list, targets_list = zip(*batch)

# 获取最大长度

max_len = max(f.size(0) for f in features_list)

batch_size = len(features_list)

# Padding

features_padded = torch.zeros(batch_size, max_len, 80)

feature_lens = torch.zeros(batch_size, dtype=torch.long)

for i, feat in enumerate(features_list):

length = feat.size(0)

features_padded[i, :length] = feat

feature_lens[i] = length

return features_padded, feature_lens, targets_list

# 创建数据加载器

train_dataset = AudioDataset('data/train', 'train.json')

train_loader = DataLoader(

train_dataset,

batch_size=32,

shuffle=True,

num_workers=4,

collate_fn=collate_fn,

pin_memory=True # 加速GPU传输

)

val_dataset = AudioDataset('data/val', 'val.json')

val_loader = DataLoader(

val_dataset,

batch_size=16,

shuffle=False,

collate_fn=collate_fn

)

# ============================================================================

# 2. 模型创建

# ============================================================================

model = RNN(

num_features=80,

subsampling_factor=4,

d_model=512,

dim_feedforward=2048,

rnn_hidden_size=1024,

num_encoder_layers=12,

dropout=0.1,

layer_dropout=0.075,

aux_layer_period=3, # 启用RandomCombine

)

# 移动到GPU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device)

print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

# ============================================================================

# 3. 优化器和损失函数

# ============================================================================

# 优化器

optimizer = torch.optim.Adam(

model.parameters(),

lr=1e-3,

betas=(0.9, 0.98),

eps=1e-9

)

# 学习率调度器

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(

optimizer,

mode='min',

factor=0.5,

patience=3,

verbose=True

)

# 损失函数(示例:CTC Loss)

criterion = nn.CTCLoss(blank=0, reduction='mean')

# ============================================================================

# 4. 训练函数

# ============================================================================

def train_epoch(model, data_loader, optimizer, criterion, epoch, total_epochs):

"""训练一个epoch"""

model.train()

total_loss = 0

num_batches = len(data_loader)

for batch_idx, (features, feature_lens, targets) in enumerate(data_loader):

# 移动到设备

features = features.to(device)

feature_lens = feature_lens.to(device)

# 计算warmup

# 前10个epoch从0增长到1

warmup = min(1.0, epoch / 10.0)

# 前向传播

embeddings, lengths, _ = model(

features,

feature_lens,

states=None, # 训练不需要状态

warmup=warmup

)

# 准备CTC Loss的输入

# embeddings: (N, T, d_model) -> (T, N, d_model)

log_probs = torch.log_softmax(embeddings.transpose(0, 1), dim=-1)

# 计算损失

loss = criterion(log_probs, targets, lengths, target_lengths)

# 反向传播

optimizer.zero_grad()

loss.backward()

# 梯度裁剪

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

# 更新参数

optimizer.step()

# 统计

total_loss += loss.item()

# 打印进度

if (batch_idx + 1) % 10 == 0:

avg_loss = total_loss / (batch_idx + 1)

print(f"Epoch [{epoch}/{total_epochs}] "

f"Batch [{batch_idx+1}/{num_batches}] "

f"Loss: {loss.item():.4f} "

f"Avg Loss: {avg_loss:.4f} "

f"Warmup: {warmup:.2f}")

return total_loss / num_batches

# ============================================================================

# 5. 验证函数

# ============================================================================

def validate(model, data_loader, criterion):

"""验证模型"""

model.eval()

total_loss = 0

num_batches = len(data_loader)

with torch.no_grad():

for features, feature_lens, targets in data_loader:

features = features.to(device)

feature_lens = feature_lens.to(device)

# 前向传播(推理模式)

embeddings, lengths, _ = model(

features,

feature_lens,

states=None,

warmup=1.0 # 验证时warmup=1.0

)

# 计算损失

log_probs = torch.log_softmax(embeddings.transpose(0, 1), dim=-1)

loss = criterion(log_probs, targets, lengths, target_lengths)

total_loss += loss.item()

avg_loss = total_loss / num_batches

return avg_loss

# ============================================================================

# 6. 主训练循环

# ============================================================================

def main():

num_epochs = 100

best_val_loss = float('inf')

for epoch in range(1, num_epochs + 1):

print(f"\n{'='*60}")

print(f"Epoch {epoch}/{num_epochs}")

print(f"{'='*60}")

# 训练

train_loss = train_epoch(

model, train_loader, optimizer, criterion, epoch, num_epochs

)

# 验证

val_loss = validate(model, val_loader, criterion)

print(f"\nEpoch {epoch} Summary:")

print(f" Train Loss: {train_loss:.4f}")

print(f" Val Loss: {val_loss:.4f}")

# 学习率调度

scheduler.step(val_loss)

# 保存最佳模型

if val_loss < best_val_loss:

best_val_loss = val_loss

torch.save({

'epoch': epoch,

'model_state_dict': model.state_dict(),

'optimizer_state_dict': optimizer.state_dict(),

'val_loss': val_loss,

}, 'best_model.pt')

print(f" ✓ 保存最佳模型 (val_loss={val_loss:.4f})")

# 定期保存checkpoint

if epoch % 10 == 0:

torch.save({

'epoch': epoch,

'model_state_dict': model.state_dict(),

'optimizer_state_dict': optimizer.state_dict(),

}, f'checkpoint_epoch_{epoch}.pt')

if __name__ == '__main__':

main()

流式推理完整示例

python

复制代码

"""

完整的流式推理脚本示例

包含音频流处理、状态管理、实时关键词检测等

"""

import torch

import numpy as np

from lstm import RNN

# ============================================================================

# 1. 模型加载

# ============================================================================

def load_model(checkpoint_path, device):

"""加载训练好的模型"""

# 创建模型

model = RNN(

num_features=80,

d_model=512,

rnn_hidden_size=1024,

num_encoder_layers=12,

)

# 加载权重

checkpoint = torch.load(checkpoint_path, map_location=device)

model.load_state_dict(checkpoint['model_state_dict'])

# 设置评估模式

model.eval()

model = model.to(device)

print(f"✓ 模型加载成功")

return model

# ============================================================================

# 2. 音频流处理器

# ============================================================================

class AudioStreamProcessor:

"""音频流处理器"""

def __init__(self, model, device, chunk_size=16):

"""

Args:

model: RNN模型

device: 设备(CPU或GPU)

chunk_size: 每个chunk的帧数

"""

self.model = model

self.device = device

self.chunk_size = chunk_size

# 初始化状态

self.reset_states()

# 统计信息

self.total_chunks = 0

self.total_time = 0

def reset_states(self):

"""重置LSTM状态(新对话/新音频流时调用)"""

self.states = self.model.get_init_states(

batch_size=1,

device=self.device

)

print("✓ 状态已重置")

def process_chunk(self, chunk):

"""

处理单个音频chunk

Args:

chunk: 音频特征,形状 (chunk_size, 80) 或 (1, chunk_size, 80)

Returns:

embeddings: 编码后的特征 (1, T', 512)

lengths: 输出长度

"""

# 确保形状正确

if chunk.dim() == 2:

chunk = chunk.unsqueeze(0) # (chunk_size, 80) -> (1, chunk_size, 80)

# 获取实际长度

chunk_len = chunk.size(1)

chunk_lens = torch.tensor([chunk_len], device=self.device)

# 移动到设备

chunk = chunk.to(self.device)

# 推理

import time

start_time = time.time()

with torch.no_grad():

embeddings, lengths, new_states = self.model(

chunk,

chunk_lens,

states=self.states,

warmup=1.0

)

# 更新状态

self.states = new_states

# 统计

elapsed = time.time() - start_time

self.total_chunks += 1

self.total_time += elapsed

return embeddings, lengths

def get_stats(self):

"""获取统计信息"""

avg_time = self.total_time / self.total_chunks if self.total_chunks > 0 else 0

# 计算实时因子

# chunk_size帧 @ 100fps = chunk_size * 10ms

audio_duration = self.chunk_size * 0.01 # 秒

rtf = avg_time / audio_duration if audio_duration > 0 else 0

return {

'total_chunks': self.total_chunks,

'total_time': self.total_time,

'avg_time_per_chunk': avg_time,

'rtf': rtf

}

# ============================================================================

# 3. 音频流生成器

# ============================================================================

def audio_stream_from_file(audio_file, chunk_size=16):

"""

从音频文件生成chunk流(模拟实时流)

Args:

audio_file: 音频文件路径

chunk_size: chunk大小(帧数)

Yields:

chunk: (chunk_size, 80)

"""

# 加载音频特征(假设已经提取好)

# 实际应用中需要实时提取特征

features = np.load(audio_file) # (T, 80)

print(f"音频总长度: {len(features)} 帧 ({len(features)*0.01:.2f} 秒)")

print(f"Chunk大小: {chunk_size} 帧 ({chunk_size*0.01:.2f} 秒)")

print(f"总chunk数: {len(features) // chunk_size}")

print()

# 分chunk

for i in range(0, len(features), chunk_size):

chunk = features[i:i+chunk_size]

# 最后一个chunk可能不足,需要padding

if len(chunk) < chunk_size:

chunk = np.pad(

chunk,

((0, chunk_size - len(chunk)), (0, 0)),

mode='constant'

)

# 转换为tensor

chunk = torch.from_numpy(chunk).float()

yield chunk

# 模拟实时延迟(可选)

# import time

# time.sleep(chunk_size * 0.01)

# ============================================================================

# 4. 关键词检测器(示例)

# ============================================================================

class KeywordDetector:

"""简单的关键词检测器"""

def __init__(self, keywords, threshold=0.5):

self.keywords = keywords

self.threshold = threshold

self.keyword_classifier = self.load_classifier()

def load_classifier(self):

"""加载关键词分类器(示例)"""

# 实际应用中这里是一个分类器网络

# 这里简化为随机检测

return lambda x: np.random.rand() > 0.95

def detect(self, embeddings):

"""

检测关键词

Args:

embeddings: 编码特征 (1, T', 512)

Returns:

detected: 是否检测到关键词

keyword: 检测到的关键词(如果有)

"""

# 简化的检测逻辑

score = self.keyword_classifier(embeddings)

if score:

return True, "小爱同学"

return False, None

# ============================================================================

# 5. 主流式推理流程

# ============================================================================

def main_streaming():

"""主流式推理函数"""

# 设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"使用设备: {device}\n")

# 加载模型

model = load_model('best_model.pt', device)

# 创建处理器

processor = AudioStreamProcessor(

model=model,

device=device,

chunk_size=16

)

# 创建关键词检测器

detector = KeywordDetector(keywords=["小爱同学", "你好"])

# 处理音频流

print("开始处理音频流...")

print("="*60)

audio_file = 'test_audio_features.npy'

all_embeddings = []

for chunk_idx, chunk in enumerate(audio_stream_from_file(audio_file)):

# 处理chunk

embeddings, lengths = processor.process_chunk(chunk)

# 保存结果

all_embeddings.append(embeddings)

# 关键词检测

detected, keyword = detector.detect(embeddings)

# 打印信息

if detected:

print(f"Chunk {chunk_idx:3d}: ✓ 检测到关键词 [{keyword}]")

else:

print(f"Chunk {chunk_idx:3d}: - 处理完成", end='\r')

print("\n" + "="*60)

print("处理完成!\n")

# 打印统计信息

stats = processor.get_stats()

print("统计信息:")

print(f" 总chunk数: {stats['total_chunks']}")

print(f" 总耗时: {stats['total_time']:.3f} 秒")

print(f" 平均每chunk: {stats['avg_time_per_chunk']*1000:.2f} ms")

print(f" 实时因子 (RTF): {stats['rtf']:.3f}")

if stats['rtf'] < 1.0:

print(f" ✓ 可以实时处理 (RTF < 1.0)")

else:

print(f" ✗ 无法实时处理 (RTF >= 1.0)")

# 拼接所有输出

final_embeddings = torch.cat(all_embeddings, dim=1)

print(f"\n最终输出形状: {final_embeddings.shape}")

# ============================================================================

# 6. 多会话管理示例(电话客服场景)

# ============================================================================

class MultiSessionManager:

"""多会话管理器(用于电话客服等场景)"""

def __init__(self, model, device):

self.model = model

self.device = device

self.sessions = {}

def create_session(self, session_id):

"""创建新会话"""

if session_id in self.sessions:

print(f"警告: 会话 {session_id} 已存在")

return

self.sessions[session_id] = {

'states': self.model.get_init_states(1, self.device),

'created_at': time.time(),

'chunk_count': 0

}

print(f"✓ 创建会话: {session_id}")

def process_chunk(self, session_id, chunk):

"""处理指定会话的chunk"""

if session_id not in self.sessions:

raise ValueError(f"会话 {session_id} 不存在")

session = self.sessions[session_id]

# 处理

chunk = chunk.to(self.device)

chunk_lens = torch.tensor([chunk.size(1)], device=self.device)

with torch.no_grad():

embeddings, lengths, new_states = self.model(

chunk,

chunk_lens,

states=session['states'],

warmup=1.0

)

# 更新会话状态

session['states'] = new_states

session['chunk_count'] += 1

return embeddings, lengths

def close_session(self, session_id):

"""关闭会话,释放资源"""

if session_id in self.sessions:

del self.sessions[session_id]

print(f"✓ 关闭会话: {session_id}")

def get_active_sessions(self):

"""获取活跃会话列表"""

return list(self.sessions.keys())

# 使用示例

def demo_multi_session():

device = torch.device('cuda')

model = load_model('best_model.pt', device)

manager = MultiSessionManager(model, device)

# 模拟3个并发通话

call_ids = ['call_001', 'call_002', 'call_003']

# 创建会话

for call_id in call_ids:

manager.create_session(call_id)

# 交替处理各个会话的音频

for i in range(100): # 模拟100个chunk

# 轮流处理各个会话

call_id = call_ids[i % 3]

chunk = torch.randn(1, 16, 80) # 模拟音频chunk

embeddings, lengths = manager.process_chunk(call_id, chunk)

# 后续处理...

# 关闭会话

for call_id in call_ids:

manager.close_session(call_id)

# ============================================================================

# 7. 主入口

# ============================================================================

if __name__ == '__main__':

# 单会话流式推理

main_streaming()

# 多会话示例

# demo_multi_session()

五、性能对比

1. 吞吐量对比

训练模式

配置:

Batch size: 32

Sequence length: 1000帧

GPU: NVIDIA V100

性能指标:

复制代码

处理速度: ~50 utterances/second

吞吐量: 50 * 1000 = 50,000 帧/秒

GPU利用率: 85-95%

显存占用: ~4GB

优势:

✅ 批量并行处理,GPU利用率高

✅ 吞吐量大,适合大规模数据处理

劣势:

❌ 延迟高(必须等待完整序列)

❌ 显存占用大

流式推理

配置:

Batch size: 1

Chunk size: 16帧

GPU: NVIDIA V100

性能指标:

复制代码

处理速度: ~1000 chunks/second

吞吐量: 1000 * 16 = 16,000 帧/秒

GPU利用率: 15-25%

显存占用: ~500MB

优势:

✅ 低延迟(实时处理)

✅ 显存占用小

劣势:

❌ GPU利用率低(单样本)

❌ 吞吐量较小

💡 结论:

训练模式适合离线批量处理

流式推理适合实时单样本处理

2. 延迟对比

训练模式延迟

复制代码

假设音频帧率 = 100 fps (10ms/frame)

序列长度: 1000帧

音频时长: 1000 / 100 = 10秒

处理时间: ~10秒 (取决于GPU性能)

端到端延迟 = 10秒 (必须等待完整序列)

实时因子 (RTF) = 10秒 / 10秒 = 1.0

特点:

延迟 = 整个序列的时长

不适合实时应用

适合离线处理

流式推理延迟

复制代码

Chunk大小: 16帧

音频时长: 16 / 100 = 0.16秒 = 160ms

处理时间: ~10ms (GPU推理)

端到端延迟 = 160ms + 10ms = 170ms

实时因子 (RTF) = 10ms / 160ms = 0.0625

延迟分解:

复制代码

1. 音频采集延迟: 160ms (chunk时长)

2. 特征提取延迟: ~5ms

3. 模型推理延迟: ~10ms

4. 后处理延迟: ~5ms

总延迟: 180ms

💡 结论:

流式推理延迟低(< 200ms)

RTF << 1,可以实时处理

适合实时应用

3. 内存占用对比

训练模式内存

复制代码

GPU显存占用:

1. 模型参数:

- 96.5M参数 × 4 bytes = 386 MB

2. 单个batch:

- 输入: (32, 1000, 80) × 4 bytes ≈ 10 MB

- Conv输出: (32, 247, 512) × 4 bytes ≈ 16 MB

- 12层LSTM输出: (247, 32, 512) × 4 bytes × 12 ≈ 192 MB

3. 梯度:

- 约等于参数量 ≈ 386 MB

4. 优化器状态 (Adam):

- 2倍参数量 ≈ 772 MB

总计: 386 + 218 + 386 + 772 ≈ 1762 MB ≈ 1.7 GB

实际显存占用: 2-4 GB (包括PyTorch overhead)

流式推理内存

复制代码

GPU显存占用:

1. 模型参数:

- 96.5M参数 × 4 bytes = 386 MB

2. 单个chunk:

- 输入: (1, 16, 80) × 4 bytes ≈ 5 KB

- Conv输出: (1, 1, 512) × 4 bytes ≈ 2 KB

- 12层输出: (1, 1, 512) × 4 bytes × 12 ≈ 24 KB

3. LSTM状态:

- Hidden: (12, 1, 512) × 4 bytes ≈ 24 KB

- Cell: (12, 1, 1024) × 4 bytes ≈ 48 KB

总计: 386 + 0.1 ≈ 386 MB

实际显存占用: 400-500 MB

💡 结论:

训练模式显存占用大(~3GB)

流式推理显存占用小(~400MB)

流式推理可以在低端GPU甚至CPU上运行

4. 计算效率对比

批量处理效率

Batch Size

吞吐量 (帧/秒)

GPU利用率

单样本延迟

1

1,000

15%

1s

8

7,500

45%

8s

16

14,000

70%

16s

32

25,000

90%

32s

64

38,000

95%

64s

观察:

Batch size越大,吞吐量越高

但延迟也线性增加

GPU利用率饱和点约在batch_size=32

Chunk大小影响

Chunk Size

延迟

RTF

下采样输出

8

80ms

0.125

可能为0 ⚠️

16

160ms

0.0625

1-2帧 ✓

32

320ms

0.031

3-4帧 ✓

64

640ms

0.016

7-8帧 ✓

建议:

推荐chunk_size=16-32

太小: 下采样后可能为0

太大: 延迟增加

六、最佳实践

训练模式最佳实践

1. Warmup调度

python

复制代码

# ✅ 推荐: 线性warmup

def get_warmup(step, warmup_steps=10000):

return min(1.0, step / warmup_steps)

# 使用

for step in range(total_steps):

warmup = get_warmup(step)

output = model(x, x_lens, warmup=warmup)

# ❌ 不推荐: 固定warmup

warmup = 0.5 # 不随训练变化

2. Batch Size选择

python

复制代码

# ✅ 推荐: 根据GPU显存动态调整

def find_optimal_batch_size(model, device):

batch_size = 64

while batch_size > 1:

try:

x = torch.randn(batch_size, 1000, 80, device=device)

_ = model(x, torch.full((batch_size,), 1000))

return batch_size

except RuntimeError: # OOM

batch_size //= 2

return 1

# ❌ 不推荐: 固定batch size可能OOM或浪费显存

batch_size = 128 # 可能OOM

3. 梯度累积

python

复制代码

# ✅ 推荐: 显存不足时使用梯度累积

accumulation_steps = 4

optimizer.zero_grad()

for i, batch in enumerate(dataloader):

loss = compute_loss(model, batch)

loss = loss / accumulation_steps

loss.backward()

if (i + 1) % accumulation_steps == 0:

optimizer.step()

optimizer.zero_grad()

4. 混合精度训练

python

复制代码

# ✅ 推荐: 使用混合精度加速训练

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:

optimizer.zero_grad()

with autocast():

output = model(x, x_lens)

loss = criterion(output, target)

scaler.scale(loss).backward()

scaler.step(optimizer)

scaler.update()

流式推理最佳实践

1. 状态管理

python

复制代码

# ✅ 推荐: 正确管理状态

class StreamingASR:

def __init__(self, model):

self.model = model

self.states = None

def start_utterance(self):

"""开始新utterance时重置状态"""

self.states = self.model.get_init_states(1, device)

def process_chunk(self, chunk):

if self.states is None:

self.start_utterance()

output, _, new_states = self.model(

chunk, chunk_lens, states=self.states

)

self.states = new_states

return output

def end_utterance(self):

"""结束utterance时清理状态"""

self.states = None

# ❌ 不推荐: 忘记管理状态

def process_stream(chunks):

# 错误: 每个chunk都从零状态开始

for chunk in chunks:

output = model(chunk, chunk_lens, states=None)

2. Chunk大小选择

python

复制代码

# ✅ 推荐: 根据延迟要求选择chunk大小

def choose_chunk_size(latency_requirement_ms, frame_rate_fps=100):

"""

根据延迟要求选择chunk大小

Args:

latency_requirement_ms: 延迟要求(毫秒)

frame_rate_fps: 帧率

Returns:

chunk_size: chunk大小(帧数)

"""

# 考虑下采样因子=4,需要至少9帧输入

min_chunk_size = 16 # 确保下采样后有输出

# 根据延迟计算最大chunk大小

max_chunk_size = int(latency_requirement_ms / (1000 / frame_rate_fps))

# 选择16的倍数(方便硬件优化)

chunk_size = min(max_chunk_size, 64)

chunk_size = max(chunk_size, min_chunk_size)

chunk_size = (chunk_size // 16) * 16

return chunk_size

# 示例

chunk_size = choose_chunk_size(latency_requirement_ms=200)

print(f"Chunk大小: {chunk_size} 帧")

3. 内存优化

python

复制代码

# ✅ 推荐: 流式推理时禁用梯度

model.eval()

for param in model.parameters():

param.requires_grad = False

with torch.no_grad():

for chunk in stream:

output = model.process(chunk)

# ✅ 推荐: 使用inplace操作

torch.backends.cudnn.benchmark = True

# ✅ 推荐: 量化模型(如果精度允许)

quantized_model = torch.quantization.quantize_dynamic(

model, {nn.Linear, nn.LSTM}, dtype=torch.qint8

)

4. 多线程/多进程

python

复制代码

# ✅ 推荐: 音频处理和模型推理分离

import queue

from threading import Thread

def audio_capture_thread(audio_queue):

"""音频采集线程"""

while True:

chunk = capture_audio()

audio_queue.put(chunk)

def inference_thread(audio_queue, result_queue):

"""推理线程"""

states = model.get_init_states(1, device)

while True:

chunk = audio_queue.get()

output, _, new_states = model(chunk, states=states)

states = new_states

result_queue.put(output)

# 启动

audio_q = queue.Queue(maxsize=10)

result_q = queue.Queue(maxsize=10)

Thread(target=audio_capture_thread, args=(audio_q,)).start()

Thread(target=inference_thread, args=(audio_q, result_q)).start()

七、常见问题

Q1: 流式推理的结果和训练时不一致?

原因:

RandomCombine在训练和推理时行为不同

Layer Dropout在训练时有随机性

Dropout层的影响

解决:

python

复制代码

# 确保设置为评估模式

model.eval()

# 或者在训练时也测试流式推理

model.eval()

with torch.no_grad():

# 流式推理测试

...

model.train()

Q2: 流式推理时chunk边界有断裂感?

原因 :

卷积下采样在chunk边界可能损失信息

解决:使用重叠chunk

python

复制代码

# ✅ 使用重叠

chunk_size = 16

overlap = 4 # 重叠4帧

for i in range(0, len(audio), chunk_size - overlap):

chunk = audio[i:i+chunk_size]

output = process(chunk)

# 只使用中间部分,丢弃边界

valid_output = output[:, overlap//2:-overlap//2, :]

Q3: 多会话时显存不足?

解决:

python

复制代码

# 1. 限制并发会话数

MAX_SESSIONS = 100

# 2. 自动清理长时间未活动的会话

def cleanup_inactive_sessions(sessions, timeout=300):

now = time.time()

for sid, session in list(sessions.items()):

if now - session['last_active'] > timeout:

del sessions[sid]

# 3. 使用CPU推理

model = model.cpu()

Q4: 如何加速流式推理?

方法:

模型量化

python

复制代码

quantized_model = torch.quantization.quantize_dynamic(

model, {nn.LSTM, nn.Linear}, dtype=torch.qint8

)

模型剪枝

python

复制代码

# 减少层数

model_small = RNN(num_encoder_layers=6) # 从12减少到6

使用TorchScript

python

复制代码

traced = torch.jit.trace(model, (example_input, example_lens))

traced.save('model_traced.pt')

ONNX导出

python

复制代码

torch.onnx.export(model, (x, x_lens), 'model.onnx')

八、总结

核心差异

特性

训练模式

流式推理

目标

学习参数

实时输出

状态

不需要

必须维护

随机性

延迟

吞吐量

内存

选择建议

使用训练模式:

✅ 模型训练

✅ 离线批量评估

✅ 研究实验

✅ 数据分析

使用流式推理:

✅ 实时应用(语音助手)

✅ 边缘设备

✅ 低延迟要求

✅ 内存受限场景

关键要点

状态管理是流式推理的核心

必须正确维护和传递LSTM状态

新对话需要重置状态

训练和推理的网络行为不同

Layer Dropout只在训练时有效

RandomCombine只在训练时启用

性能权衡

训练模式: 高吞吐、高延迟、高内存

流式推理: 低延迟、低内存、中等吞吐

正确设置模式

训练: model.train()

推理: model.eval() + torch.no_grad()

最新发表
友情链接