koturnの日記

転職したい社会人1年生の技術系日記

AVXとAVX-512のインタリーブ

はじめに

前回の記事では,Intel系のCPUとARM系のCPUのSIMD命令紹介した. 記事中のサンプルコードで,画像の2倍の拡大を行うコードがあり,その中でインタリーブを行っていた.

SSEであれば,単純にunpack命令を実行するだけでよかった. 簡単なサンプルコードと出力結果は以下の通り.

// $ g++ -std=gnu++14 -march=native main.cpp -o main
#include <iostream>
#ifdef _MSC_VER
#  include <intrin.h>
#else
#  include <x86intrin.h>
#endif  // _MSC_VER


static inline void
showAVX128(const __m128i& v128) noexcept
{
  alignas(alignof(__m128i)) unsigned char z[sizeof(__m128i)] = {0};
  _mm_store_si128(reinterpret_cast<__m128i*>(z), v128);
  for (const auto& b : z) {
    std::cout << static_cast<unsigned int>(b) << ' ';
  }
  std::cout << "\n" << std::endl;
}


int
main()
{
  alignas(alignof(__m128i)) unsigned char z[sizeof(__m128i)] = {0};

  for (std::size_t i = 0; i < sizeof(z) / sizeof(z[0]); i++) {
    z[i] = static_cast<unsigned char>(i);
  }

  __m128i v128 = _mm_load_si128(reinterpret_cast<const __m128i*>(z));

  std::cout << "==================== low ====================" << std::endl;
  showAVX128(_mm_unpacklo_epi8(v128, v128));

  std::cout << "==================== high ====================" << std::endl;
  showAVX128(_mm_unpackhi_epi8(v128, v128));

  return 0;
}
==================== low ====================
0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7

==================== high ====================
8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15

しかし,AVX,AVX-512のインタリーブは一筋縄ではいかない. 例えば,以下の2つコード,

// $ g++ -std=gnu++14 -march=native main.cpp -o main
#include <iostream>
#ifdef _MSC_VER
#  include <intrin.h>
#else
#  include <x86intrin.h>
#endif  // _MSC_VER


static inline void
showAVX256(const __m256i& v256) noexcept
{
  alignas(alignof(__m256i)) unsigned char z[sizeof(__m256i)] = {0};
  _mm256_store_si256(reinterpret_cast<__m256i*>(z), v256);
  for (const auto& b : z) {
    std::cout << static_cast<unsigned int>(b) << ' ';
  }
  std::cout << "\n" << std::endl;
}


int
main()
{
  alignas(alignof(__m256i)) unsigned char z[sizeof(__m256i)] = {0};

  for (std::size_t i = 0; i < sizeof(z) / sizeof(z[0]); i++) {
    z[i] = static_cast<unsigned char>(i);
  }

  __m256i v256 = _mm256_load_si256(reinterpret_cast<const __m256i*>(z));

  std::cout << "==================== low ====================" << std::endl;
  showAVX256(_mm256_unpacklo_epi8(v256, v256));

  std::cout << "==================== high ====================" << std::endl;
  showAVX256(_mm256_unpackhi_epi8(v256, v256));

  return 0;
}
// $ g++ -std=gnu++14 -march=native -mavx512vbmi main.cpp -o main
#include <iostream>
#ifdef _MSC_VER
#  include <intrin.h>
#else
#  include <x86intrin.h>
#endif  // _MSC_VER


static inline void
showAVX512(const __m512i& v512) noexcept
{
  alignas(alignof(__m512i)) unsigned char z[sizeof(__m512i)] = {0};
  _mm512_store_si512(reinterpret_cast<__m512i*>(z), v512);
  for (const auto& b : z) {
    std::cout << static_cast<unsigned int>(b) << ' ';
  }
  std::cout << "\n" << std::endl;
}


int
main()
{
  alignas(alignof(__m512i)) unsigned char z[sizeof(__m512i)] = {0};

  for (std::size_t i = 0; i < sizeof(z) / sizeof(z[0]); i++) {
    z[i] = static_cast<unsigned char>(i);
  }

  __m512i v512 = _mm512_load_si512(reinterpret_cast<const __m512i*>(z));

  std::cout << "==================== low ====================" << std::endl;
  showAVX512(_mm512_unpacklo_epi8(v512, v512));

  std::cout << "==================== high ====================" << std::endl;
  showAVX512(_mm512_unpackhi_epi8(v512, v512));

  return 0;
}

の出力はそれぞれ次のようなものを期待する.

==================== low ====================
0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15

==================== high ====================
16 16 17 17 18 18 19 19 20 20 21 21 22 22 23 23 24 24 25 25 26 26 27 27 28 28 29 29 30 30 31 31
==================== low ====================
0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 16 16 17 17 18 18 19 19 20 20 21 21 22 22 23 23 24 24 25 25 26 26 27 27 28 28 29 29 30 30 31 31

==================== high ====================
32 32 33 33 34 34 35 35 36 36 37 37 38 38 39 39 40 40 41 41 42 42 43 43 44 44 45 45 46 46 47 47 48 48 49 49 50 50 51 51 52 52 53 53 54 54 55 55 56 56 57 57 58 58 59 59 60 60 61 61 62 62 63 63

しかし,実際の出力はそれぞれ以下の通り.

==================== low ====================
0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 16 16 17 17 18 18 19 19 20 20 21 21 22 22 23 23

==================== high ====================
8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 24 24 25 25 26 26 27 27 28 28 29 29 30 30 31 31
==================== low ====================
0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 16 16 17 17 18 18 19 19 20 20 21 21 22 22 23 23 32 32 33 33 34 34 35 35 36 36 37 37 38 38 39 39 48 48 49 49 50 50 51 51 52 52 53 53 54 54 55 55

==================== high ====================
8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 24 24 25 25 26 26 27 27 28 28 29 29 30 30 31 31 40 40 41 41 42 42 43 43 44 44 45 45 46 46 47 47 56 56 57 57 58 58 59 59 60 60 61 61 62 62 63 63

実行結果を観察する限り,AVX,AVX-512のunpackは128bitごとに区切り,SSEのunpack命令を実行しているようになっている.

期待通りの出力を得る,すなわち,レジスタ全体に渡ってインタリーブを行うためには,

  • AVXの場合,unpack命令の後に並び換え
  • AVX-512の場合,8bit毎にレジスタから値を選択

する必要がある.

具体的には以下のようなコードにするとよい. AVXの場合は, _mm256_permute2f128_si256() ,AVX-512の場合は, _mm512_permutex2var_epi8() を用いるとよい. AVXの場合,128bit毎に整列を行う命令が見当たらなかったので,unpackを行わず,直接値を選択する形となる.

// $ g++ -std=gnu++14 -march=native main.cpp -o main
#include <iostream>
#ifdef _MSC_VER
#  include <intrin.h>
#else
#  include <x86intrin.h>
#endif  // _MSC_VER


static inline void
showAVX256(const __m256i& v256) noexcept
{
  alignas(alignof(__m256i)) unsigned char z[sizeof(__m256i)] = {0};
  _mm256_store_si256(reinterpret_cast<__m256i*>(z), v256);
  for (const auto& b : z) {
    std::cout << static_cast<unsigned int>(b) << ' ';
  }
  std::cout << "\n" << std::endl;
}


int
main()
{
  alignas(alignof(__m256i)) unsigned char z[sizeof(__m256i)] = {0};

  for (std::size_t i = 0; i < sizeof(z) / sizeof(z[0]); i++) {
    z[i] = static_cast<unsigned char>(i);
  }

  __m256i v256 = _mm256_load_si256(reinterpret_cast<const __m256i*>(z));
  // unpackの結果を保存
  __m256i vlo = _mm256_unpacklo_epi8(v256, v256);
  __m256i vhi = _mm256_unpackhi_epi8(v256, v256);

  // _mm256_permute2f128_si256() で整列を行う
  std::cout << "==================== permute vlo ====================" << std::endl;
  showAVX256(_mm256_permute2f128_si256(vlo, vhi, 0x20));

  std::cout << "==================== permute vhi ====================" << std::endl;
  showAVX256(_mm256_permute2f128_si256(vlo, vhi, 0x31));

  return 0;
}
// $ g++ -std=gnu++14 -march=native -mavx512vbmi main.cpp -o main
#include <iostream>
#ifdef _MSC_VER
#  include <intrin.h>
#else
#  include <x86intrin.h>
#endif  // _MSC_VER


static inline void
showAVX512(const __m512i& v512) noexcept
{
  alignas(alignof(__m512i)) unsigned char z[sizeof(__m512i)] = {0};
  _mm512_store_si512(reinterpret_cast<__m512i*>(z), v512);
  for (const auto& b : z) {
    std::cout << static_cast<unsigned int>(b) << ' ';
  }
  std::cout << "\n" << std::endl;
}


int
main()
{
  static const __m512i LOIDX = _mm512_setr_epi64(
      0x4303420241014000,
      0x4707460645054404,
      0x4b0b4a0a49094808,
      0x4f0f4e0e4d0d4c0c,
      0x5313521251115010,
      0x5717561655155414,
      0x5b1b5a1a59195818,
      0x5f1f5e1e5d1d5c1c);
  static const __m512i HIIDX = _mm512_setr_epi64(
      0x6323622261216020,
      0x6727662665256424,
      0x6b2b6a2a69296828,
      0x6f2f6e2e6d2d6c2c,
      0x7333723271317030,
      0x7737763675357434,
      0x7b3b7a3a79397838,
      0x7f3f7e3e7d3d7c3c);
  alignas(alignof(__m512i)) unsigned char z[sizeof(__m512i)] = {0};

  for (std::size_t i = 0; i < sizeof(z) / sizeof(z[0]); i++) {
    z[i] = static_cast<unsigned char>(i);
  }

  __m512i v512 = _mm512_load_si512(reinterpret_cast<const __m512i*>(z));

  // _mm512_permutex2var_epi8() で整列を行う
  std::cout << "==================== permute vlo ====================" << std::endl;
  showAVX512(_mm512_permutex2var_epi8(v512, LOIDX, v512));

  std::cout << "==================== permute vhi ====================" << std::endl;
  showAVX512(_mm512_permutex2var_epi8(v512, HIIDX, v512));

  return 0;
}

これで期待通りの出力を得ることが可能である.

_mm256_permute2f128_si256()_mm512_permutex2var_epi8() については,Intelのガイドページより検索するとよい.

をそれぞれ見るとよい. 疑似コードは動作の理解の助けになるはずだ. 特に, _mm256_permute2f128_si256() の第3引数, _mm512_permutex2var_epi8() の第二引数については,疑似コードを見なければ理解しづらい.

_mm256_permute2f128_si256() の第三引数は,下位4byteが出力先の下位4byte,上位4byteが出力先の上位4byteに対応しており,第一引数と第二引数のAVXレジスタを128bitに区切ったセクションのうち,どこから値を取るかを決定する.

なお,AVX-512のインタリーブに関しては,AVXと同様,unpackを行った後,並び換えてもよい. その際,64bit毎に並び替えを行う _mm512_permutex2var_epi64() を用いるとよさそうだ. _mm512_permutex2var_epi8() は8bitごと, _mm512_permutex2var_epi64() は64bitごとに,第一引数と第三引数のAVX-512レジスタのうち,どちらの値を取るかを指定する命令である.

// $ g++ -std=gnu++14 -march=native -mavx512vbmi main.cpp -o main
#include <iostream>
#ifdef _MSC_VER
#  include <intrin.h>
#else
#  include <x86intrin.h>
#endif  // _MSC_VER


static inline void
showAVX512(const __m512i& v512) noexcept
{
  alignas(alignof(__m512i)) unsigned char z[sizeof(__m512i)] = {0};
  _mm512_store_si512(reinterpret_cast<__m512i*>(z), v512);
  for (const auto& b : z) {
    std::cout << static_cast<unsigned int>(b) << ' ';
  }
  std::cout << "\n" << std::endl;
}


int
main()
{
  // 各引数は64bit整数.各引数の2byte目より上位は全て0なので,上位は書いていない.
  static const __m512i LOIDX = _mm512_setr_epi64(0x00, 0x01, 0x08, 0x09, 0x02, 0x03, 0x0a, 0x0b);
  static const __m512i HIIDX = _mm512_setr_epi64(0x04, 0x05, 0x0c, 0x0d, 0x06, 0x07, 0x0e, 0x0f);
  alignas(alignof(__m512i)) unsigned char z[sizeof(__m512i)] = {0};

  for (std::size_t i = 0; i < sizeof(z) / sizeof(z[0]); i++) {
    z[i] = static_cast<unsigned char>(i);
  }

  __m512i v512 = _mm512_load_si512(reinterpret_cast<const __m512i*>(z));
  // unpackの結果を保存
  __m512i vlo = _mm512_unpacklo_epi8(v512, v512);
  __m512i vhi = _mm512_unpackhi_epi8(v512, v512);

  // _mm512_permutex2var_epi64() で整列を行う
  std::cout << "==================== permute vlo ====================" << std::endl;
  showAVX512(_mm512_permutex2var_epi64(vlo, LOIDX, vhi));

  std::cout << "==================== permute vhi ====================" << std::endl;
  showAVX512(_mm512_permutex2var_epi64(vlo, HIIDX, vhi));

  return 0;
}

まとめ

AVX,AVX-512の場合,unpack命令だけではレジスタ全体のインタリーブを行うことができない. unpack後に,並び換えを行う必要がある.

参考文献