koturnの日記

普通の人です.ブログ上のコードはコピペ自由です.

CodeIQのスクエア・カルテット問題を解いた

はじめに

CodeIQで@riverplus氏による「 スクエア・カルテット」問題 という,初等整数論を絡めた面白いプログラミングの問題があったので,それについての記事を書いた. 高校生の数学を思い出す良い問題であり,楽しかった.

問題

2つの自然数の組 $(a, b)$ が与えられたとき,自然数 $x, y$ に関する次の方程式を考えます.

\begin{equation} x^2 + a^2 = y^2 + b^2 \label{eq:given-equation} \end{equation}

例えば, $(a, b) = (3, 10)$ のとき,方程式(\ref{eq:given-equation})の解は $(x, y) = (10, 3), (46, 45)$ の2組です.

自然数の組 $(a, b)$ に対し,方程式(\ref{eq:given-equation})の全ての解の $x + y$ の和を $F(a, b)$ と定義します. 例えば $F(3, 10) = 10 + 3 + 46 + 45 = 104$ です. 同様に, $F(10, 50) = 3500$ , $F(20, 100) = 15022$ となることが確かめられます.

標準入力から,半角空白区切りで 2つの自然数 $a, b$ ( $1 \leq a < b \leq 10^5$ )が与えられます. 標準出力に $F(a, b)$ の値を出力するプログラムを書いてください.

考え方

ここでは,$0 \notin \mathbb{N}$ とする.

与式(\ref{eq:given-equation})を変形して,

\begin{equation} (x + y)(x - y) = b^2 - a^2 = n \label{eq:converted} \end{equation}

とおく($a, b \in \mathbb{N}$ かつ $a > b$ より $n \in \mathbb{N}$). よって, $x + y \in \mathbb{N}$ , $x - y \in \mathbb{N}$ . 式(\ref{eq:converted})より,$x$, $y$ に関する連立方程式

\begin{eqnarray} \begin{cases} x + y = p & \\ x - y = q & \end{cases} \label{eq:xypq} \end{eqnarray}

を得る($p, q \in \mathbb{N}$ かつ $n = pq$). $i$ 番目の解 $(x_i, y_i)$ に $p_i, q_i$ が対応すると考えると,出力すべき値は,

\begin{equation} F(a, b) = \sum_i (x_i + y_i) = \sum_i p_i = \sum_i \dfrac{n}{q_i} \end{equation}

である. 連立方程式(\ref{eq:xypq})を解くと,

\begin{equation} (x, y) = \left( \dfrac{p + q}{2}, \dfrac{p - q}{2} \right) \end{equation}

を得る. $x \in \mathbb{N}$ なので,

\begin{equation} (p + q) \bmod 2 = 0 \label{eq:pq-constrain} \end{equation}

また,$x, y \in \mathbb{N}$ なので,

\begin{equation} x - y = q < p = x + y \end{equation}

$p, q$ は $n$ を2つの自然数積に分解したもの,すなわち,$n$ の約数のペアであることを踏まえると,

\begin{equation} (1 \leq) \:\: q < \sqrt{n} < p \:\: (\leq n) \end{equation}

以上より, $\sqrt{n}$ より小さく(「以下」ではない),かつ条件(\ref{eq:pq-constrain})を満たす $n$ の約数 $q_i$ を全て見つけ出し,対になる約数 $p_i = \dfrac{n}{q_i}$ の総和 $\sum_i p_i$ を計算して,出力すると良い.

おまけ

$a < b$ という制約が無く,

  1. $a = b$ , すなわち $n = 0$
  2. $a > b$ , すなわち $n < 0$

である場合も,思考実験として考えてみる.

$n = 0$ の場合, $(x, y)$ は $x = y$ なる任意の自然数

$n < 0$ の場合,$x + y > 0$ なので,

\begin{equation} x - y = q < 0 \end{equation}

となり, $q$ がマイナス符号を担当する($p \in \mathbb{N}, q \in \mathbb{Z}$). 連立方程式(\ref{eq:converted})を

\begin{eqnarray} \begin{cases} y + x = p & \\ y - x = -q & \end{cases} \end{eqnarray}

と変形し, $y \rightarrow x'$ , $x \rightarrow y'$ , $p \rightarrow p'$ , $-q \rightarrow q' (> 0)$ と置き直すことで, $n > 0$ のときと同様に処理できる.

\begin{eqnarray} \begin{cases} x' + y' = p' & \\ x' - y' = q' & \end{cases} \end{eqnarray}

$-n (> 0)$ (マイナス符号が無い場合)と比較して, $(x, y)$ の組み合わせは逆転しているが,出力すべき値はその和

\begin{equation} F(a, b) = \sum_i (y'_i + x'_i) = \sum_i p'_i = \sum_i \dfrac{-n}{q'_i} \end{equation}

なので, $n$ と $-n$ のときの $F(a, b)$ は等しい. すなわち, $F(a, b) = F(b, a)$ .

まとめ

  1. $n = b^2 - a^2$ ( $n = | b^2 - a^2 |$ ) とし, $\sqrt{n}$ より小さく,$\left( q + \dfrac{n}{q} \right) \bmod 2 = 0$ となる $n$ の約数 $q_i$ を全て求める
  2. $\sum_i \dfrac{n}{q_i}$ を出力する

本番入力値と考察

以下の6ケースが本番での入力値であった.

10 26
11 389
123 456
35672 61243
71200 82321
19126 98765

1つ目のケース $(a, b) = (10, 26)$ は, $b^2 - a^2 = 576 = 24^2$ となり,コーナーケースであった($p = q = 24$ ,すなわち $(x, y) = (24, 0)$ を含めてしまうのは誤り). また,問題文の例にあった $(a, b) = (10, 50)$ は, $b^2 - a^2 = 576 = 2400$ となり,$2400$ は $\lfloor \sqrt{2400} \rfloor = 48$ を約数に持つので,(ある意味,前述のものと対になる)コーナーケースであった.

この2つのケースを考えると, $n = b^2 - a^2$ の約数を単純に $1, 2, \ldots, \lfloor \sqrt{n} \rfloor - 1$ から見つけ出すのは,本番ケースに限るならばうまくいくが誤りである. 解決策としては,$n$ が2乗数であるかどうかを判定し,範囲を調製しなければならないが,コードで書くと汚くなる上に面倒である. そこで,$1, 2, \ldots, \lfloor \sqrt{n - \epsilon} \rfloor$ ($\epsilon$ は十分に小さな正の実数)から約数を探索するようにすると単純に処理できるはずだ.

解答例

bashJavaC++で解答での解答例を紹介する. $ \epsilon = 1.0 \times 10^{-10}$ とした.

ちなみに,入力は各ケースにつき1行のみだったので,それぞれの解答例のように,whileでEOFまで読み込みを行う必要はない.

bashでの解答

あえて,bashで解くというのも面白い.

#!/bin/bash -eu

declare -i a b
while read a b; do
  declare -i n=$((b ** 2 - a ** 2))
  declare -i qMax=`echo "sqrt($n - 0.00000000001)" | bc | sed 's/\.[0-9]*$//g'`
  declare -i q answer=0
  for q in `seq 1 $qMax`; do
    (( n % q == 0 )) && (( (q + n / q) % 2 == 0 )) && (( answer += n / q ))
  done
  echo $answer
done

bashだと,下手に組むとTLEになるので,やや難易度は高かった(1秒の壁は大きい). expr コマンドは時間がかかるので,基本的にbashの算術式で計算し,平方根などの算術式では計算できないものは bc コマンドに投げて計算するだけだ.

なお,以下のようなbashの算術式のfor文

for ((i = 0; i < 100; i++)) {
  # 処理
}

だと時間がかかるので, seq コマンドで $1$ から $\lfloor \sqrt{n - \epsilon} \rfloor$ までの連続する整数のリストを生成し,通常のシェルのforを用いるとよい. また,ifのパースは時間がかかると予想できるので,短絡評価を利用し,bashの算術式を繋げるとよいだろう.

Javaでの解答

まともな言語,例えばJavaでは以下のように率直に書けばよいだろう.

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.stream.LongStream;

public class Main {
    public static void main(String[] args) throws Exception {
        try (BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) {
            br.lines()
                .map(line -> Arrays.stream(line.split(" "))
                        .mapToLong(Long::parseLong)
                        .map(x -> x * x)
                        .toArray())
                .map(inputs -> inputs[1] - inputs[0])
                .map(n -> LongStream.rangeClosed(1, (long) Math.sqrt(n - 1e-10))
                        .filter(q -> n % q == 0 && ((q + n / q) & 0x01) == 0)
                        .map(q -> n / q)
                        .sum())
                .forEach(System.out::println);
        }
    }
}

Javaでは2の累乗の定数の乗算,除算,剰余はビット演算に置き換えた方が速い. C/C++ならば,2の累乗の定数の乗算,除算,剰余は最適化オプションを付与しなくてもビット演算に置き換えられるが,JavaコンパイラJITコンパイルのためにバイトコードを最適化しないようになっているため,2の累乗の定数の乗算,除算,剰余は,コンパイル後のバイトコードにおいても乗算,除算,剰余のままである. これは生成されるバイトコードを見れば一目瞭然である. しかし,JITコンパイルにより最適化されるかといえばそうではなく,手動でビット演算に置き換えた方が高速に動作する.

なお, $b^2 < 10^{10}$ なので,64bit整数型を用いないとオーバーフローする. この点は気をつけないといけない.

C++での解答

みんな大好きC++で書くと,以下のようになる. for文で逐次的に処理するので,連続する整数の配列やリストを生成する必要がなく, $\lfloor \sqrt{n - \epsilon} \rfloor$ の議論が不要になる. 大半の人はこういった解答をしていると思う. 試していないが,ループ毎に $q^2$ を計算するので, $\lfloor \sqrt{n - \epsilon} \rfloor$ を計算しておく場合と比較すると遅い気がするのだが,実際はどうなのだろうか?

#include <cstdlib>
#include <iostream>

typedef long long  llint;


int
main()
{
  std::cin.tie(0);
  std::ios::sync_with_stdio(false);

  llint a, b;
  while (std::cin >> a >> b) {
    llint n = b * b - a * a;
    llint answer = 0;
    for (llint q = 1; q * q < n; q++) {
      if (n % q != 0) continue;
      llint p = n / q;
      if ((p + q) % 2 == 1) continue;
      answer += p;
    }
    std::cout << answer << std::endl;
  }

  return EXIT_SUCCESS;
}