量化友好的SR-STE稀疏化训练

N:M稀疏化简介

概念

N:M sparsity是目前在深度网络内常用的一种优化策略,用于削减参数量减轻显存消耗并提高推理速度。其中,N:M代表在M个连续参数内只有有N个参数不为0。目前稀疏可以采用在各种layer内, 包括切不限于GEMM/Liner/Conv等.

与剪枝的区别

稀疏化也可以算作prune的一种,与一般的结构化剪枝例如channel prune不同的是,稀疏化可以看作是一种基于模式而不是基于结构的剪枝(稀疏),其并不会直接裁剪整个通道,而是对完整的参数矩阵进行mask. 以conv层为例, channel prune是沿着卷积层weight(chw)中的C维度(output channel)进行剪裁, 假设原conv kernel weight有128个channel, 剪裁50%后
便仅剩下64个channel. 再不进行微调/重训练的前提下, 重要的channel内的参数被全部保留, 不重要的channel内的参数被全部剔除. 而对于4:2稀疏来说, chww维度的weight会以4为大小分组, 每4个连续的单参数内至多只有两个参数不为0. 这种pattern是均匀地应用在所有channel上的. 借用MIT6.5940内的一张图来表示稀疏和剪枝的关系:

各类不同剪枝方式示意图

事实上, 在视觉模型的工业界, channel prune一般会在训练阶段做. 在训练阶段发现模型规模太大便可以考虑做channel prune. 但prune后需要大规模重训练, 精度掉点也会比较严重, 但channel prune的结果是直接体现在模型结构上的, 不需要硬件特别支持也可以得到性能提升. 可以考虑用NAS(子网搜索)的方式去寻找最合适prune pattern. 稀疏化同样也需要微调, 但一般来说精度损失相对更低, 并且稀疏模式是可以学习的, 相比channel prune其需要硬件的支持, 但一样也能带来很可观的收益.

其实关于稀疏和剪枝的区别, 网上有很多文章可供参考, 有不少文章直接将N:M这类稀疏pattern直接称为剪枝. 但更多的文章会单独把稀疏拉出来当做一种模型优化方法. 在这篇博客里不打算对这些有过多探讨, 读者只需要知道N:M稀疏化的概念即可.

硬件加速原理

N:M 稀疏化的implementation必须要有硬件支持,NVIDIA Ampere/Hopper/Blackwell 架构的GPU支持的稀疏模式为4:2,即每四个连续的参数内只有两个参数不为0,稀疏比率为50%,理论上可以带来接近一倍的显存性能释放。NVIDIA为结构化稀疏提供了高效的内存访问模式,可以完成推理加速,并且4:2稀疏化模式的精度也可以在深度CNN网络中得到良好保障。

NVIDA Apex内的静态掩码稀疏

NV的GPU提供了对稀疏的硬件支持, 自然其需要提供相对应的稀疏化工具库. NVIDIA Apex便是NV提供的一个专门为Pytorch支持的提供训练测优化的工具库. 其中有些代码甚至已经被Pytorch官方认可并加入源码. 参考这篇NV的博客与apex库内的的ASP(Automatic SParsity)简介文档. 使用Apex库对Resnet50进行2:4稀疏的示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import copy
import torchvision
from apex.contrib.sparsity import ASP

# Load dense model
resnet50_dense = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")

# Initialize sparsity mode before starting sparse training
resnet50_sparse = copy.deepcopy(resnet50_dense)
ASP.prune_trained_model(resnet50_sparse, optimizer)

# Re-train model
for e in range(0, epoch):
for i, (image, target) in enumerate(data_loader):
image, target = image.to(device), target.to(device)
output = resnet50_sparse(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Save model
torch.save(resnet50_sparse.state_dict(), "sparse_finetuned.pth")

实际上, 执行上述代码时, apex会默认使用所谓的”exhaustive search”方式去寻找每个可稀疏化层的最合适的稀疏化模式(默认为2:4模式). 通过查阅apex的源码, 这个”exhaustive method”最终调用的是sparse_masklib.py内的m4n2_1d(mat, density)方法. 其对应的源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
""" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) """
def reshape_1d(matrix, m):
# If not a nice multiple of m, fill with zeroes.
if matrix.shape[1] % m > 0:
mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m-matrix.shape[1]%m)).fill_(0)
mat[:, :matrix.shape[1]] = matrix
shape = mat.shape
return mat.view(-1,m),shape
else:
return matrix.view(-1,m), matrix.shape

""" return all possible m:n patterns in a 1d vector """
valid_m4n2_1d_patterns = None
def compute_valid_1d_patterns(m,n):
# Early exit if patterns was already created.
global valid_m4n2_1d_patterns

if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
valid_patterns = torch.tensor(list(set(permutations(patterns.tolist()))))
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
return valid_patterns

""" m:n 1d structured best """
def mn_1d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_1d_patterns(m,n).cuda()

# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)
mat,shape = reshape_1d(matrix,m)
pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask

def m4n2_1d(mat, density):
return mn_1d_best(mat, 4, 2)

不难看出, 这段代码的入口函数m4n2_1d主要用途是给定一个参数矩阵与给定密度(一定是50%), 返回其对应的掩码矩阵(mask). 这里的mask是一个参数矩阵对应的稀疏化掩码, 掩码矩阵内为1的位置对应原参数矩阵位置的参数会保留, 为0的位置对应会被置0. 寻找mask矩阵的方式便是列出每四个元素可能组成的2:4 pattern, 随后计算参数矩阵内的每4个元素(的绝对值)在不同pattern下的L1 Norm(即简单相乘相加). 随后用argmax找到每四个元素对应的L1 Norm最大的pattern作为对应位置的最终掩码结果. 当然计算过程是直接使用的矩阵相乘, 示意图如下:

L1 Norm方式寻找最佳mask矩阵

至于为何使用L1 Norm计算掩码矩阵, 在神经网络内理论上参数的绝对值越大(scale越大), 其重要性也越大, 所以用此种方式保留那些绝对值较大的参数. 在得到掩码矩阵后自然便可以开始进行微调使得模型参数自调整去适应该pattern. 在训练过程中此掩码矩阵不会再进行改动. 最终微调得到的稀疏模式和初始搜寻模式时得到的相同. 这种做法在最初寻找pattern时会付出较大的计算代价, 同时在训练时由于固定了pattern, 使得模型参数必须去适应固定的掩码矩阵, 在微调时效果可能会欠佳.

基于直通估计的动态掩码精炼化稀疏(SR-STE)

2021年, Zhou Aj.等人发表了论文Learning N:M Fine-grained Structured Sparse Neural Networks From Scratch. 在其中提到了Sparse-refined straight-through estimator. 基于这篇论文我们可以自己实现各种算子对应的稀疏化版本. 同时由于伪量化算子在做梯度下降时同样也会使用straight-through estimator. 我们进一步可以把稀疏化和量化放在一起实现伪INT8稀疏算子. TensorRT支持诸如卷积、线性层这类常用层的INT8稀疏实现, 对于常见视觉模型, 将稀疏和量化结合起来可以使模型加速数倍不止.

基于直通估计的稀疏化可以用论文内的一张图来表示其原理, 如下图所示. 注意图中的红色虚线部分, 在左侧即普通的直通算子内, mask并不参与梯度计算的过程, 直接将由loss计算得到的梯度return回去(直通). 而图右的SR-STE部分, mask参与gradient计算. 同时, mask还会在正向传播时重新计算.

SR-STE的正向计算与反向梯度传播过程

论文关联的github仓库内给出了SR-STE的forward和backward方法的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@staticmethod
def forward(ctx, weight, N, M, decay = 0.0002):
ctx.save_for_backward(weight)

output = weight.clone()
length = weight.numel()
group = int(length/M)

weight_temp = weight.detach().abs().reshape(group, M)
index = torch.argsort(weight_temp, dim=1)[:, :int(M-N)]

w_b = torch.ones(weight_temp.shape, device=weight_temp.device)
w_b = w_b.scatter_(dim=1, index=index, value=0).reshape(weight.shape)
ctx.mask = w_b
ctx.decay = decay

return output*w_b

@staticmethod
def backward(ctx, grad_output):
weight, = ctx.saved_tensors
return grad_output + ctx.decay * (1-ctx.mask) * weight, None, None

可见, 对于掩码矩阵内为0的部分(索引为i), 回传的梯度会加上一个附带的decay * weight_i部分. 同时在forward函数内, mask每次会基于weight的L1 Norm重新计算. 掩码矩阵内为0的参数部分由于得到了额外的梯度所以会更倾向于改变自身的值, 在正向过程中mask又会基于改变后的值重新计算得到. 这样就形成了一个反馈过程, 最终使得mask矩阵和参数都会在训练过程中收敛, 并达到较好的精度.

TL;DR: SR-STE相比普通的STE, 在反向传播中增加了基于掩码矩阵的部分, 使得在训练过程中动态更新掩码矩阵成为了可能. 而相比静态掩码矩阵+普通STE, 动态的掩码矩阵更能适应训练过程时的参数变化, 使得最终得到的稀疏pattern更友好, 在收敛的情况下理论上获得的精度也更高.

SR-STE与INT8量化的合并

如果大家有了解过INT8量化伪算子在pytorch内的实现原理, 可以知道量化算子的反向传播也是基于直通估计得到的, 也就是说所谓的伪量化算子的反向传播方法会直接返回输入的梯度. 如果我们想要让QAT和稀疏微调同时进行, 那便可以考虑自定义对应的量化SR-STE算子. 实际过程并不复杂, 我们只需要参考一下pytorch-quantization内对量化算子的实现, 随后将两者结合起来即可.

废话不说直接上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch import autograd, nn
import torch.nn.functional as F
from pytorch_quantization.nn import QuantConv2d
from torch.nn.modules.utils import _pair

class Sparse_NHWC(autograd.Function):
"""" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using SR-STE in the backwards phase"""

@staticmethod
def forward(ctx, weight, N, M, decay = 0.0002):

ctx.save_for_backward(weight)
output = weight.clone()
length = weight.numel()
group = int(length/M)

weight_temp = weight.detach().abs().permute(0,2,3,1).reshape(group, M)
index = torch.argsort(weight_temp, dim=1)[:, :int(M-N)]

w_b = torch.ones(weight_temp.shape, device=weight_temp.device)
w_b = w_b.scatter_(dim=1, index=index, value=0).reshape(weight.permute(0,2,3,1).shape)
w_b = w_b.permute(0,3,1,2)

ctx.mask = w_b

return output*w_b

@staticmethod
def backward(ctx, grad_output):

weight, = ctx.saved_tensors
return grad_output + ctx.decay * (1-ctx.mask) * weight, None, None


class SparseQuantConv2d(QuantConv2d):
"""
继承于pytorch_quantization内的QuantCov2d
其源码位于pytorch_quantization -> nn -> modules -> quant_conv.py
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
N=2,
M=4,
**kwargs):
self.N = N
self.M = M
super(SparseQuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kwargs)

def get_sparse_weights(self):
return Sparse_NHWC.apply(self.weight, self.N, self.M)


def forward(self, input):
"""
稀疏化算子和量化Function的forward和backword都已经完成定义
因此只需要将该module的forward定义好即可

注意要保留原有的self.weight
"""


# 获取稀疏化后的weight
sparsed_weight = self.get_sparse_weights()

# 获取量化后的weight
quant_input = self._input_quantizer(input)
quant_weight = self._weight_quantizer(sparsed_weight)

# 直接使用F.conv2d完成conv计算
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2)
output = F.conv2d(F.pad(quant_input, expanded_padding, mode='circular'),
quant_weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
else:
output = F.conv2d(quant_input, quant_weight, self.bias, self.stride, self.padding, self.dilation,
self.groups)

完成后便可以尝试将一般CNN网络中的conv层替换为SR-STE并稀疏化后的版本. 建议先使用pytorch-quantization完成量化与校准后再替换为稀疏版本随后开始微调.

TensorRT是允许很多算子同时应用有稀疏与量化的, 对于显式插有QDQ的onnx, 在trtexec中开启flag --sparsity=enable后trt会同时传播INT8并应用那些具有N:M稀疏模式的参数层. 实际应用时可以考虑先使用原始onnx(没有显式量化与稀疏)导出一个engine并在日志中查看tensorRT推荐的稀疏层. 随后在完成
显式量化后再将trt推荐的层替换为SparseQuant版本并完成微调.


量化友好的SR-STE稀疏化训练
https://blog.bakeneko-kuro.com/2025/09/16/hpc/sparse-refined-straight-through-estimator/
作者
迷途黑猫
发布于
2025年9月16日
许可协议