当前位置: 主页 > 资讯中心 > 常见问题 » GEMM算法及优化流程详解
目录
神经网络前向耗时主要由卷积的耗时决定,参考賈杨青毕业论文,那么如何对卷积加速便成了重要的一个点,主流的加速方法有
以下几种:
im2col+GEMM:目前几乎所有的主流计算框架包括 Caffe, MXNet 等都实现了该方法. 该方法把整个卷积过程转化成了GEMM过程,而GEMM在各种 BLAS 库中都是被极致优化的,一般来说,速度较快。
Winograd: Winograd 是存在已久最近被重新发现的方法,在大部分场景中, Winograd方法都显示和较大的优势,目前cudnn中计算卷积就使用了该方法。
Strassen:1969年,Volker Strassen提出了第一个时间复杂度低于O(N^3)的算法,其复杂度为O(N^(2^(log2(7)))),但这种方法只在大卷积核情况下优势才比较明显,目前还没有在开源框架中见到这种方法。
FFT:傅里叶变换和快速傅里叶变化是在经典图像处理里面经常使用的计算方法,但是,在 ConvNet中通常不采用,主要是因为在 ConvNet 中的卷积模板通常都比较小,例如?3×3?等,这种情况下,FFT 的时间开销反而更大,所以很少在CNN中利用FFT实现卷积。
很高兴你看完前言:最近发现这篇文章写的很好,阿里那边的,《支付宝如何优化移动端深度学习引擎》推荐给大家~
?
GEMM在深度学习中是十分重要的,全连接层以及卷积层基本上都是通过GEMM来实现的,而网络中大约90%的运算都是在这两层中。而一个良好的GEMM的实现可以充分利用系统的多级存储结构和程序执行的局部性来充分加速运算。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?常规的卷积操作为:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 3维卷积运算执行完毕,得一个2维的平面:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
将卷积操作的3维立体变为二维矩阵乘法,可以调用BLAS中的GEMM库,按 [kernel_height, kernel_width, kernel_depth] ? 将输入分成 3 维的 patch,并将其展成一维向量:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
此时的卷积操作就可转化为矩阵乘法:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
下面我们将以M=K=N=600为例说明GEMM算法的优化过程:
?
直接暴力卷积:
上述公式总计算量为2MNK FLOPs(其中 𝑀、𝑁、𝐾 分别指代三层循环执行的次数,2 指代循环最内层的一次乘法和加法) ,内存访问操作总数为 4MNK(其中 2MNK 指代对 𝐶 的内存访问,𝐶 需要先读取内存、累和再存储)。GEMM 的优化均以此为基点。
耗时分析:上述暴力gemm代码耗时约为872ms
?
首先能想到的就是减少C矩阵的访存次数,将C[m][n]放到外面,全部累和之后再赋值即可:
上述公式总计算量依然为2MNK FLOPs,内存访问操作总数为 2MNK+2MN(其中 2MN?指代对 𝐶 的内存访问,𝐶 需要先读取内存、累加完毕在存储)。
耗时分析:上述代码耗时约为791ms,耗时变少的原因是减少了部分C的访存
?
将输出的计算拆分为 1×4 的小块,即将 𝑁 维度拆分为两部分。计算该块输出时,需要使用 𝐴 矩阵的 1 行,和 𝐵 矩阵的 4 列。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?图一:矩阵乘计算?1×4输出
下面是该计算的伪代码表示,这里已经将 1×4 中 N 维度的内部拆分进行了展开。这里的计算量仍然是 2𝑀𝑁𝐾 ,这一点在本文中不会有变化。
简单的观察即可发现,上述伪代码的最内侧计算使用的矩阵 𝐴 的元素是一致的。因此可以将 𝐴[𝑚][𝑘] 读取到寄存器中,从而实现 4 次数据复用(这里不再给出示例)。一般将最内侧循环称作计算核(micro kernel)。进行这样的优化后,内存访问操作数量变为 2MN+5/4MNK,访存约为上面的5/8。
耗时分析:本优化耗时约为473ms,相比暴力耗时减少300ms左右,可能的两个原因:1、由于B是行优先排列,1x4方法能够减少数据从内存到cache的加载次数;2、合理利用寄存器,减少对𝐴矩阵访存次数
?
类似地,我们可以继续拆分输出的 𝑀 维度,从而在内侧循环中计算 4×4 输出,如图二。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?图二:矩阵乘计算?4×4输出
同样地,将计算核心展开,可以得到下面的伪代码。由于乘数效应,4×4 的拆分可以将对输入数据的访存缩减到 MN/16*(16*2+8K)=2MN+1/2*MNK。这相对于最开始的 4MNK 已经得到了 8X 的改进,这些改进都是通过展开循环后利用寄存器存储数据减少访存得到的。
耗时分析:本优化耗时约为354ms,相比1x4耗时减少120ms左右