SIMD学习 -- 用SSE2指令作点乘和累加计算

 这几天在做学校的一个学习小项目,需要用到SIMD指令计算提速。也是第一次碰这个,看了一些资料和代码,模仿着写了两个函数。

void sse_mul_float(float *A, float *B, int  cnt):两段内存float数据点乘,结果覆盖第一组内存。

float sse_acc_float(float *A, int cnt):一组内存float值累加。

注:

1. 没有考虑中间的精确问题,结果会有误差。

2. 每个函数包括指令操作部分和C++语句计算部分。本文附的代码注释介绍指令部分思路。

**3. 关于内存对齐,我不是很懂,所以下面的代码中判断是否对齐的相关语句我写的也不是很正确,所有后面都补上了一点C++的明白操作。

因此,有些指令操作也许没用上。

头文件

#include "time.h"
#include "stdafx.h"
#include
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
View Code

 

 

sse_mul_float asm部分

  1 //MOV EAX,1               ;request CPU feature flags
  2     //CPUID                   ;0Fh, 0A2h CPUID instruction
  3     //TEST EDX,4000000h       ;test bit 26 (SSE2)
  4     //JNZ >L18                ;SSE2 available
  5 
  6     int cnt1;
  7     int cnt2;
  8     int cnt3;
  9 
 10     //we process  the majority by using SSE instructions
 11     if (((int)A % 16) || ((int)B % 16))      //如果内存不对齐
 12     {
 13 
 14         cnt1 = cnt / 16;                         //该loop一轮处理16个float*float
 15         cnt2 = (cnt - (16 * cnt1)) / 4;          //该loop一轮处理4个float*float
 16         cnt3 = (cnt - (16 * cnt1) - (4 * cnt2)); //该loop一轮处理1个float*float
 17 
 18         _asm
 19         {
 20 
 21             mov edi, A;                      //先将内存地址放入指针寄存器
 22             mov esi, B;    
 23             mov ecx, cnt1;                  //循环寄存器置值
 24             jecxz ZERO;                   //如果数据量不超过16个,则跳过L1
 25 
 26         L1:
 27 
 28                                           
 29             //xmm 寄存器有128bit
 30             //movups  XMM,XMM/m128
 31             //传128bit数据,不必对齐内存16字节.
 32             movups xmm0, [edi];
 33             movups xmm1, [edi + 16];        
 34             movups xmm2, [edi + 32];        
 35             movups xmm3, [edi + 48];    
 36             //为什么只载入4*4个float?  到上面看看这一轮需要处理多少数据
 37 
 38             movups xmm4, [esi];            
 39             movups xmm5, [esi + 16];        
 40             movups xmm6, [esi + 32];        
 41             movups xmm7, [esi + 48];        
 42 
 43             //mulps XMM,XMM/m128
 44             //寄存器按双字对齐, 
 45             //共4个单精度浮点数与目的寄存器里的4个对应相乘, 
 46             //结果送入目的寄存器, 内存变量必须对齐内存16字节.
 47             mulps xmm0, xmm4;            
 48             mulps xmm1, xmm5;            
 49             mulps xmm2, xmm6;            
 50             mulps xmm3, xmm7;            
 51 
 52             //(一个float占4字节,也就是32bit)
 53             //到这里,xmm0-3寄存器里都有了4个float的乘积结果
 54             //然后回载到相应内存
 55             movups[edi], xmm0;        
 56             movups[edi + 16], xmm1;        
 57             movups[edi + 32], xmm2;        
 58             movups[edi + 48], xmm3;        
 59 
 60             //记得给指针移位
 61             //64=16 * 4 
 62             //每一轮处理了16次float * float,每一个float占4字节
 63             //所以移位应该加64
 64             add edi, 64;
 65             add esi, 64;
 66 
 67             loop L1;                            
 68 
 69         ZERO:
 70             mov ecx, cnt2;
 71             jecxz ZERO1;
 72 
 73         L2:
 74 
 75             movups xmm0, [edi];           //对于4个float,一个xmm寄存器正好够用
 76             movups xmm1, [esi];        
 77             mulps xmm0, xmm1;           //对应相乘,结果在xmm0
 78             movups[edi], xmm0;           //由xmm0回载内存
 79             add edi, 16;               //指针移位
 80             add esi, 16;
 81 
 82             loop L2;
 83 
 84         ZERO1:
 85 
 86             mov ecx, cnt3;
 87             jecxz ZERO2;
 88 
 89             mov eax, 0;
 90 
 91         L3:                                
 92 
 93             movd eax, [edi];            //对于单个float * float,无需sse指令
 94             imul eax, [esi];
 95             movd[edi], eax;
 96             add esi, 4;
 97             add edi, 4;
 98 
 99             loop L3;
100 
101         ZERO2:
102 
103             EMMS;                       //清空
104 
105         }
106 
107     }
108     else
109     {
110 
111         cnt1 = cnt / 28;                          //该loop一轮处理28个float*float
112         cnt2 = (cnt - (28 * cnt1)) / 4;           //该loop一轮处理4个float*float
113         cnt3 = (cnt - (28 * cnt1) - (4 * cnt2));  //该loop一轮处理1个float*float
114 
115         _asm
116         {
117 
118             
119             mov edi, A;    
120             mov esi, B;    
121             mov ecx, cnt1;    
122             jecxz AZERO;
123 
124         AL1:
125 
126             //movaps XMM, XMM / m128 
127             //把源存储器内容值送入目的寄存器, 当有m128时, 必须对齐内存16字节, 也就是内存地址低4位为0.
128             movaps xmm0, [edi];        
129             movaps xmm1, [edi + 16];        
130             movaps xmm2, [edi + 32];        
131             movaps xmm3, [edi + 48];        
132             movaps xmm4, [edi + 64];        
133             movaps xmm5, [edi + 80];        
134             movaps xmm6, [edi + 96];
135             //7*4=28,处理28个float*float
136 
137             mulps xmm0, [esi];            //对应点乘
138             mulps xmm1, [esi + 16];        
139             mulps xmm2, [esi + 32];        
140             mulps xmm3, [esi + 48];        
141             mulps xmm4, [esi + 64];        
142             mulps xmm5, [esi + 80];        
143             mulps xmm6, [esi + 96];        
144 
145             movaps[edi], xmm0;            //回载
146             movaps[edi + 16], xmm1;        
147             movaps[edi + 32], xmm2;        
148             movaps[edi + 48], xmm3;        
149             movaps[edi + 64], xmm4;        
150             movaps[edi + 80], xmm5;        
151             movaps[edi + 96], xmm6;        
152 
153             add edi, 112;
154             add esi, 112;
155 
156             loop AL1;                            
157 
158         AZERO:
159             mov ecx, cnt2;
160             jecxz AZERO1;
161 
162         AL2:
163 
164             movaps xmm0, [edi];        
165             mulps xmm0, [esi];        
166             movaps[edi], xmm0;        
167             add edi, 16;
168             add esi, 16;
169 
170             loop AL2;
171 
172         AZERO1:
173 
174             mov ecx, cnt3;
175             jecxz AZERO2;
176 
177             mov eax, 0;
178 
179         AL3:                                
180 
181             movd eax, [edi];
182             imul eax, [esi];
183             movd[edi], eax;
184             add esi, 4;
185             add edi, 4;
186 
187             loop AL3;
188 
189         AZERO2:
190 
191             EMMS;
192 
193         }
194 
195     }
View Code

由于内存对齐的问题,导致末尾有部分数据不正常,特添加C++部分修复。
sse_mul_float c++部分

1 int start;
2     start = cnt - (cnt % 4);
3     for (int i = start; i < cnt; i++)
4     {
5         A[i] *= B[i];
6     }

 

用于累加的这个函数,分两块。一块是用指令把大部分数据处理掉,而极少部分数据使用C++语句,这样能各取所长。

sse_acc_float asm部分

  1 float temp = 0;
  2 
  3     int cnt1;
  4     int cnt2;
  5     int cnt3;
  6     int select = 0;
  7 
  8     //we process  the majority by using SSE instructions
  9     if (((int)A % 16))      //unaligned 如果这次调用,内存数据不对齐
 10     {
 11         select = 1;
 12         
 13         cnt1 = cnt / 24;
 14         cnt2 = (cnt - (24 * cnt1)) / 8;
 15         cnt3 = (cnt - (24 * cnt1) - (8 * cnt2));
 16 
 17         __asm
 18         {
 19             
 20             mov edi, A;            
 21             mov ecx, cnt1;        
 22             pxor xmm0, xmm0;    
 23             jecxz ZERO;
 24 
 25         L1:
 26 
 27             movups xmm1, [edi];
 28             movups xmm2, [edi + 16];
 29             movups xmm3, [edi + 32];
 30             movups xmm4, [edi + 48];
 31             movups xmm5, [edi + 64];
 32             movups xmm6, [edi + 80];
 33 
 34             //addps 对应相加
 35             //结果返回目的寄存器
 36             addps xmm1, xmm2;
 37             addps xmm3, xmm4;
 38             addps xmm5, xmm6;
 39 
 40             addps xmm1, xmm5;
 41             addps xmm0, xmm3;
 42 
 43             addps xmm0, xmm1;
 44             //至此,xmm0内4个float的和就是24个float的和
 45 
 46             add edi, 96;
 47 
 48             loop L1;                        
 49 
 50         ZERO:
 51 
 52 
 53             movd ebx, xmm0;      //低4个字节(第一个float)传入ebx
 54             psrldq xmm0, 4;      //xmm0右移4字节
 55             movd eax, xmm0;      //右移后,低4个字节(第二个float)传入eax
 56 
 57             movd xmm1, eax;      //第一个float传入xmm1低32bit
 58             movd xmm2, ebx;      //第二个float传入xmm2低32bit
 59             addps xmm1, xmm2;    //两个寄存器内4个float对应相加
 60             movd eax, xmm1;      //只取我们要的低位float,传入eax
 61             movd xmm3, eax;      //第一和第二个float的和存在xmm3低32位
 62             psrldq xmm0, 4;      //又截掉一个float
 63             
 64 
 65             movd ebx, xmm0;      //第三个float进ebx
 66             psrldq xmm0, 4;      //截掉第三个float
 67             movd eax, xmm0;      //第四个float进eax 
 68 
 69             movd xmm1, eax;      
 70             movd xmm2, ebx;
 71             addps xmm1, xmm2;    //第三和第四个float的和存在xmm1低32位
 72             movd eax, xmm1;
 73             movd xmm4, eax;
 74             addps xmm3, xmm4;    //4个float的和存在xmm3低32位
 75 
 76 
 77             movd eax, xmm3;
 78             mov temp, eax;       //这部分求和存在temp地址区
 79 
 80 
 81 
 82             EMMS;                            
 83 
 84         }
 85     }
 86     else              // aligned   如果这次调用,内存数据对齐
 87     {
 88         select = 2;
 89         
 90         cnt1 = cnt / 56;                     
 91         cnt2 = (cnt - (56 * cnt1)) / 8;        
 92         cnt3 = (cnt - (56 * cnt1) - (8 * cnt2)); 
 93 
 94         __asm
 95         {
 96             
 97             mov edi, A;            
 98             mov ecx, cnt1;        
 99             pxor xmm0, xmm0;    
100             jecxz ZZERO;
101 
102         LL1:
103 
104             movups xmm1, [edi];
105             movups xmm2, [edi + 16];
106             movups xmm3, [edi + 32];
107             movups xmm4, [edi + 48];
108             movups xmm5, [edi + 64];
109             movups xmm6, [edi + 80];
110 
111             addps xmm1, xmm2;
112             addps xmm3, xmm4;
113             addps xmm5, xmm6;
114             addps xmm1, xmm5;
115             addps xmm0, xmm3;
116             addps xmm0, xmm1;
117 
118             add edi, 96;
119 
120             movups xmm1, [edi];
121             movups xmm2, [edi + 16];
122             movups xmm3, [edi + 32];
123             movups xmm4, [edi + 48];
124             movups xmm5, [edi + 64];
125             movups xmm6, [edi + 80];
126 
127             addps xmm1, xmm2;
128             addps xmm3, xmm4;
129             addps xmm5, xmm6;
130             addps xmm1, xmm5;
131             addps xmm0, xmm3;
132             addps xmm0, xmm1;
133 
134             add edi, 96;
135 
136             movups xmm1, [edi];
137             movups xmm2, [edi + 16];
138 
139             addps xmm1, xmm2;
140             addps xmm0, xmm1;
141 
142             add edi, 32;
143 
144             loop LL1;                        
145 
146         ZZERO:
147 
148 
149             movd ebx, xmm0;
150             psrldq xmm0, 4;
151             movd eax, xmm0;
152 
153             movd xmm1, eax;
154             movd xmm2, ebx;
155             addps xmm1, xmm2;
156             movd eax, xmm1;
157             movd xmm3, eax;
158             psrldq xmm0, 4;
159             
160 
161             movd ebx, xmm0;
162             psrldq xmm0, 4;
163             movd eax, xmm0;
164 
165             movd xmm1, eax;
166             movd xmm2, ebx;
167             addps xmm1, xmm2;
168             movd eax, xmm1;
169             movd xmm4, eax;
170             addps xmm3, xmm4;
171 
172 
173             movd eax, xmm3;
174             mov temp, eax;
175 
176             EMMS;                            
177 
178         }
179     }
View Code

 

sse_acc_float   c++部分

//上面的select记录本次调用sse_acc_float时,数据是否对齐内存
    //下面分情况把剩余的和累加
    int start;
    float c = 0.0f;
    if (select == 1)
    {
        
        start = cnt - (cnt % 24);
        for (int i = start; i < cnt; i++)
        {
            c += A[i];
        }
        
    }
    else
    {
        start = cnt - (cnt % 56);
        for (int i = start; i < cnt; i++)
        {
            c += A[i];
        }
        
    }

    //temp 是用指令计算 ,大部分数据的和
    //c    是用C++语句计算, 所有数据模24或者56剩余部分数据的和
    return(temp + c);
View Code

 

 推荐参考:SIMD(单道指令多道数据流)指令(MMX/SSE1/SSE2)详解(中文).

 我是一名编程菜鸟,有什么技术上的问题,欢迎讨论和交流指正。谢谢!

获取全部源码:点此  dot_acc.cpp

 

转载于:https://www.cnblogs.com/errorplayer/p/6616091.html

你可能感兴趣的:(SIMD学习 -- 用SSE2指令作点乘和累加计算)