koturnの日記

普通の人です.ブログ上のコードはコピペ自由です.

Unityのシェーダーにおいて,定数の整数乗を行う場合の繰り返し二乗法はループ展開されるか?

TL;DR

繰り返し二乗法の関数を定義し,指数に定数を指定した場合はループ展開される. しかし,pow(x, 5.0) のように指数が比較的小さい整数定数を指定した場合でも繰り返し二乗法をループ展開したものと同一のコード生成がされるため,自前で繰り返し二乗法の関数を設ける必要はないかもしれない.

前置き

Unityのシェーダーの標準ライブラリのコードを眺めていて,以下のようなものを見つけた.

  • UnityStandardBRDF.cginc
inline half Pow4 (half x)
{
    return x*x*x*x;
}

inline float2 Pow4 (float2 x)
{
    return x*x*x*x;
}

inline half3 Pow4 (half3 x)
{
    return x*x*x*x;
}

inline half4 Pow4 (half4 x)
{
    return x*x*x*x;
}

// Pow5 uses the same amount of instructions as generic pow(), but has 2 advantages:
// 1) better instruction pipelining
// 2) no need to worry about NaNs
inline half Pow5 (half x)
{
    return x*x * x*x * x;
}

inline half2 Pow5 (half2 x)
{
    return x*x * x*x * x;
}

inline half3 Pow5 (half3 x)
{
    return x*x * x*x * x;
}

inline half4 Pow5 (half4 x)
{
    return x*x * x*x * x;
}

pow() 関数による累乗の計算は整数ではない指数による累乗にも対応している. しかし,実装としては log2()exp2() を利用したものとなるため,Pow5() のコメントにあるように,命令数は同じでも(乗算3回か log, mul, exp の3命令)計算負荷は単純な乗算の方が軽いはずである.

float3 pow(float3 x, float3 y)

    return exp2(log2(x) * y);
}

4乗と5乗で同じように処理を記述しているのはコード量が膨れ上がると思った. また,2乗や3乗のバリエーションも欲しい.

指数が非負整数の累乗は繰り返し二乗法で効率的に計算可能であることはよく知られている. 以下のように繰り返し二乗法の関数を定義し,累乗部分に定数を指定することで,良い感じにループ展開されるかどうかを調べた. (すなわち,Pow4()Pow5() の出力アセンブリと一致,または同等のものになるかどうか調べた)

inline float pown(float x, int n)
{
    float v = 1.0;
    UNITY_UNROLL
    for (; n > 0; n >>= 1) {
        v *= (n & 1) == 0 ? 1.0 : x;
        x *= x;
    }
    return v;
}

inline float2 pown(float2 x, int n)
{
    static const float2 ones = float2(1.0, 1.0);

    float2 v = ones;
    UNITY_UNROLL
    for (; n > 0; n >>= 1) {
        v *= (n & 1) == 0 ? ones : x;
        x *= x;
    }
    return v;
}

inline float3 pown(float3 x, int n)
{
    static const float3 ones = float3(1.0, 1.0, 1.0);

    float3 v = ones;
    UNITY_UNROLL
    for (; n > 0; n >>= 1) {
        v *= (n & 1) == 0 ? ones : x;
        x *= x;
    }
    return v;
}

inline float4 pown(float4 x, int n)
{
    static const float4 ones = float4(1.0, 1.0, 1.0, 1.0);

    float4 v = ones;
    UNITY_UNROLL
    for (; n > 0; n >>= 1) {
        v *= (n & 1) == 0 ? ones : x;
        x *= x;
    }
    return v;
}

繰り返し二乗法

繰り返し二乗法の再帰的な定義は下記の通り(簡単のため $n$ は非負整数とする).

\begin{equation} x^{n} = \begin{cases} x (x^{\frac{n - 1}{2}})^2 & \text{where} ~ n \equiv 1 \pmod 2 \\ (x^{\frac{n}{2}})^2 & \text{where} ~ n \equiv 0 \pmod 2 \end{cases} \end{equation}

乗算命令の回数は$O(\log n)$であるが,正確には下記の通り(ループ展開した場合).

\begin{equation} C(n) = \lfloor \log_2 n \rfloor + popcnt(n) - 1 \label{numberOfMulExpBySquaring} \end{equation}

$popcnt$ は立っているビット数を数える関数であり,あえて数式で記述するならば下記のような定義となる.

\begin{equation} popcnt(n) = \begin{cases} 0 & \text{where} ~ n = 0 \\ popcnt \left(\dfrac{n}{2} \right) & \text{where} ~ n > 0 ~ \text{and} ~ n \equiv 0 \pmod 2 \\ 1 + popcnt \left( \dfrac{n - 1}{2} \right) & \text{where} ~ n > 0 ~ \text{and} ~ n \equiv 1 \pmod 2 \end{cases} \end{equation}

確認用シェーダー

指数は5,ベクトルの次元は3で確認すれば十分だろう. multi_compile 用のキーワードを設け,それぞれ下記のコードになるようにした.

キーワード コード詳細
NAIVE pow() 組み込み関数を利用
NAIVE_MUL 自前の乗算5回を行う関数を利用
ITERATIVE_SQUARE 自前の繰り返し二乗法を行う関数を利用
  • IntPower.shader
Shader "koturn/IntPower"
{
    Properties
    {
        [KeywordEnum(NAIVE, NAIVE_MUL, ITERATIVE_SQUARE)]
        _ExpMethod("Exp method", Int) = 2  // Default: ITERATIVE_SQUARE
    }

    SubShader
    {
        Pass
        {
            CGPROGRAM
            #pragma vertex vert
            #pragma fragment frag
            #pragma target 3.0
            #pragma multi_compile_local_fragment _EXPMETHOD_NAIVE _EXPMETHOD_NAIVE_MUL _EXPMETHOD_ITERATIVE_SQUARE

            #include "UnityCG.cginc"

            struct appdata
            {
                float4 vertex : POSITION;
                float2 uv : TEXCOORD0;
                float4 color : COLOR;
            };

            struct v2f
            {
                float4 vertex : SV_POSITION;
                float2 uv : TEXCOORD0;
                float4 color : TEXCOORD1;
            };


            inline float3 pow5(float3 x)
            {
                return x * x * x * x * x;
            }

            inline float3 pown(float3 x, int n)
            {
                static const float3 ones = float3(1.0, 1.0, 1.0);

                float3 v = ones;
                UNITY_UNROLL
                for (; n > 0; n >>= 1) {
                    v *= (n & 1) == 0 ? ones : x;
                    x *= x;
                }
                return v;
            }

            v2f vert(appdata v)
            {
                v2f o;
                o.vertex = UnityObjectToClipPos(v.vertex);
                o.uv = v.uv;
                o.color = v.color;

                return o;
            }

            fixed4 frag(v2f i) : SV_Target
            {
#if defined(_EXPMETHOD_NAIVE)
                // return float4(pow(i.color.rgb, 5.0), 1.0);
                // return float4(exp2(log2(i.color.rgb) * 5.0)), 1.0);
                return float4(exp2(i.color.rgb * log2(5.0))), 1.0);
#elif defined(_EXPMETHOD_NAIVE_MUL)
                return float4(pow5(i.color.rgb), 1.0);
#else
                return float4(pown(i.color.rgb, 5), 1.0);
#endif
            }
            ENDCG
        }
    }
}

確認用シェーダーの生成コード

各キーワードに対し,それぞれ下記のコードが生成されていた. すなわち,どれも全く同じ繰り返し二乗法をループ展開したコードとなっていた. となると,繰り返し二乗法の関数を自前で定義する必要はないかもしれない?

  • _EXPMETHOD_NAIVE
//////////////////////////////////////////////////////
Global Keywords: <none>
Local Keywords: _EXPMETHOD_NAIVE
-- Vertex shader for "d3d11":
// No shader variant for this keyword set. The closest match will be used instead.

-- Hardware tier variant: Tier 1
-- Fragment shader for "d3d11":
// Stats: 3 math, 1 temp registers
Shader Disassembly:
//
// Generated by Microsoft (R) D3D Shader Disassembler
//
//
// Input signature:
//
// Name                 Index   Mask Register SysValue  Format   Used
// -------------------- ----- ------ -------- -------- ------- ------
// SV_POSITION              0   xyzw        0      POS   float
// TEXCOORD                 0   xy          1     NONE   float
// TEXCOORD                 1   xyzw        2     NONE   float   xyz
//
//
// Output signature:
//
// Name                 Index   Mask Register SysValue  Format   Used
// -------------------- ----- ------ -------- -------- ------- ------
// SV_Target                0   xyzw        0   TARGET   float   xyzw
//
      ps_4_0
      dcl_input_ps linear v2.xyz
      dcl_output o0.xyzw
      dcl_temps 1
   0: mul r0.xyz, v2.xyzx, v2.xyzx
   1: mul r0.xyz, r0.xyzx, r0.xyzx
   2: mul o0.xyz, r0.xyzx, v2.xyzx
   3: mov o0.w, l(1.000000)
   4: ret
// Approximately 0 instruction slots used
  • _EXPMETHOD_NAIVE_MUL
//////////////////////////////////////////////////////
Global Keywords: <none>
Local Keywords: _EXPMETHOD_NAIVE_MUL
-- Vertex shader for "d3d11":
// No shader variant for this keyword set. The closest match will be used instead.

-- Hardware tier variant: Tier 1
-- Fragment shader for "d3d11":
// Stats: 3 math, 1 temp registers
Shader Disassembly:
//
// Generated by Microsoft (R) D3D Shader Disassembler
//
//
// Input signature:
//
// Name                 Index   Mask Register SysValue  Format   Used
// -------------------- ----- ------ -------- -------- ------- ------
// SV_POSITION              0   xyzw        0      POS   float
// TEXCOORD                 0   xy          1     NONE   float
// TEXCOORD                 1   xyzw        2     NONE   float   xyz
//
//
// Output signature:
//
// Name                 Index   Mask Register SysValue  Format   Used
// -------------------- ----- ------ -------- -------- ------- ------
// SV_Target                0   xyzw        0   TARGET   float   xyzw
//
      ps_4_0
      dcl_input_ps linear v2.xyz
      dcl_output o0.xyzw
      dcl_temps 1
   0: mul r0.xyz, v2.xyzx, v2.xyzx
   1: mul r0.xyz, r0.xyzx, r0.xyzx
   2: mul o0.xyz, r0.xyzx, v2.xyzx
   3: mov o0.w, l(1.000000)
   4: ret
// Approximately 0 instruction slots used
  • _EXPMETHOD_ITERATIVE_SQUARE
//////////////////////////////////////////////////////
Global Keywords: <none>
Local Keywords: _EXPMETHOD_ITERATIVE_SQUARE
-- Vertex shader for "d3d11":
// No shader variant for this keyword set. The closest match will be used instead.

-- Hardware tier variant: Tier 1
-- Fragment shader for "d3d11":
// Stats: 3 math, 1 temp registers
Shader Disassembly:
//
// Generated by Microsoft (R) D3D Shader Disassembler
//
//
// Input signature:
//
// Name                 Index   Mask Register SysValue  Format   Used
// -------------------- ----- ------ -------- -------- ------- ------
// SV_POSITION              0   xyzw        0      POS   float
// TEXCOORD                 0   xy          1     NONE   float
// TEXCOORD                 1   xyzw        2     NONE   float   xyz
//
//
// Output signature:
//
// Name                 Index   Mask Register SysValue  Format   Used
// -------------------- ----- ------ -------- -------- ------- ------
// SV_Target                0   xyzw        0   TARGET   float   xyzw
//
      ps_4_0
      dcl_input_ps linear v2.xyz
      dcl_output o0.xyzw
      dcl_temps 1
   0: mul r0.xyz, v2.xyzx, v2.xyzx
   1: mul r0.xyz, r0.xyzx, r0.xyzx
   2: mul o0.xyz, r0.xyzx, v2.xyzx
   3: mov o0.w, l(1.000000)
   4: ret
// Approximately 0 instruction slots used

pow()関数はどこまでを繰り返し二乗法のコードとするか

繰り返し二乗法の乗算回数は式 \eqref{numberOfMulExpBySquaring} の通りであるため,かなり多い乗算回数よりも log, mul, exp の3命令を用いた方が好ましいと考えられる. シェーダーコンパイラが整数定数を指定した pow() 関数から log, mul, exp の3命令を生成する閾値を調べてみることにした.

結果は下記の表の通り. この結果から指数 $n$ ではなく,乗算命令数 $C(n)$ が閾値となっており,7回以下なら繰り返し二乗法,8回以上なら log, mul, exp の3命令を生成することがわかる.

$$n$$ $$C(n)$$ 複数の乗算命令?
1 0 movのみ
2 1
3 2
4 2
5 3
6 3
7 4
8 3
9 4
10 4
11 5
12 4
13 5
14 5
15 6
16 4
17 5
18 5
19 6
20 5
21 6
22 6
23 7
24 5
25 6
26 6
27 7
28 6
29 7
30 7
31 8
32 5

まとめ

繰り返し二乗法の関数を用意し,指数に定数を指定することで,単純な乗算の命令が並ぶことが確認できた. これにより,2乗,3乗,4乗,...の関数を個別に用意せずとも,繰り返し二乗法の関数だけ用意するだけでよいことになる.

しかし,Direct3D11の出力アセンブラを確認する限り,組み込み関数 pow() の指数に定数を指定することでも繰り返し二乗法のコード生成をすることがわかった. しかも,乗算命令数によっては log, mul, exp の3命令にするようだ. このことから,わざわざ自前で繰り返し二乗法の関数を用意する必要はないかもしれない.

参考文献