koturnの日記

転職したい社会人2年生の技術系日記.ブログ上のコードはコピペ自由です.

C++のラムダで再帰する

はじめに

C++において,ラムダで再帰したいと考えることはたまにある. この記事ではラムダで再帰する手法をいくつか紹介する. 例として扱う再帰関数はフィボナッチ数列の関数(もっとも単純な実装)とする.

int
fib(int n) noexcept
{
  return n < 2 ? n : (fib(n - 1) + fib(n - 2));
}

生成コード

ラムダで再帰を行いたいと考えるのは変数のキャプチャをしつつ再帰をしたい場面であると思う. 例えば,メモ化再帰を行うと考えた場合,

  1. グローバル変数としてメモ化配列を用意する
  2. メモ化再帰用のクラスを用意する(メンバにメモ化配列を用意する)
  3. 再帰関数の引数にメモ化配列を与える

といった方法が考えられるが,グローバル変数や専用のクラス定義が必要であり,あまりやりたくない. 引数にメモ化配列を与えるとなると,全ての再帰呼び出し箇所の記述が冗長になってしまう. ラムダで再帰ができれば,ローカル変数にメモ化用の配列を用意してもキャプチャできるようになるため,スッキリすると思われる.

3つの手法

この記事ではラムダで再帰を実現する手法を3つ紹介する.

1. std::function による無名再帰

C++11でない限り,絶対に使うべきではない

実装としては以下の通り.

#include <functional>
#include <iostream>
#include <utility>


int
main()
{
  std::function<int(int)> fib = [&fib](int n) -> int {
    return n < 2 ? n : (fib(n - 1) + fib(n - 2));
  };
  auto result = fib(10);
  std::cout << result << std::endl;
}

生成コード

利点と欠点は以下の通り.

利点

  1. C++11でも使用可能

欠点

  1. std::function を用いるため,実行コストが大きい(速度)
  2. std::function を用いるため,再帰可能深度が小さい
  3. 無名にできない
  4. std::function とラムダそのものに引数等の返り値を書く必要があり冗長

なぜかC++関連の記事を見ていると,この手法を紹介している記事が目につくのだが,言うまでもなく std::function を用いているのがとにかく良くない. よく知られている通り,std::function は型消去を行うため,与えた関数をインライン展開できない,実行コストが大きいといった欠点を持つ. std::functionoperator() の実装,あるいは生成されるコードを見ればわかるように( std::function 自体の operator() はインライン展開されるので,g++で -S --verbose-asm のようなオプションを付けてアセンブリを吐かせ,呼び出し部分のコードを確認すれば見てとれる),

  1. 保持している関数が有効かどうか(関数を保持しているかどうか)をチェックするコードが operator() に含まれる
  2. 関数が有効でない場合の, std::bad_function_call 例外を投げるためのコードが含まれる

といった欠点がある.これらの理由により,1回の関数呼び出しに時間がかかるだけでなく,消費スタック量が大きく,再帰可能深度が小さくなる. そのため,C++11でない限りは使うべきではない. 強いて利点を挙げるならば,下準備に必要な関数およびクラスがないため,ゼロからコードを書く場合には楽という利点はあるかもしれない.

本記事の「ラムダで再帰」に限った話ではなく,std::function を使用するべきかどうかは慎重にならなければならない. autoやテンプレートの型パラメータを用いることによって,いかにラムダをラムダのまま保持するかを考えるべきだ.

2. ジェネリックラムダを用いる方法

実装としては以下の通り. 可変引数はとりあえず完全転送すればよい.

#include <iostream>
#include <utility>


template<typename F>
static inline constexpr decltype(auto)
fix(F&& f) noexcept
{
  return [f = std::forward<F>(f)](auto&&... args) {
    return f(f, std::forward<decltype(args)>(args)...);
  };
}


int
main()
{
  auto result = fix([](auto f, int n) -> int {
    return n < 2 ? n : (f(f, n - 1) + f(f, n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

利点と欠点は以下の通り.

利点

  1. 実行コストは小さい
  2. 無名再帰可能
  3. C++17以降ではコンパイル時計算可能なラムダによる再帰関数を記述可能

欠点

  1. C++14以降
  2. 再帰関数の引数にその再帰関数自身を与える必要があるため,関数本体の記述が冗長.

この手法はこの記事のものである. ジェネリックラムダを用いる必要があるため,C++14以降でなければならない. 2018年の今となっては困ることはないと思うが,例えば競プロのジャッジサーバがC++11までしか対応していない場面では利用できない.

std::function を用いる方法と違い,ラムダをラムダのまま保持しているので,関数呼び出しのコストは小さい. C++17以降では,ラムダが constexpr に対応するため,コンパイル時計算が可能となる. これは上記コードの auto resultconstexpr auto result に書き換えて,C++17対応のコンパイラコンパイルすれば確認できる.

3. ジェネリックラムダと operator() を用いる方法

実装としては以下の通り.

#include <iostream>
#include <utility>


template<typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#elif defined(__GNUC__) && __GNUC_PREREQ(3, 4)
__attribute__((warn_unused_result))
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint : F
{
public:
  explicit constexpr FixPoint(F&& f) noexcept
    : F(std::forward<F>(f))
  {}

  template<typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


template<typename F>
static inline constexpr decltype(auto)
makeFixPoint(F&& f) noexcept
{
  return FixPoint<F>{std::forward<F>(f)};
}


int
main()
{
  auto result = makeFixPoint([](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

利点

  1. 実行コストは小さい
  2. 無名再帰可能
  3. C++17以降ではコンパイル時計算可能なラムダによる再帰関数を記述可能

欠点

  1. C++14以降

手法としてはこれがベストである. 2つ目の手法異なり,再帰関数の引数にその再帰関数自身を与える必要もない.

類似する実装として,こういうものがあるが,これでは2つ目の手法と変わりない. ラムダの第一引数に FixPoint クラスのオブジェクトを与える(すなわち, *this を与える)ように, FixPoint::operator() を定義するのがポイントである.

C++17以降では constexpr なラムダでコンパイル時計算可能であるが,C++14でも constexpr な関数オブジェクトを用意すれば一応コンパイル時計算は可能である. ラムダではないので,本記事の趣旨とは外れるが....

#include <iostream>
#include <utility>


template<typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#elif defined(__GNUC__) && __GNUC_PREREQ(3, 4)
__attribute__((warn_unused_result))
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint : F
{
public:
  explicit constexpr FixPoint(F&& f) noexcept
    : F(std::forward<F>(f))
  {}

  template<typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


template<typename F>
static inline constexpr decltype(auto)
makeFixPoint(F&& f) noexcept
{
  return FixPoint<F>{std::forward<F>(f)};
}


class Fib
{
public:
  template<typename F>
  constexpr int
  operator()(F&& f, int n) const noexcept
  {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  }
};  // class Fib


int
main()
{
  constexpr auto result = makeFixPoint(Fib())(10);
  std::cout << result << std::endl;
}

C++17以降ではクラステンプレートのテンプレートパラメータ推論が可能になったため,推論のためのmake関数も不要になる.

#include <iostream>
#include <utility>


template<typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#elif defined(__GNUC__) && __GNUC_PREREQ(3, 4)
__attribute__((warn_unused_result))
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint : F
{
public:
  explicit constexpr FixPoint(F&& f) noexcept
    : F(std::forward<F>(f))
  {}

  template<typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


int
main()
{
  auto result = FixPoint{[](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  }}(10);
  std::cout << result << std::endl;
}

おまけの話題

FixPoint の実装

3つ目の手法ではラムダを継承して FixPoint を実装したが,別にメンバーにラムダを持つ実装でもよい. 僕個人としては,ラムダを継承する方がスッキリしていると感じたのと,次の章のオーバーロード実装を考えると継承の方を紹介するのが自然と考えた.

実は継承を用いる方法はいなむ神に教えていただいた

#include <iostream>
#include <utility>


template<typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#elif defined(__GNUC__) && __GNUC_PREREQ(3, 4)
__attribute__((warn_unused_result))
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint
{
public:
  explicit constexpr FixPoint(F&& f) noexcept
    : m_f(std::forward<F>(f))
  {}

  template<typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
  {
    return m_f(*this, std::forward<Args>(args)...);
  }

private:
  const F m_f;
};  // class FixPoint


int
main()
{
  auto result = FixPoint{[](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  }}(10);
  std::cout << result << std::endl;
}

再帰ラムダのオーバーロード

これもいなむ神に教えていただいたものである. C++17のusing宣言のパック展開を用いると,オーバーロードが可能となる.

#include <iostream>
#include <utility>


template<typename... Fs>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#elif defined(__GNUC__) && __GNUC_PREREQ(3, 4)
__attribute__((warn_unused_result))
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint : Fs...
{
  using Fs::operator()...;

public:
  explicit constexpr FixPoint(Fs&&... fs) noexcept
    : Fs(std::forward<Fs>(fs))...
  {}

  template<typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
  {
    return operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


int
main()
{
  auto fib = FixPoint{
    [](auto f, int n) -> int {
      return n < 2 ? n : (f(n - 1) + f(n - 2));
    },
    [](auto f, double n) -> double {
      return n < 2 ? n : (f(n - 1) + f(n - 2));
    }};
  std::cout << fib(10) << std::endl;
  std::cout << fib(10.0) << std::endl;
}

上記の例では double 型のオーバーロードを用意しているが,単にオーバーロード可能ということを示すためのものにすぎない.

追記 (2018/06/10)

いなむ神にラムダのみで再帰する方法を提示していただいた. ラムダのみでYコンビネータとZコンビネータを実現しており,非常に面白い. この記事で散々取り扱っているフィボナッチ数列の関数を,上記記事のYコンビネータとZコンビネータを用いて実装すると,それぞれ以下のようになる.

ラムダオンリーのYコンビネータによるフィボナッチ数列の関数の実装

#include <iostream>
#include <utility>


int
main()
{
  auto result = [g=[](auto f, int n) -> int {
    return n < 2 ? n : (f(f, n - 1) + f(f, n - 2));
  }](auto&&... args) {
    return g(g, std::forward<decltype(args)>(args)...);
  }(10);
  std::cout << result << std::endl;
}

生成コード

なお,MSVCでコンパイルしたとき,上記の実装だとMSVCがポンコツなためICEとなってコンパイルが通らないので,以下の実装を用いる方が安全かもしれない. MSVCでは初期化キャプチャでラムダをキャプチャすると,ICEとなるようだ.

#include <iostream>
#include <utility>


int
main()
{
  auto ret = [](auto f) {
    return [=](auto&&... args) {
      return f(f, std::forward<decltype(args)>(args)...);
    };
  }([](auto f, int n) -> int {
    return n < 2 ? n : f(f, n - 1) + f(f, n - 2);
  })(10);
  std::cout << ret << std::endl;
}

f:id:koturn:20180621043652p:plain

ラムダオンリーのZコンビネータによるフィボナッチ数列の関数の実装

#include <iostream>
#include <utility>


int
main()
{
  auto result = [](auto f) {
    return [=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    }([=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    });
  }([](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

C++に不慣れな人にとっては,上記記事のZコンビネータほど複雑であれば実行効率が気になるところかもしれない(ラムダを返すラムダが云々で何となく遅そうなイメージ)が,生成コードを見れば要らぬ心配であることがわかると思う. そして,上記のYコンビネータ,Zコンビネータから生成されるコードは記事の本編で紹介した関数やクラスを用いる方法と同一のコードが生成されていることもわかる. C++はよしなにラムダのインライン展開を行ってくれるのだ!

僕はいなむ神のYコンビネータとZコンビネータの実装に感銘を受け,VimプラグインであるShougo/neosnippet.vimスニペットファイルに下記のスニペットを追加した.

snippet ycombinator
alias ycomb
  [](auto f) {
    return [=](auto&&... args) {
      return f(f, std::forward<decltype(args)>(args)...);
    };
  }([${1:&}](auto ${2:f}, ${3:#:args...}) {
    ${0}
  })

snippet zcombinator
alias zcomb
  [](auto f) {
    return [=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    }([=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    });
  }([${1:&}](auto ${2:f}, ${3:#:args...}) {
    ${0};
  })

これでいつでも再帰するラムダを楽に書ける. なお,ラムダを受ける引数はユニヴァーサル参照等の参照系で受けると単にautoで受けるよりも数命令多く,消費スタック量もわずかに多いので,いなむ神の記事の通り,autoで受けたり,コピーキャプチャを行う方が良い.

※Zコンビネータに関して:g++とMSVCではコンパイルは通るが,clang++ではコンパイルが通らないようなので気を付けてほしい.

速度比較

実際に実行時間を計測してみることにしよう. 再帰に関して,C++では様々な方法で実現することが可能であるが,この記事では,以下の4つのような関数オブジェクトによる再帰手法も紹介している. これらも含めて,実行時間を計測することにする.

class Fibonacci01
{
public:
  constexpr int
  operator()(int n) const noexcept
  {
    return n < 2 ? n : (Fibonacci01{}(n - 1) + Fibonacci01{}(n - 2));
  }
};  // struct Fibonacci01


class Fibonacci02
{
public:
  constexpr int
  operator()(int n) const noexcept
  {
    return n < 2 ? n : ((*this)(n - 1) + (*this)(n - 2));
  }
};  // struct Fibonacci02


class Fibonacci03
{
public:
  constexpr int
  operator()(int n) const noexcept
  {
    return n < 2 ? n : (operator()(n - 1) + operator()(n - 2));
  }
};  // struct Fibonacci03


class Fibonacci04
{
public:
  constexpr int
  operator()(Fibonacci04 f, int n) const noexcept
  {
    return n < 2 ? n : (f(f, n - 1) + f(f, n - 2));
  }
};  // struct Fibonacci04

この内 Fibonacci01 は本記事で紹介した FixPoint クラスによる手法と生成コードは同一であるため除外する. また Fibonacci02Fibonacci03 の生成コードは同一であり,比較する意味はないため, Fibonacci02 は除外する. 他にも,本記事において同一のコードが生成されたもの同士は意味が無いので,ユニークになるように除外して計測する.

計測に際して,以下のコードを用意した.

記事中のコードではフィボナッチ数の型を一律 int としていたが,計測コードでは std::uint64_t とした. x64環境であれば,64bitレジスタに乗せるだけなので,32bitから64bitに変更しても大きな影響はない. Wandboxは実行時間制限があるため,fib(42) をそれぞれ1回計算するコードにしてあるが,手元では fib(45) をそれぞれ5回計算するのにかかった平均時間を計測した.

使用したコンパイラは g++ 7.3.0 であり,コンパイルオプションは下記の通りである.

$ g++ -O2 -march=native main.cpp -o main.out
$ ./main.out

計測結果は以下の表に示す通りである.

手法 実行時間
普通の関数 2832.08 ms
std::function<> 7481.06 ms
fix() 省略
FixPoint クラス 2475.44 ms
FixPoint クラス (参照受け) 2715.21 ms
Yコンビネータ 省略
Zコンビネータ 省略
Fibonacci01 省略
Fibonacci02 省略
Fibonacci03 2903.60 ms
Fibonacci04 省略

つまり,

FixPoint クラス(追記のZコンビネータも含む) < FixPoint クラス(本体のラムダにおいて参照で受ける) < 普通の関数 < Fibonacci03 < std::function

という結果になった.

表の結果からも分かるように std::function は特に遅い. 普通の関数で実装したよりも3倍の実行時間を必要とする. そして,本記事で紹介したようなラムダで再帰を行う手法は,通常の関数で再帰を行う手法よりも高速である.

ただし,ラムダや関数オブジェクトを引数で受けるときに参照で受けた場合,わずかに遅くなる. 実際,参照で受ける場合のコードを確認してみると,余分なlea命令が生成されていた. 関数オブジェクトによる再帰の記事にもあるように,ラムダや関数オブジェクトはコピーで受け渡す方,あるいは都度生成する方がよいのだろう.

追記 (2019/03/18)

速度比較がg++だけであったので,MSVCおよびclang++でも速度比較を行った. また,MSVCやclang++では再帰関数の呼出自体が省略されてしまうことがあったので,前述の計測コードに少し手を加え,以下のようにした.

このコードはkoturn/CppRecursiveLambdaにも同一のものを置いている. 手元で試してみたい方は上記のリポジトリを利用するとよいだろう.

前と同じく fib(45) をそれぞれ5回計算するのにかかった平均時間を計測し,結果は以下の表のようになった. g++では同一コードが生成されていたものもコンパイラが違えば別のコードを生成していたので,全ての再帰実装の結果を示す.

手法 g++ clang++ MSVC
普通の関数 3152.51 ms 6570.66 ms 5537.51 ms
std::function<> 8710.48 ms 8762.2 ms 9797.05 ms
fix() 2639.57 ms 6734.75 ms 4348.72 ms
FixPoint クラス 2638.15 ms 6568.66 ms 6718.6 ms
FixPoint クラス (参照受け) 2647.43 ms 7604.38 ms 6430.34 ms
Yコンビネータ 2640.64 ms 7414.25 ms 4308 ms
Zコンビネータ 2637.05 ms ※1 22183.6 ms
Fibonacci01 2937.63 ms 6181.37 ms 5448 ms
Fibonacci02 3268.61 ms 7696.45 ms 6460 ms
Fibonacci03 3259.33 ms 7299.16 ms 6422.45 ms
Fibonacci04 2619.63 ms 6284.43 ms 6661.06 ms

※1:コンパイルできないため,記録無し

g++とclang++はおおむね同じような結果を示しているが,MSVCは全く異なった結果を示している. 特にZコンビネータのパフォーマンスが std::function<> よりも悪いことに驚かされる. MSVCの場合,fix() や Yコンビネータが最も高速であり,次に普通の関数および Fibonacci01が続き,次に FixPoint クラスが続くようだ. このことから,MSVCはあまりラムダのインライン展開等が得意ではないのではないかと思われる.

std::function による再帰がどのコンパイラでも遅いのは予想通りであった.

また,clang++がg++の2倍に近い実行時間となっているのは,単に1回の再帰につき2回フィボナッチ関数を呼び出す部分の最適化ができていないためであるが,clang++の結果のみでいえばどの再帰スタイルでも同じ最適化になっているので,今回の再帰の書き方によるパフォーマンスの違いに関する結果とは別のものである.

MSVCを無視するなら,個人的には一番使い勝手のよい FixPoint クラスを推したいと思う.

まとめ

  • C++11でない限り,ラムダの再帰std::function を用いるべきでない(約3倍の実行時間になる)
  • ジェネリックラムダと operator() を用いた実装を使うべき

いくつかの実装を紹介したが,C++14のジェネリックラムダによる再帰はいずれも同一のコード生成がなされ,また無駄な処理がなく,普通の関数による再帰よりも高速に動作することがわかった.

参考文献