cplusplusonly's memo

Atcoder: https://atcoder.jp/users/cplusplusonly

剰余上での割り算は累乗のほうが GCD より 1000 倍ぐらい早い

剰余上での割り算

競技プログラミングをやっていると、ときどき剰余上での割り算を行う必要が出てきます。
例えば エクサウィザーズ 2019 E - Black or White では xz=y (mod (10^9+7)) となる z を出力する必要があります。

このとき zmod (10^9+7) 上で y/x を実行した値となり、この値は  inv(x)mod (10^9+7) 上での  x の逆元とすると y*inv(x) (mod (10^9+7)) です。つまり剰余上での割り算を行うには逆元をかける必要があるので、逆元を高速に計算する方法が求められます。

剰余上での逆元の計算手法

剰余上での逆元を求めるには以下の 2 種類の方法があります。それぞれ制限がありますが、競技プログラミングでは素数での剰余以外を考える必要がないはずなので、どちらの方法でも正しく逆元が求まります。

最小公倍数 (GCD) で求める

 (a, b) について GCD を計算すると  r = ai + bj となる  (r, i, j) が得られます。 (p, x) が互いに素であるとき、  x p の剰余上での逆元を求めるには、GCD を計算して、 r=1 から  1 = pi + xj となり、 pi (mod(p))=0 なので  xj(mod(p))=1 j が求める逆元となります。

コードは以下の通りです。

tuple<int64_t, int64_t, int64_t> gcd(int64_t x, int64_t y)
{
	if (y == 0) return make_tuple(x, 1, 0);

	tuple<int64_t, int64_t, int64_t> ret = gcd(y, x % y);

	return make_tuple(get<0>(ret), get<2>(ret), get<1>(ret) - (x / y)*get<2>(ret));
}

int64_t gcd_remdiv(int64_t val, int64_t rem)
{
	int64_t ret = get<2>(gcd(rem, val));
	while (ret < 0) ret += rem;

	return ret;
}
フェルマーの小定理を使って累乗で求める

フェルマーの小定理より  p素数のとき、すべての  x (x < p) について  x^{p-1} (mod(p)) = 1 となるので  x*x^{p-2} (mod(p)) = 1 より x^{p-2} (mod(p)) が逆元になります。これは累乗を O(log(p)) で計算するアルゴリズムより高速に求まります。

コードは以下の通りです。

int64_t pow_rem(int64_t val, int64_t mul, int64_t rem)
{
	if (mul == 1) return val;

	int64_t ret = pow_rem(val, mul / 2, rem);
	ret *= ret;
	ret %= rem;

	if (mul & 1) {
		ret *= val;
		ret %= rem;
	}

	return ret;
}

int64_t pow_remdiv(int64_t val, int64_t rem)
{
	return pow_rem(val, rem - 2, rem);
}

逆元の計算手法の計算時間

ここで問題となるのは GCD と 累乗のどちらが早いかですが、ランダムな値に対して  10^9 + 7 の剰余上での逆元を求める計算を 1000万回繰り返すテストをしてみたところ、計算時間は以下のようになりました。(Visual Studio 2018 にて最適化オプション O4 付きでビルド)
GCD は値によって計算時間がかわるのですが累乗に比べて 1000 倍ぐらい遅く、剰余上での割り算の計算は累乗による計算を使うほうが圧倒的に高速のようです。

GCD 累乗
1500 msec - 2500 msec 2 msec


実験コードは以下になります。

#include <iostream>
#include <tuple>
#include <random>
#include <chrono>
#include <ctime>
#include <cstdint>

using namespace std;

tuple<int64_t, int64_t, int64_t> gcd(int64_t x, int64_t y)
{
	if (y == 0) return make_tuple(x, 1, 0);

	tuple<int64_t, int64_t, int64_t> ret = gcd(y, x % y);

	return make_tuple(get<0>(ret), get<2>(ret), get<1>(ret) - (x / y)*get<2>(ret));
}

int64_t gcd_remdiv(int64_t val, int64_t rem)
{
	int64_t ret = get<2>(gcd(rem, val));
	while (ret < 0) ret += rem;

	return ret;
}

int64_t pow_rem(int64_t val, int64_t mul, int64_t rem)
{
	if (mul == 1) return val;

	int64_t ret = pow_rem(val, mul / 2, rem);
	ret *= ret;
	ret %= rem;

	if (mul & 1) {
		ret *= val;
		ret %= rem;
	}

	return ret;
}

int64_t pow_remdiv(int64_t val, int64_t rem)
{
	return pow_rem(val, rem - 2, rem);
}

int main()
{
	int64_t rem = 1000000007;

	mt19937 mt;
	mt.seed(time(nullptr));

	int64_t ret = 1;

	for (int i = 0; i < 10000; i++) {
		uint32_t val = mt();
		if (rem <= val || val == 0) {
			i--;
			continue;
		}

		if (pow_remdiv(val, rem) * val % rem != 1) {
			cerr << "error" << endl;
		}
		if (gcd_remdiv(val, rem) * val % rem != 1) {
			cerr << "error" << endl;
		}

		auto start = chrono::high_resolution_clock::now();
		
		for (int j = 0; j < 10000000; j++) {
			ret += pow_remdiv(val, rem);
		}

		auto end = chrono::high_resolution_clock::now();
		auto dur_pow = end - start;

		start = chrono::high_resolution_clock::now();

		for (int j = 0; j < 10000000; j++) {
			ret += gcd_remdiv(val, rem);
		}

		end = chrono::high_resolution_clock::now();
		auto dur_gcd = end - start;

		auto pow_msec = chrono::duration_cast<chrono::milliseconds>(dur_pow).count();
		auto gcd_msec = chrono::duration_cast<chrono::milliseconds>(dur_gcd).count();

		cout << "pow msec : " << pow_msec << " gcd msec : " << gcd_msec << endl;
	}

	return ret;
}