SSE图像算法优化系列十六:经典USM锐化中的分支判断语句SSE实现的几种方法尝试。

  分支判断的语句一般来说是不太适合进行SSE优化的,因为他会破坏代码的并行性,但是也不是所有的都是这样的,在合适的场景中运用SSE还是能对分支预测进行一定的优化的,我们这里以某一个算法的部分代码为例进行讲解。

  在某一个版本的USM锐化算法中有这样的一段代码:

int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
    int Channel = Stride / Width;
    if ((Src == NULL) || (Dest == NULL))                                return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0))                                  return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3) && (Channel != 4))             return IM_STATUS_INVALIDPARAMETER;
    int Status = IM_STATUS_OK;

    Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);    //  这里标准过程是用IM_GaussBlur代替
    if (Status != IM_STATUS_OK)    return Status;

    const float Inv255 = 1.0f / 255.0f;
    int *Table = (int *)malloc(511 * 256 * sizeof(int));
    if (Table == NULL)    return IM_STATUS_OUTOFMEMORY;

    for (int Y = 0; Y < 256; Y++)
    {
        float TempUp = Amount * sqrtf(1.0f - Y * Inv255) / 100.0f;
        float TempDown = Amount * sqrtf(Y * Inv255) / 100.0f;
        for (int X = -255; X <= 255; X++)
        {
            int Diff = X;
            if (Diff >= Threshold)
            {
                Diff -= Threshold;
                Table[((X + 255) << 8) + Y] = IM_ClampToByte(int(Diff * TempUp + 0.5f) + Y);
            }
            else if (Diff < -Threshold)
            {
                Diff += Threshold;
                Table[((X + 255) << 8) + Y] = IM_ClampToByte(int(Diff * TempDown + 0.5f) + Y);
            }
            else
            {
                Table[((X + 255) << 8) + Y] = Y;        //    不做变化
            }
        }
    }
    for (int Y = 0; Y < Height * Stride; Y++)            //    分四路并行速度只有一点点提高
    {
        Dest[Y] = Table[((Src[Y] - Dest[Y] + 255) << 8) + Src[Y]];
    }

    free(Table);
    return IM_STATUS_OK;
}

  这个USM锐化的算法参考自:https://github.com/pluginguy/plugins/tree/master/USM2,源代码中的算法还提供了对高光、暗调和中间调进行不同调节的参数,我这里对他那个代码进行了适度的修改和简化,并且用查找表进行了优化。这个github的作者还提供了关于高斯模糊方面的资料,是个不错的参考点。

  上述代码起始已经很高效了,复杂的浮点和开方计算都已经用查表的形式进行了简化,实测一副1080P的24位图像大处理时间大约在14.5ms左右,而其中的IM_ExpBlur耗时约有6.75ms,建立查找表花了0.75ms,后面的遍历图像进行查找表替换使用了7ms,注意前面的IM_ExpBlur的时间是已经进行了SSE编码后的优化时间。

  查找表其实本身也是个耗时的工作,因为这个可能有着严重的cache miss,特别是查找表比较大的时候。但是查找表本身呢在目前SIMD框架下是无法使用SSE优化的(除非是16个字节的查找表,可以使用_mm_shuffle_epi8来优化),因此,如果查找表本身的建立算法并不特别复杂,是可以考虑使用SSE来对表中每个元素进行直接的实现的,鉴于此,我们来考虑上述代码的查找表的直接SSE实现。

  为了表示清楚,我们把上述算法的非查找表方式实现的代码整理出来如下:

int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
    int Channel = Stride / Width;
    if ((Src == NULL) || (Dest == NULL))                                  return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0))                                     return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3) && (Channel != 4))                 return IM_STATUS_INVALIDPARAMETER;
    int Status = IM_STATUS_OK;

    Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);        //    这里标准过程是用IM_GaussBlur代替
    if (Status != IM_STATUS_OK)    return Status;

    float Adjust = Amount / 100.0f / sqrtf(255.0f);
    for (int Y = 0; Y < Height * Stride; Y++)                        
    {
        int Diff = Src[Y] - Dest[Y];
        if (Diff >= Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]);
        }
        else if (Diff < -Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf((float)Src[Y]) + 0.5f) + Src[Y]);
        }
        else
        {
            Dest[Y] = Src[Y];        //    不做变化
        }
    }
    return IM_STATUS_OK;
}

  注意为减少计算我已经把一些重复的计算提取到Adjust变量中,其中的/sqrtf(255.0f)可以让循环内部的sqrtf的参数少一次乘法计算,并且在后面我们还可以看到他起到了另外一个特殊的作用。运行上述代码的同参数同照片耗时变为了55ms左右,可见查找表的优化也是很给力的。

  我注意到这段代码已经有很久了,也一直想使用SSE优化他们,但苦于能力,一直未得良方,不过最近过年重新审视这段代码,发现只要手指按住键盘,总会有新大陆发现的。

  第一方案:既然SSE不太好做分支判断,我就把所有分支的结果都计算出来,最后再根据分支条件做数据融合不就可以了吗,可以肯定SSE计算每个分支的速度肯定比C快,但是如果要每个分支都计算,这个增加的耗时和加速的时间比例如何呢,只有实践才知道,于是我硬着头皮把他们用SSE做个硬编码,代码如下所示:

//    实在没有好的办法,极端情况下把所有的分支的结果都算出来,然后在最后根据判断条件合成,比如下面的代码,写出来后比原始的查找表方式也还是要快一点的。
int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
    int Channel = Stride / Width;
    if ((Src == NULL) || (Dest == NULL))                                return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0))                                    return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3) && (Channel != 4))                return IM_STATUS_INVALIDPARAMETER;
    int Status = IM_STATUS_OK;

    Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);
    if (Status != IM_STATUS_OK)    return Status;

    const float Adjust = Amount / 100.0f / sqrt(255.0f);
    const int BlockSize = 8;
    int Block = (Height * Stride) / BlockSize;

    const __m128i Zero = _mm_setzero_si128();
    const __m128i ThresholdV = _mm_set1_epi16(Threshold);
    const __m128i MinusThresholdV = _mm_set1_epi16(-Threshold);
    const __m128i One = _mm_set1_epi16(1);
    const __m128i MinusOne = _mm_set1_epi16(-1);
    const __m128 Const255 = _mm_set1_ps(255.0f);
    const __m128 AdjustV = _mm_set1_ps(Adjust);

    for (int Y = 0; Y < Block * BlockSize; Y += BlockSize)
    {
        __m128i SrcV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Src + Y)), Zero);
        __m128i DstV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Dest + Y)), Zero);
        __m128 SrcL = _mm_cvtepi32_ps(_mm_unpacklo_epi8(SrcV, Zero));
        __m128 SrcH = _mm_cvtepi32_ps(_mm_unpackhi_epi8(SrcV, Zero));
        __m128i Diff = _mm_sub_epi16(SrcV, DstV);
        __m128i DiffA = _mm_add_epi16(Diff, ThresholdV);
        __m128i DiffS = _mm_sub_epi16(Diff, ThresholdV);
        __m128 DiffL = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(Diff));
        __m128 DiffH = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(Diff, 8)));

        __m128 UpL = _mm_mul_ps(AdjustV, _mm_sqrt_ps(_mm_sub_ps(Const255, SrcL)));
        __m128 UpH = _mm_mul_ps(AdjustV, _mm_sqrt_ps(_mm_sub_ps(Const255, SrcH)));
        __m128 DownL = _mm_mul_ps(AdjustV, _mm_sqrt_ps(SrcL));
        __m128 DownH = _mm_mul_ps(AdjustV, _mm_sqrt_ps(SrcH));

        __m128 DiffUpL = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(DiffS)), UpL);
        __m128 DiffUpH = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(DiffS, 8))), UpH);
        __m128 DiffDownL = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(DiffA)), DownL);
        __m128 DiffDownH = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(DiffA, 8))), DownH);

        __m128i DiffUp = _mm_adds_epi16(_mm_packs_epi32(_mm_cvtps_epi32(DiffUpL), _mm_cvtps_epi32(DiffUpH)), SrcV);
        __m128i DiffDown = _mm_adds_epi16(_mm_packs_epi32(_mm_cvtps_epi32(DiffDownL), _mm_cvtps_epi32(DiffDownH)), SrcV);

        __m128i DestV = _mm_blendv_si128(SrcV, DiffUp, _mm_cmpgt_epi16(Diff, ThresholdV));
        DestV = _mm_blendv_si128(DestV, DiffDown, _mm_cmplt_epi16(Diff, MinusThresholdV));
_mm_storel_epi64((__m128i
*)(Dest + Y), _mm_packus_epi16(DestV, Zero)); } for (int Y = Block * BlockSize; Y < Height * Stride; Y++) { int Diff = Src[Y] - Dest[Y]; if (Diff >= Threshold) { Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]); } else if (Diff < -Threshold) { Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf(0.0f + Src[Y]) + 0.5f) + Src[Y]); } else { Dest[Y] = Src[Y]; } } return IM_STATUS_OK; }

  上述代码基本就是普通C语言的翻译,这里讲几个需要注意的地方。

  第一、_mm_cvtepi16_epi32这是个讲signed short转换为signed int的函数,只处理XMM寄存的低8位,如果需要将高8位也进行转换,就必须得配合_mm_srli_si128一起使用,如果需要转换的signed short能确认是大于等于0的,也可以使用_mm_unpacklo_epi16及_mm_unpackhi_epi16配合_mm_setzero_si128来实现,比如上面的SrcL和SrcH就是使用的这个技巧,但是如果有小于0的情况出现,一定只能用_mm_cvtepi16_epi32来实现,比如上面的DiffL和DiffH,我以前在这个上面吃过很多亏。

  第二、在计算DiffUp和DiffDown这两个结果时,注意需要使用_mm_packs_epi32,而不是_mm_packus_epi32,因为计算结果是有负数存在的。

  第三、结果的融合这里的技巧很好,我们知道SSE4提供了两个__m128i变量融合的函数,比如_mm_blendv_epi8,但是他要求最后的融合选项是个常数,而我们这里的融合选项是变化的,所以无法使用,我们使用了一个叫做_mm_blendv_si128的内联函数,这个函数用一个__m128i变量作为融合参数,对128个位进行融合,其代码如下:

static inline __m128i _mm_blendv_si128(__m128i x, __m128i y, __m128i mask)
{
    return _mm_or_si128(_mm_andnot_si128(mask, x), _mm_and_si128(mask, y));
}

  当mask的某一位为0时,选择x中的对应位的值,否则选择y中对应位的值。

  这个函数正是我需要的,而且恰好前几天在浏览文章:A few missing SSE intrinsics发现了他,有的时候真的觉得处处留心皆学问啊。

  这时我们来看下上面的融合的代码:__m128i DestV = _mm_blendv_si128(SrcV, DiffUp, _mm_cmpgt_epi16(Diff, ThresholdV));

  后面的_mm_cmpgt_epi16的比较函数会返回一个__m128i变量,当Diff > Threshold时,对应的16位数据为0xFFFF,否则为0,这样我们使用_mm_blendv_si128融合时,满足条件的部分结果就为DiffUp了,其他部分还保持SrcV不变。

  接着 DestV = _mm_blendv_si128(DestV, DiffDown, _mm_cmplt_epi16(Diff, MinusThresholdV)); 使用Diff < -Threshold作为判断条件,因为该条件和Diff > Threshold不可能同时成立,所以_mm_cmplt_epi16的返回结果中的为true的部分和_mm_cmpgt_epi16返回的true部分的值不可能重叠,因此,再次执行_mm_blendv_si128混合的值就是我们融合的正确结果。

  那么我们最关心的速度来了,经过测试,上述算法对1080P彩色图能达到约14ms的执行速度,和查找表的C语言版本速度差不多,唯一的优势就是运算时少占用了一部分内存。但是同时也说明SSE的计算能力真的不是盖的,算一算,正正的SSE执行时间实际上只有14-6.75 =7.25ms,而不用查找表的C代码的用时为55-6.75=48.25ms,达到了进7倍的提速比,但这就是我们的终点了吗?

  第二方案:我们在仔细观察下Diff > Threshold和Diff < -Threshold时计算的不同,第一个不同是Diff > Threshold时使用了Diff - Threshold,而Diff < -Threshold时使用了Diff + Threshold;第二个不同为Diff > Threshold时使用了255.0f - Src[Y]作为开平方的算式,而Diff < -Threshold时使用了 Src[Y]。关于第一个不同,我们可以看到仅仅是个符号位不同,如果在Threshold前面根据不同的条件加个符号位在进行乘法不就可以了,也就是说如果我们根据Diff和Threshold的关系构建一个-1和1的中间变量,则可以把他们写在一个式子里,那这样的符号为要如何构建呢?

  自然而然我们又想到了上述方法的_mm_blendv_si128,简单的方式如下所示:

__m128i Sign = _mm_blendv_si128(Zero, MinusOne, _mm_cmpgt_epi16(Diff, ThresholdV));
        Sign = _mm_blendv_si128(Sign, One, _mm_cmplt_epi16(Diff, MinusThresholdV));

  Zero,MinusOne,One这个还需不需要解释,上面的代码还需不需要解释?

  第二个不同,我们这样看,我们把它们放在一起 255.0f - Src[Y]  |  Src[Y],稍微改写一下255 - Src[Y]  | 0 -  Src[Y],后面的+和-可以用类似前面的同样的方法处理,我们还需处理255和0,如果我们能够根据判断条件构造出255 和 0这样的序列,那是不是就解决问题了,如何构造?

  前面说过,_mm_cmpgt_epi16会返回0xFFFF和0,看成unsigned short类型则为65535和0, 如果我们把这个返回结果右移8位,是不是就变为了255和0呢,明白了吗?

  最后我们注意一点,当-Threshold < Diff

  我们还来在说下前面的符号问题,正或者负某个数,直接用符号位加乘法固然是可以实现的,但是有么有其他的方式更好的实现呢,翻一番SSE的手册,我们会发现有_mm_sign_epi8 、_mm_sign_epi16 、_mm_sign_epi32 这样的函数,他们是干什么的呢,我们以_mm_sign_epi16为例,看看他的文档说明:

extern __m128i _mm_sign_epi16 (__m128i a, __m128i b); 
Negate packed words in a if corresponding sign in b is less than zero. 
Interpreting a, b, and r as arrays of signed 16-bit integers: 
for (i = 0; i < 8; i++)
{ 
    if (b[i] < 0) 
    { 
        r[i] = -a[i]; 
    } 
    else if (b[i] == 0)
    { 
        r[i] = 0; 
    } 
    else 
    { 
        r[i] = a[i]; 
    } 
}

  什么意思,就是以参数b的符号位来决定a的值,当b为负数是,对a求反,当b为0时,a也为0,否则a值保持不变。这不就可以直接实现上述的符号位的问题了吗?

  说了那么多,我贴出代码大家看一看:

int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
    int Channel = Stride / Width;
    if ((Src == NULL) || (Dest == NULL))                                return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0))                                    return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3) && (Channel != 4))                return IM_STATUS_INVALIDPARAMETER;
    int Status = IM_STATUS_OK;

    Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);
    if (Status != IM_STATUS_OK)    return Status;

    const float Adjust = Amount / 100.0f / sqrt(255.0f);
    const int BlockSize = 8;
    int Block = (Height * Stride) / BlockSize;

    const __m128i Zero = _mm_setzero_si128();
    const __m128i ThresholdV = _mm_set1_epi16(Threshold);
    const __m128i MinusThresholdV = _mm_set1_epi16(-Threshold);
    const __m128i MinusOne = _mm_set1_epi16(-1);
    const __m128 AdjustV = _mm_set1_ps(Adjust);
    const __m128i One = _mm_set1_epi16(1);
    for (int Y = 0; Y < Block * BlockSize; Y += BlockSize)
    {
        __m128i SrcV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Src + Y)), Zero);
        __m128i DstV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Dest + Y)), Zero);
        __m128i Diff = _mm_sub_epi16(SrcV, DstV);                                                //    int Diff = Src[Y] - Dest[Y];
        
        //    当Diff > ThresholdV时,Sign设置为负数,当Diff < -ThresholdV时,Sign设置为正数,
        //    介于-ThresholdV和ThresholdV之间时为0,这里One和MinusOne只是取得一个代表性的值

        __m128i SignA = _mm_cmpgt_epi16(Diff, ThresholdV);                           
        __m128i SignB = _mm_cmplt_epi16(Diff, MinusThresholdV);                        
        __m128i Sign = _mm_blendv_si128(Zero, MinusOne, SignA);
        Sign = _mm_blendv_si128(Sign, One, SignB);
            
        //    Diff 为不同值时,NewDiff需要带上不同符号,利用上面的Sign配合_mm_sign_epi16能很好的解决问题
        __m128i NewDiff = _mm_add_epi16(Diff, _mm_sign_epi16(ThresholdV, Sign));

        //    _mm_cmpgt_epi16返回0xfffff和0两种值,我们这里需要的是0xff和0,因此需要进行下移位,注意此时在Diff < Threshold(Sign为0或者1时)
        //    _mm_add_epi16的第一个参数都是0,而第二个参数对于Sign为0的情况则也返回0,这样0+0正好为0,Sqrt后也为0,对结果正好没有影响(巧合还是天意?)
        __m128i NewPower = _mm_add_epi16(_mm_srli_epi16(SignA, 8), _mm_sign_epi16(SrcV, Sign));

        //    注意这里有负数存在,则必须用这种强制转换函数
        __m128 NewDiffL = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(NewDiff));                                    
        __m128 NewDiffH = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(NewDiff, 8)));

        //    都是正数就可以这样转化
        __m128 NewPowerL = _mm_cvtepi32_ps(_mm_unpacklo_epi16(NewPower, Zero));                            
        __m128 NewPowerH = _mm_cvtepi32_ps(_mm_unpackhi_epi16(NewPower, Zero));

        //    按公式计算结果
        __m128 DstL = _mm_mul_ps(_mm_mul_ps(AdjustV, NewDiffL), _mm_sqrt_ps(NewPowerL));
        __m128 DstH = _mm_mul_ps(_mm_mul_ps(AdjustV, NewDiffH), _mm_sqrt_ps(NewPowerH));

        //    合成到16位的结果,注意这里不要用_mm_packus_epi32,因为后面还有一个加法要进行
        __m128i Result = _mm_packs_epi32(_mm_cvtps_epi32(DstL), _mm_cvtps_epi32(DstH));                    

        //    合成到8位的结果,注意这要用抗饱和的加法_mm_adds_epi16
        _mm_storel_epi64((__m128i *)(Dest + Y), _mm_packus_epi16(_mm_adds_epi16(Result, SrcV), Zero));
    }

    for (int Y = Block * BlockSize; Y < Height * Stride; Y++)
    {
        int Diff = Src[Y] - Dest[Y];
        if (Diff >= Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]);
        }
        else if (Diff < -Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf(0.0f + Src[Y]) + 0.5f) + Src[Y]);
        }
        else
        {
            Dest[Y] = IM_ClampToByte(int(Diff * Adjust * sqrtf(0.0f + 0.0f) + 0.5f) + Src[Y]);        //    不做变化
        }
    }

    return IM_STATUS_OK;
}

  最后回到我们关心的速度问题上去,经过上述优化后能达到的速度平均值在11.5ms左右,比查找表版本的还要快了3ms左右。

  实际上上述求Sign的过程还有更为简单的优化过程的,想通了也很有道理,这个留个读者自行去研究,大概能加快0.4ms左右的速度。

  关于分支预测的SSE优化,目前我掌握的技巧也就这么多,管件还是要看算法本身,有的时候要脱离原始算法,为了能用SSE而稍微改变下算法的外表。这就各位神仙各显神通了,当然有很多分支预测由于太复杂还是不能够用SIMD指令优化的。

  最后说一句,关于Photoshop的标准USM锐化并不是使用的上述算法,其原理应该说比上面的还要简单,但也不是网络上流行的那个计算公式,我已经通过测试推到得到了和其一模一样的计算式,这里不提,不过呢,为什么非要一样呢,这里的这个算法也是不错的。

  算法Demo下载地址:https://files.cnblogs.com/files/Imageshop/SSE_Optimization_Demo.rar

 

      SSE图像算法优化系列十六:经典USM锐化中的分支判断语句SSE实现的几种方法尝试。_第1张图片

 

你可能感兴趣的:(SSE图像算法优化系列十六:经典USM锐化中的分支判断语句SSE实现的几种方法尝试。)