koturnの日記

普通の人です.ブログ上のコードはコピペ自由です.

C++のラムダで再帰する

はじめに

C++において,ラムダで再帰したいと考えることはたまにある. この記事ではラムダで再帰する手法をいくつか紹介する. 例として扱う再帰関数はフィボナッチ数列の関数(もっとも単純な実装)とする.

int
fib(int n) noexcept
{
  return n < 2 ? n : (fib(n - 1) + fib(n - 2));
}

生成コード

ラムダで再帰を行いたいと考えるのは変数のキャプチャをしつつ再帰をしたい場面であると思う. 例えば,メモ化再帰を行うと考えた場合,

  1. グローバル変数としてメモ化配列を用意する
  2. メモ化再帰用のクラスを用意する(メンバにメモ化配列を用意する)
  3. 再帰関数の引数にメモ化配列を与える

といった方法が考えられるが,グローバル変数や専用のクラス定義が必要であり,あまりやりたくない. 引数にメモ化配列を与えるとなると,全ての再帰呼び出し箇所の記述が冗長になってしまう. ラムダで再帰ができれば,ローカル変数にメモ化用の配列を用意してもキャプチャできるようになるため,スッキリすると思われる.

3つの手法

この記事ではラムダで再帰を実現する手法を3つ紹介する.

1. std::function による無名再帰

C++11でない限り,絶対に使うべきではない

実装としては以下の通り.

#include <functional>
#include <iostream>
#include <utility>


int
main()
{
  std::function<int(int)> fib = [&fib](int n) -> int {
    return n < 2 ? n : (fib(n - 1) + fib(n - 2));
  };
  auto result = fib(10);
  std::cout << result << std::endl;
}

生成コード

利点と欠点は以下の通り.

利点

  1. C++11でも使用可能

欠点

  1. std::function を用いるため,実行コストが大きい(速度)
  2. std::function を用いるため,再帰可能深度が小さい
  3. 無名にできない
  4. std::function とラムダそのものに引数等の返り値を書く必要があり冗長

なぜかC++関連の記事を見ていると,この手法を紹介している記事が目につくのだが,言うまでもなく std::function を用いているのがとにかく良くない. よく知られている通り,std::function は型消去を行うため,与えた関数をインライン展開できない,実行コストが大きいといった欠点を持つ. std::functionoperator() の実装,あるいは生成されるコードを見ればわかるように( std::function 自体の operator() はインライン展開されるので,g++で -S --verbose-asm のようなオプションを付けてアセンブリを吐かせ,呼び出し部分のコードを確認すれば見てとれる),

  1. 保持している関数が有効かどうか(関数を保持しているかどうか)をチェックするコードが operator() に含まれる
  2. 関数が有効でない場合の, std::bad_function_call 例外を投げるためのコードが含まれる

といった欠点がある.これらの理由により,1回の関数呼び出しに時間がかかるだけでなく,消費スタック量が大きく,再帰可能深度が小さくなる. そのため,C++11でない限りは使うべきではない. 強いて利点を挙げるならば,下準備に必要な関数およびクラスがないため,ゼロからコードを書く場合には楽という利点はあるかもしれない.

本記事の「ラムダで再帰」に限った話ではなく,std::function を使用するべきかどうかは慎重にならなければならない. autoやテンプレートの型パラメータを用いることによって,いかにラムダをラムダのまま保持するかを考えるべきだ.

2. ジェネリックラムダを用いる方法

実装としては以下の通り. 可変引数はとりあえず完全転送すればよい.

なお, nodiscard 属性を付けているのは,返り値を利用しない場面は考えられないためである.

#include <iostream>
#include <utility>


namespace
{
template <typename F>
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#elif defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)
__attribute__((warn_unused_result))
#elif defined(_MSC_VER) && _MSC_VER >= 1700 && defined(_Check_return_)
_Check_return_
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
inline constexpr decltype(auto)
fix(F&& f) noexcept
{
  return [f = std::forward<F>(f)](auto&&... args) {
    return f(f, std::forward<decltype(args)>(args)...);
  };
}
}  // namespace


int
main()
{
  auto result = fix([](auto f, int n) -> int {
    return n < 2 ? n : (f(f, n - 1) + f(f, n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

利点と欠点は以下の通り.

利点

  1. 実行コストは小さい
  2. 無名再帰可能
  3. C++17以降ではコンパイル時計算可能なラムダによる再帰関数を記述可能

欠点

  1. C++14以降
  2. 再帰関数の引数にその再帰関数自身を与える必要があるため,関数本体の記述が冗長.

この手法はこの記事のものである. ジェネリックラムダを用いる必要があるため,C++14以降でなければならない. 2018年の今となっては困ることはないと思うが,例えば競プロのジャッジサーバがC++11までしか対応していない場面では利用できない.

std::function を用いる方法と違い,ラムダをラムダのまま保持しているので,関数呼び出しのコストは小さい. C++17以降では,ラムダが constexpr に対応するため,コンパイル時計算が可能となる. これは上記コードの auto resultconstexpr auto result に書き換えて,C++17対応のコンパイラコンパイルすれば確認できる.

3. ジェネリックラムダと operator() を用いる方法

実装としては以下の通り.

#include <iostream>
#include <utility>


template <typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint final : private F
{
public:
  template <typename G>
  explicit constexpr FixPoint(G&& g) noexcept
    : F{std::forward<G>(g)}
  {}

  template <typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
    noexcept(noexcept(F::operator()(std::declval<FixPoint>(), std::declval<Args>()...)))
#endif  // !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


#if defined(__cpp_deduction_guides)
template <typename F>
FixPoint(F&&)
  -> FixPoint<std::decay_t<F>>;
#endif  // defined(__cpp_deduction_guides)


namespace
{
template <typename F>
#if !defined(__has_cpp_attribute) || !__has_cpp_attribute(nodiscard)
#  if defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)
__attribute__((warn_unused_result))
#  elif defined(_MSC_VER) && _MSC_VER >= 1700 && defined(_Check_return_)
_Check_return_
#  endif  // defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)
#endif  // !defined(__has_cpp_attribute) || !__has_cpp_attribute(nodiscard)
inline constexpr decltype(auto)
makeFixPoint(F&& f) noexcept
{
  return FixPoint<std::decay_t<F>>{std::forward<std::decay_t<F>>(f)};
}
}  // namespace


int
main()
{
  auto result = makeFixPoint([](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

利点

  1. 実行コストは小さい
  2. 無名再帰可能
  3. C++17以降ではコンパイル時計算可能なラムダによる再帰関数を記述可能

欠点

  1. C++14以降

手法としてはこれがベストである. 2つ目の手法異なり,再帰関数の引数にその再帰関数自身を与える必要もない.

類似する実装として,こういうものがあるが,これでは2つ目の手法と変わりない. ラムダの第一引数に FixPoint クラスのオブジェクトを与える(すなわち, *this を与える)ように, FixPoint::operator() を定義するのがポイントである.

ちなみに,実装をもっと簡略化し,コアとなる部分だけ抜き出すと以下のようになる.

#include <iostream>
#include <utility>


template <typename F>
class
FixPoint : private F
{
public:
  explicit constexpr FixPoint(F&& f) noexcept
    : F{std::forward<F>(f)}
  {}

  template <typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


namespace
{
template <typename F>
inline constexpr decltype(auto)
makeFixPoint(F&& f) noexcept
{
  return FixPoint<F>{std::forward<F>(f)};
}
}  // namespace


int
main()
{
  auto result = makeFixPoint([](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

ただし,この実装では

  • makeFixPointFixPoint のコンストラクタで左辺値は受け取れない
  • 戻り値を無視してもワーニングが出ない

という欠点がある. 左辺値が受け取れないため,例えば次のコードがエラーになる

int
main()
{
  auto body = [](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  };
  auto result = makeFixPoint(body)(10);
  std::cout << result << std::endl;
}

ちゃんと std::decayかますことによって,ムーブすべきときにムーブし,コピーすべきときにコピーできるようにしておくべきである.

C++17以降では constexpr なラムダでコンパイル時計算可能であるが,C++14でも constexpr な関数オブジェクトを用意すれば一応コンパイル時計算は可能である. ラムダではないので,本記事の趣旨とは外れるが....

#include <iostream>
#include <utility>


template <typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint final : private F
{
public:
  template <typename G>
  explicit constexpr FixPoint(G&& g) noexcept
    : F{std::forward<G>(g)}
  {}

  template <typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
    noexcept(noexcept(F::operator()(std::declval<FixPoint>(), std::declval<Args>()...)))
#endif  // !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


#if defined(__cpp_deduction_guides)
template <typename F>
FixPoint(F&&)
  -> FixPoint<std::decay_t<F>>;
#endif  // defined(__cpp_deduction_guides)


namespace
{
template <typename F>
#if !defined(__has_cpp_attribute) || !__has_cpp_attribute(nodiscard)
#  if defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)
__attribute__((warn_unused_result))
#  elif defined(_MSC_VER) && _MSC_VER >= 1700 && defined(_Check_return_)
_Check_return_
#  endif  // defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)
#endif  // !defined(__has_cpp_attribute) || !__has_cpp_attribute(nodiscard)
inline constexpr decltype(auto)
makeFixPoint(F&& f) noexcept
{
  return FixPoint<std::decay_t<F>>{std::forward<std::decay_t<F>>(f)};
}
}  // namespace


class Fib
{
public:
  template <typename F>
  constexpr int
  operator()(F&& f, int n) const noexcept
  {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  }
};  // class Fib


int
main()
{
  constexpr auto result = makeFixPoint(Fib{})(10);
  std::cout << result << std::endl;
}

生成コード

C++17以降ではクラステンプレートのテンプレートパラメータ推論が可能になったため,推論のためのmake関数も不要になる.

#include <iostream>
#include <utility>


template <typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint final : private F
{
public:
  template <typename G>
  explicit constexpr FixPoint(G&& g) noexcept
    : F{std::forward<G>(g)}
  {}

  template <typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
    noexcept(noexcept(F::operator()(std::declval<FixPoint>(), std::declval<Args>()...)))
#endif  // !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


#if defined(__cpp_deduction_guides)
template <typename F>
FixPoint(F&&)
  -> FixPoint<std::decay_t<F>>;
#endif  // defined(__cpp_deduction_guides)


int
main()
{
  auto result = FixPoint{[](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  }}(10);
  std::cout << result << std::endl;
}

おまけの話題

FixPoint の実装

3つ目の手法ではラムダを継承して FixPoint を実装したが,別にメンバーにラムダを持つ実装でもよい. 僕個人としては,ラムダを継承する方がスッキリしていると感じたのと,次の章のオーバーロード実装を考えると継承の方を紹介するのが自然と考えた.

実は先に紹介した継承を用いる方法はいなむ神に教えていただいた

#include <iostream>
#include <utility>


template <typename F>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint
{
private:
  const F m_f;

public:
  template <typename G>
  explicit constexpr FixPoint(G&& g) noexcept
    : m_f{std::forward<G>(g)}
  {}

  template <typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
    noexcept(noexcept(m_f(std::declval<FixPoint>(), std::declval<Args>()...)))
#endif  // !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
  {
    return m_f(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


#if defined(__cpp_deduction_guides)
template <typename F>
FixPoint(F&&)
  -> FixPoint<std::decay_t<F>>;
#endif  // defined(__cpp_deduction_guides)


int
main()
{
  auto result = FixPoint{[](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  }}(10);
  std::cout << result << std::endl;
}

再帰ラムダのオーバーロード

これもいなむ神に教えていただいたものである. C++17のusing宣言のパック展開を用いると,オーバーロードが可能となる.

#include <iostream>
#include <utility>


template <typename... Fs>
class
#if defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
[[nodiscard]]
#endif  // defined(__has_cpp_attribute) && __has_cpp_attribute(nodiscard)
FixPoint final : private Fs...
{
  using Fs::operator()...;

public:
  template <
    typename... Gs,
    std::enable_if_t<sizeof...(Fs) == sizeof...(Gs), std::nullptr_t> = nullptr
  >
  explicit constexpr FixPoint(Gs&&... gs) noexcept
    : Fs{std::forward<Gs>(gs)}...
  {}

  template <typename... Args>
  constexpr decltype(auto)
  operator()(Args&&... args) const
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
    noexcept(noexcept(operator()(std::declval<FixPoint>(), std::declval<Args>()...)))
#endif  // !defined(__GNUC__) || defined(__clang__) || __GNUC__ >= 9
  {
    return operator()(*this, std::forward<Args>(args)...);
  }
};  // class FixPoint


#if defined(__cpp_deduction_guides)
template <
  typename... Fs,
  std::enable_if_t<sizeof...(Fs) != 0, std::nullptr_t> = nullptr
>
FixPoint(Fs&&...)
  -> FixPoint<std::decay_t<Fs>...>;
#endif  // defined(__cpp_deduction_guides)


int
main()
{
  auto fib = FixPoint{
    [](auto f, int n) -> int {
      return n < 2 ? n : (f(n - 1) + f(n - 2));
    },
    [](auto f, double n) -> double {
      return n < 2 ? n : (f(n - 1) + f(n - 2));
    }};
  std::cout << fib(10) << std::endl;
  std::cout << fib(10.0) << std::endl;
}

上記の例では double 型のオーバーロードを用意しているが,単にオーバーロード可能ということを示すためのものにすぎない.

なお,Deducation Guides,およびコンストラクタにおいてSFINAEしているのは,g++の場合,これが無ければコンパイルエラーになるためである. clangおよびMSVCでは無くてもコンパイルが通るのだが....

Boost.Hanaによる実装(追記(2020/05/23))

実はBoostにも FixPoint クラスと同様のものの実装がある.

#include <iostream>
#include <boost/hana/functional/fix.hpp>


int
main()
{
  auto result = boost::hana::fix([](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

実装としては以下の形となっている.

#include <iostream>
#include <type_traits>
#include <utility>


namespace detail
{

template <template <typename ...> class T>
struct create
{
  template <typename ...X>
  constexpr T<typename std::decay<X>::type...>
  operator()(X&& ...x) const {
    return T<typename std::decay<X>::type...>{
      static_cast<X&&>(x)...
    };
  }
};  // struct create

}  // namespace detail


template <typename F>
struct fix_t;

constexpr detail::create<fix_t> fix{};

template <typename F>
struct fix_t
{
  F f;

  template <typename ...X>
  constexpr decltype(auto) operator()(X&& ...x) const&
  { return f(fix(f), static_cast<X&&>(x)...); }

  template <typename ...X>
  constexpr decltype(auto) operator()(X&& ...x) &
  { return f(fix(f), static_cast<X&&>(x)...); }

  template <typename ...X>
  constexpr decltype(auto) operator()(X&& ...x) &&
  { return std::move(f)(fix(f), static_cast<X&&>(x)...); }
};


int
main()
{
  auto result = fix([](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

追記 (2018/06/10)

いなむ神にラムダのみで再帰する方法を提示していただいた. ラムダのみでYコンビネータとZコンビネータを実現しており,非常に面白い. この記事で散々取り扱っているフィボナッチ数列の関数を,上記記事のYコンビネータとZコンビネータを用いて実装すると,それぞれ以下のようになる.

ラムダオンリーのYコンビネータによるフィボナッチ数列の関数の実装

#include <iostream>
#include <utility>


int
main()
{
  auto result = [g=[](auto f, int n) -> int {
    return n < 2 ? n : (f(f, n - 1) + f(f, n - 2));
  }](auto&&... args) {
    return g(g, std::forward<decltype(args)>(args)...);
  }(10);
  std::cout << result << std::endl;
}

生成コード

なお,MSVCでコンパイルしたとき,上記の実装だとMSVCがポンコツなためICEとなってコンパイルが通らないので,以下の実装を用いる方が安全かもしれない. MSVCでは初期化キャプチャでラムダをキャプチャすると,ICEとなるようだ.

#include <iostream>
#include <utility>


int
main()
{
  auto ret = [](auto f) {
    return [=](auto&&... args) {
      return f(f, std::forward<decltype(args)>(args)...);
    };
  }([](auto f, int n) -> int {
    return n < 2 ? n : f(f, n - 1) + f(f, n - 2);
  })(10);
  std::cout << ret << std::endl;
}

f:id:koturn:20180621043652p:plain

ラムダオンリーのZコンビネータによるフィボナッチ数列の関数の実装

#include <iostream>
#include <utility>


int
main()
{
  auto result = [](auto f) {
    return [=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    }([=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    });
  }([](auto f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

生成コード

C++に不慣れな人にとっては,上記記事のZコンビネータほど複雑であれば実行効率が気になるところかもしれない(ラムダを返すラムダが云々で何となく遅そうなイメージ)が,生成コードを見れば要らぬ心配であることがわかると思う. そして,上記のYコンビネータ,Zコンビネータから生成されるコードは記事の本編で紹介した関数やクラスを用いる方法と同一のコードが生成されていることもわかる. C++はよしなにラムダのインライン展開を行ってくれるのだ!

僕はいなむ神のYコンビネータとZコンビネータの実装に感銘を受け,VimプラグインであるShougo/neosnippet.vimスニペットファイルに下記のスニペットを追加した.

snippet ycombinator
alias ycomb
  [](auto f) {
    return [=](auto&&... args) {
      return f(f, std::forward<decltype(args)>(args)...);
    };
  }([${1:&}](auto ${2:f}, ${3:#:args...}) {
    ${0}
  })

snippet zcombinator
alias zcomb
  [](auto f) {
    return [=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    }([=](auto g) {
      return [=](auto&&... args) {
        return f(g(g), std::forward<decltype(args)>(args)...);
      };
    });
  }([${1:&}](auto ${2:f}, ${3:#:args...}) {
    ${0};
  })

これでいつでも再帰するラムダを楽に書ける. なお,ラムダを受ける引数はユニヴァーサル参照等の参照系で受けると単にautoで受けるよりも数命令多く,消費スタック量もわずかに多いので,いなむ神の記事の通り,autoで受けたり,コピーキャプチャを行う方が良い.

※Zコンビネータに関して:g++とMSVCではコンパイルは通るが,clang++ではコンパイルが通らないようなので気を付けてほしい.

速度比較

実際に実行時間を計測してみることにしよう. 再帰に関して,C++では様々な方法で実現することが可能であるが,この記事では,以下の4つのような関数オブジェクトによる再帰手法も紹介している. これらも含めて,実行時間を計測することにする.

class Fibonacci01
{
public:
  constexpr int
  operator()(int n) const noexcept
  {
    return n < 2 ? n : (Fibonacci01{}(n - 1) + Fibonacci01{}(n - 2));
  }
};  // struct Fibonacci01


class Fibonacci02
{
public:
  constexpr int
  operator()(int n) const noexcept
  {
    return n < 2 ? n : ((*this)(n - 1) + (*this)(n - 2));
  }
};  // struct Fibonacci02


class Fibonacci03
{
public:
  constexpr int
  operator()(int n) const noexcept
  {
    return n < 2 ? n : (operator()(n - 1) + operator()(n - 2));
  }
};  // struct Fibonacci03


class Fibonacci04
{
public:
  constexpr int
  operator()(Fibonacci04 f, int n) const noexcept
  {
    return n < 2 ? n : (f(f, n - 1) + f(f, n - 2));
  }
};  // struct Fibonacci04

この内 Fibonacci01 は本記事で紹介した FixPoint クラスによる手法と生成コードは同一であるため除外する. また Fibonacci02Fibonacci03 の生成コードは同一であり,比較する意味はないため, Fibonacci02 は除外する. 他にも,本記事において同一のコードが生成されたもの同士は意味が無いので,ユニークになるように除外して計測する.

計測に際して,以下のコードを用意した.

記事中のコードではフィボナッチ数の型を一律 int としていたが,計測コードでは std::uint64_t とした. x64環境であれば,64bitレジスタに乗せるだけなので,32bitから64bitに変更しても大きな影響はない. Wandboxは実行時間制限があるため,fib(42) をそれぞれ1回計算するコードにしてあるが,手元では fib(45) をそれぞれ5回計算するのにかかった平均時間を計測した.

使用したコンパイラは g++ 7.3.0 であり,コンパイルオプションは下記の通りである.

$ g++ -O2 -march=native main.cpp -o main.out
$ ./main.out

計測結果は以下の表に示す通りである.

手法 実行時間
普通の関数 2832.08 ms
std::function<> 7481.06 ms
fix() 省略
FixPoint クラス 2475.44 ms
FixPoint クラス (参照受け) 2715.21 ms
Yコンビネータ 省略
Zコンビネータ 省略
Fibonacci01 省略
Fibonacci02 省略
Fibonacci03 2903.60 ms
Fibonacci04 省略

つまり,

FixPoint クラス(追記のZコンビネータも含む) < FixPoint クラス(本体のラムダにおいて参照で受ける) < 普通の関数 < Fibonacci03 < std::function

という結果になった.

表の結果からも分かるように std::function は特に遅い. 普通の関数で実装したよりも3倍の実行時間を必要とする. そして,本記事で紹介したようなラムダで再帰を行う手法は,通常の関数で再帰を行う手法よりも高速である.

ただし,ラムダや関数オブジェクトを引数で受けるときに参照で受けた場合,わずかに遅くなる. 実際,参照で受ける場合のコードを確認してみると,余分なlea命令が生成されていた. 関数オブジェクトによる再帰の記事にもあるように,ラムダや関数オブジェクトはコピーで受け渡す方,あるいは都度生成する方がよいのだろう.

ただし,これはフィボナッチ数列の計算という外部の変数のキャプチャを必要としない例だからであり,値渡しの引数で受け取ると通常はラムダ式のコピーが発生する. これが問題になるのは,キャプチャを行っているときである. ラムダ式のキャプチャは,関数オブジェクトのメンバと同等なものであるので,ラムダ式のコピーが1回行われると,そのラムダ式がキャプチャした変数全てのコピーも行われる. intdouble といったプリミティブな型の変数1つのコピーキャプチャや,任意の型の変数の参照キャプチャ1つであれば,コピーコストとしてはほぼ問題が無いと思われるが,それ以外の場合(std::vector 等のコピーに際してメモリアロケーションが関係するオブジェクトのコピーキャプチャや参照キャプチャであっても多数のキャプチャを行っている場合)は問題になる. 基本的には auto&&const auto& 等を使って,参照で受ける方がよいと思う.

追記 (2019/03/18,2020/05/23)

速度比較がg++だけであったので,MSVCおよびclang++でも速度比較を行った. また,MSVCやclang++では再帰関数の呼出自体が省略されてしまうことがあったので,前述の計測コードに少し手を加え,以下のようにした.

計測コードはkoturn/CppRecursiveLambdaにも同一のものを置いている. 手元で試してみたい方は上記のリポジトリを利用するとよいだろう.

前と同じく fib(45) をそれぞれ5回計算するのにかかった平均時間を計測し,結果は以下の表のようになった. g++では同一コードが生成されていたものもコンパイラが違えば別のコードを生成していたので,全ての再帰実装の結果を示す.

手法 g++ clang++ MSVC
普通の関数 3042.32 ms 5011.04 ms 4427.97 ms
std::function<> 5965.21 ms 5869.21 ms 6658.59 ms
fix() 2474.56 ms 4900.02 ms 5001.03 ms
FixPoint クラス 2471.09 ms 4899.59 ms 4970.68 ms
FixPoint クラス (参照受け) 2661.03 ms 4785.61 ms 4923.05 ms
Boost.Hanaの fix 実装 2469.58 ms 4902.19 ms 4737.15 ms
Boost.Hanaの fix 実装(参照受け) 2711.37 ms 4790.04 ms 4960.24 ms
Yコンビネータ 2387.47 ms 5271.59 ms 4997.06 ms
Zコンビネータ 2468.44 ms ※1 15567.03 ms
Fibonacci01 2382.15 ms 4461.81 ms 4230.46 ms
Fibonacci02 3496.07 ms 4784.06 ms 4511.31 ms
Fibonacci03 3491.11 ms 4784.08 ms 4510.35 ms
Fibonacci04 2376.32 ms 4669.02 ms 4983.88 ms

※1:コンパイルできないため,記録無し

g++とclang++はおおむね同じような結果を示しているが,MSVCは全く異なった結果を示している. 特にZコンビネータのパフォーマンスが std::function<> よりも悪いことに驚かされる. MSVCの場合,fix() や Yコンビネータが最も高速であり,次に普通の関数および Fibonacci01が続き,次に FixPoint クラスが続くようだ. このことから,MSVCはあまりラムダのインライン展開等が得意ではないのではないかと思われる.

std::function による再帰がどのコンパイラでも遅いのは予想通りであった.

また,clang++がg++の2倍に近い実行時間となっているのは,単に1回の再帰につき2回フィボナッチ関数を呼び出す部分の最適化ができていないためであるが,clang++の結果のみでいえばどの再帰スタイルでも同じ最適化になっているので,今回の再帰の書き方によるパフォーマンスの違いに関する結果とは別のものである.

MSVCを無視するなら,個人的には一番使い勝手のよい FixPoint クラスを推したいと思う.

まとめ

  • C++11でない限り,ラムダの再帰std::function を用いるべきでない(約3倍の実行時間になる)
  • ジェネリックラムダと operator() を用いた実装を使うべき

いくつかの実装を紹介したが,C++14のジェネリックラムダによる再帰はいずれも同一のコード生成がなされ,また無駄な処理がなく,普通の関数による再帰よりも高速に動作することがわかった.

競プロの人向け

AtCoderの提出コードで,このブログに書いているような FixPoint の実装を見かけることがあった. 実際は記事中ほど丁寧に実装する必要はなく,競プロのテンプレートとしては以下のコードでよい.

ヘルパ関数の makeFixPoint という名前は長すぎるので,fix にする.

競プロのコードごときに,わざわざ丁寧に単一引数のコンストラクタに explicit を付けて暗黙の変換を避ける必要はない. これにより,fix 関数に暗黙の型変換を利用することができるようになり,文字数を多少削ることが可能になる.

FixPoint クラスでは public: を書くのが面倒なので,struct にして省略するとよい. また,noexcept が無いからといって,単純なコードでは最適化に影響はないはずだし,constexpr が無ければコンパイル時計算はできなくなるが,競プロで再帰関数のコンパイル時計算が必要な場面も無いと思われる.

#include <utility>


template <typename F>
struct FixPoint : F
{
  template <typename G>
  FixPoint(G&& g)
    : F{std::forward<G>(g)}
  {}

  template <typename... Args>
  decltype(auto)
  operator()(Args&&... args) const
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};


#if defined(__cpp_deduction_guides)
template <typename F>
FixPoint(F&&)
  -> FixPoint<std::decay_t<F>>;
#endif  // defined(__cpp_deduction_guides)


// C++17以降でクラステンンプレートの型推論を使うなら以下は不要
template <typename F>
inline FixPoint<std::decay_t<F>>
fix(F&& f)
{
  return std::forward<std::decay_t<F>>(f);
}


#include <iostream>


int
main()
{
  auto result = fix([](auto&& f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

using namespace std 派の人は適宜 std:: を除去してもよい. C++14の範囲でしか書かない,すなわちC++17のクラステンプレートのテンプレート引数推論を使わない場合,推論補助やコンストラクタにテンプレートは必要ないので,以下のようにするとよい.

#include <utility>


template <typename F>
struct FixPoint : F
{
  FixPoint(F&& f)
    : F{std::forward<F>(f)}
  {}

  template <typename... Args>
  decltype(auto)
  operator()(Args&&... args) const
  {
    return F::operator()(*this, std::forward<Args>(args)...);
  }
};


template <typename F>
inline FixPoint<std::decay_t<F>>
fix(F&& f)
{
  return std::forward<std::decay_t<F>>(f);
}


#include <iostream>


int
main()
{
  auto result = fix([](auto&& f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

記事中でも述べているようにBoost.Hanaに FixPoint と同等の実装があるので,そもそもAtCoderのようなBoostが使える環境では FixPoint クラスの実装すら必要ない. (ローカルでBoostを用意するのが面倒な人もいるかもしれないが)

#include <boost/hana/functional/fix.hpp>
using boost::hana::fix;


#include <iostream>


int
main()
{
  auto result = fix([](auto&& f, int n) -> int {
    return n < 2 ? n : (f(n - 1) + f(n - 2));
  })(10);
  std::cout << result << std::endl;
}

参考文献