yutasの競技プログラミング勉強帖

競技プログラミングの問題についての解説記事を主に書いています。

F - Hop Sugoroku / AtCoder Beginner Contest 335(Sponsored by Mynavi)

問題

一列に並んだ  n 個のマス  1, \, 2, \, \cdots , \, n と長さ  n の数列  a = ({a}_{1}, \, {a}_{2}, \, \cdots , \, {a}_{n}) がある。

最初、マス  1 は黒く、他の  n - 1 個のマスは白く塗られており、1つのコマがマス  1 に置かれている。

以下の操作を  0 回以上好きな回数繰り返す。

  • コマがマス  i にあるとき、ある正整数  x を決めて、コマをマス  i + {a}_{i} \times x に移動させる。
    • ただし、  i + {a}_{i} \times x \gt n となるような移動はできない。
  • その後、マス  i + {a}_{i} \times x を黒く塗る。

操作を終えた時点で黒く塗られたマスの集合として考えられるものの数を  998244353 で割った余りを求めよ。

入力

まず最初の1行目に、整数  n が与えられる。

次の1行には、整数  {a}_{1}, \, {a}_{2}, \, \cdots , \, {a}_{n} が順番に与えられる。

条件

  • 実行時間制限: 2.5s
  • メモリ制限: 1024MB
  •  1 \leq n \leq 2 \times {10}^{5}
  •  1 \leq {a}_{i} \leq 2 \times {10}^{5}

出力

答えを1行で出力すること。

解法

 \mathrm{ans}_{i} を、「マス  i が黒く塗られており、それ以降のマスはすべて白く塗られている」ものの総数として定義します。

マス  i から移動できるマスとしては  i + {a}_{i}, \, i + 2 {a}_{i}, \, \cdots であることから、 \mathrm{ans}_{i} を用いて、 \begin{gather} \mathrm{ans}_{i + {a}_{i}} \leftarrow \, \mathrm{ans}_{i + {a}_{i}} + \mathrm{ans}_{i} \\ \mathrm{ans}_{i + 2 {a}_{i} } \leftarrow \, \mathrm{ans}_{i + 2 {a}_{i}} + \mathrm{ans}_{i} \\ \vdots \end{gather} と更新することができます。

この方法の場合、ある  i に対して最大で  k 個分の値の更新が行われると仮定すると、 \begin{align} i + k {a}_{i} & \leq n \\ \therefore \; k & \leq \frac{n - i}{ {a}_{i} } \end{align} であることから、 {a}_{i} の値が大きければ、実行時間制限に間に合うような更新となります。 一方で、例えば数列  {a} のすべての要素が  1 であったとすると、全体で  O( {n}^{2}) の計算量が必要になり、実行時間制限には間に合いません。

ここで、マス  i によってマス  j が黒く塗られる条件について考えます。(ただし、  i \lt j とします。) このとき、ある整数  x を用いて \begin{align} i + x {a}_{i} = j \end{align} が成立することから、 j - i の値は  {a}_{i} の倍数であると言えます。

従って、  i \equiv j \; (\mathrm{mod} \, {a}_{i}) が成立することから、マス  i によってマス  j が黒く塗られる条件は「  j {a}_{i} で割ったときの余りと、  i {a}_{i} で割ったときの余りが等しい」ということになります。

ここで、  \mathrm{dp}_{i, \, j} を、「  i で割ったときの余りが  j となるすべてのマス  x の場合の数の総数」が記録されているものとして定義します。

この  \mathrm{dp} \mathrm{ans}_{1}, \, \cdots , \, \mathrm{ans}_{i - 1} の分まで更新されているとします。 このとき、マス  i での  \mathrm{ans}_{i} の値は、すべての  j に対して、  \mathrm{ans}_{i} \leftarrow \mathrm{ans}_{i} + \mathrm{dp}_{j, \, (i \,  \% \, j)} と更新していくことによって、  \mathrm{ans}_{i} が得られていくと言えます。

これは、  \mathrm{dp}_{i, \, j} i の値が小さければ、実行時間制限に間に合うような更新となります。  \mathrm{dp}_{i, \, j} i の値は、数列  a の要素の値によって定まるため、数列  a の要素がすべて小さいときに、この方法は実行時間制限に間に合うと言えます。 一方で、  a の要素の中に1つでも最大値の  {10}^{5} のような値が入っている場合、全体の計算量としては  O({10}^{5} n) 程度になってしまい、この方法は実行時間制限に間に合いません。

ここで、以上の2つの方法を順に方法1, 方法2とすると、

  • 方法1:  {a}_{i} の値が大きいときには使用可能
  • 方法2:  {a}_{i} の値が小さいときには使用可能

ということがそれぞれ言えます。

従って、ある値  b に対して、  {a}_{i} \geq b のときには方法1を、  {a}_{i} \leq b のときには方法2を使用するということを考えます。

このとき、すべて方法1を使用するときの全体の計算量は  \displaystyle O \left( \frac{{n}^{2}}{b} \right) となり、すべて方法2を使用するときの全体の計算量は  O(bn) となります。 よって、この方法1, 2を切り替えることによる方法での全体の計算量は  \displaystyle O \left( \frac{{n}^{2}}{b} + bn \right) となります。

この値について、相加平均と相乗平均の大小関係より、 \begin{align} \frac{ {n}^{2} }{b} + bn \geq 2 \sqrt{ \frac{ {n}^{2} }{b} \cdot bn } = 2 n \sqrt{n} \end{align} ということが言えます。 なお、等号が成立するのは \begin{gather} \frac{ {n}^{2} }{b} = bn \\ \therefore \; b = \sqrt{n} \end{gather} のときとなります。

以上から、  b = \sqrt{n} と定めて、方法1, 2を切り替える方法によって、全体で  O(2 n \sqrt{n}) 程度の計算量になると言え、この方法によって実行時間制限に間に合いながら  \mathrm{ans}_{1}, \, \cdots , \,  \mathrm{ans}_{n} の値を導出できるということが言えます。

 \mathrm{ans}_{i} は「マス  i が黒く塗られており、それ以降のマスはすべて白く塗られている場合の数」を表しているので、最終的に出力する値は、 \begin{align} \mathrm{ans}_{1} + \mathrm{ans}_{2} + \cdots + \mathrm{ans}_{n} \end{align} を  998244353 で割ったときの余りである、ということが言えます。 よって、この値を出力することで、この問題は解くことができます。

ソースコード

main() 関数の中に、答えを出力する部分を直接実装しました。

const ll MOD2 = 998244353;

int n;
ll a[int(2e5 + 5)];
ll ans[int(2e5 + 5)], dp[1005][1005];

int main() {
  // 値の入力
  cin >> n;
  for (int i = 1; i <= n; i++) {
    cin >> a[i];
  }

  int b = sqrt(n); // 境界となる値 b を std::sqrt を用いて定める
  ans[1] = 1; // 初期条件として、 ans[1] = 1 を追加しておく

  for (int i = 1; i <= n; i++) {
    // ans[i] の値の更新
    for (int j = 1; j <= b; j++) {
      ans[i] += dp[j][i % j]; //  j = 1, ..., b に対して、「 i を j で割ったときの余り」の部分から ans[i] を更新する
      ans[i] %= MOD2; // 998244353 で割ったときの余りを求めるため、この値で mod を取る
    }

    // dp や i 以降の ans の値の更新
    if (a[i] <= b) {
      // a[i] が小さい値のときは、方法2を使う
      dp[a[i]][i % a[i]] += ans[i]; // 「i を a[i] で割ったときの余り」の部分を更新する
      dp[a[i]][i % a[i]] %= MOD2;
    } else {
      // a[i] が大きい値のときは、方法1を使う
      for (int j = 1; i + a[i] * j <= n; j++) {
        ans[i + a[i] * j] += ans[i];
        ans[i + a[i] * j] %= MOD2;
      }
    }
  }

  ll sum = 0; // 最終的に出力する総和
  for (int i = 1; i <= n; i++) {
    sum += ans[i]; // ans[i] を sum に加えていく
    sum %= MOD2;
  }
  cout << sum << endl; // 値の出力

  return 0;
}

感想

計算量を削減する方法が自力では思いつかず、公式解説 *1 頼りになってしまいました。

ちなみに、記事中の  b の値について、すべてのテストケースに対して  b = 1000 として提出した場合 *2 でもACが得られました。 ただ、この場合の実行時間は754msであり、  b = \sqrt{n} とした場合は377msであったので、当てずっぽうで値を定めてしまうと実行時間制限オーバーとなってしまう可能性が出てきてしまいそうですね。