红魔咖啡馆

头发越掉越多,头发越掉越少

0%

【算法】矩阵快速幂

矩阵快速幂

解决的问题

  • 固定关系的一维k阶递推表达式

  • 固定关系的k维一阶递推表达式

时间复杂度\(O(\log n\times k^3)\)

矩阵乘法

前提:对于二维矩阵,两个行列式相乘需要满足一个行列式的列数=另一个行列式的行数

算法: \[ \begin{vmatrix} a&b \\ c&d\end{vmatrix}\times \begin{vmatrix} e&f \\ g&h\end{vmatrix}=\begin{vmatrix} ae+bg&af+bh \\ ce+dg&cg+dh\end{vmatrix} \] Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
vector<vector<int>> martix_mul(vector<vector<int>> a,vector<vector<int>> b ){
  int n = a.size();
  int m = b[0].size();
  int k = a[0].size();
  vector<vector<int>> ans(n, vector<int>(m));
  for(int i = 0;i<n;i++){
    for (int j = 0;j<m;j++){
      for (int c = 0;c<k;c++){
        ans[i][j] +=a[i][c]*b[c][j];
      }
    }
  }
  return ans;
}

矩阵快速幂

前提:需要是正方形矩阵

单位矩阵:类似乘法中的1,即主对角线上为1,其余位置为0的矩阵,与矩阵相乘不会改变原矩阵

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
vector<vector<int>> mqpow(vector<vector<int>> m, int p){
  int n = m.size();
  vector<vector<int>> ans(n, vector<int>(n));
  for (int i = 0; i<n;i++){
    ans[i][i]=1;
  }
  while(p){
    if(p&1)ans = martix_mul(ans, m);
    m = martix_mul(m,m);
    p>>=1;
  }
  return ans;
  
}

例:求斐波那契数列

使用矩阵可以加速斐波那契数列的递推,我们设向量\(v_1\)

初始时令\(v_1=\begin{pmatrix} 0 \\ 1 \end{pmatrix}\),即数列前两项,斐波那契的变化可以看作这个向量一直向右移动,更新为新的两项

这个转移可以表示为\(\begin{pmatrix} a&b\\c&d\end{pmatrix}\begin{pmatrix} 0\\1\end{pmatrix}=\begin{pmatrix} 1\\1\end{pmatrix}\)

为了求解左乘的转置矩阵,需要再写一项\(\begin{pmatrix} a&b\\c&d\end{pmatrix}\begin{pmatrix} 1\\1\end{pmatrix}=\begin{pmatrix} 1\\2\end{pmatrix}\)

这两个矩阵转化为方程组联立可以解得矩阵为\(\begin{pmatrix}0&1\\1&1 \end{pmatrix}\),设为A,这个A是不变的

即每次递推只需要将当前状态向量相乘该转移矩阵即可实现

因此如果要求解转移n次后的结果,就可以表示为\(A^{n-1}v_1\)

这样我们可以将线性递推转化为求解矩阵快速幂,时间复杂度降为\(O(\log n)\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

vector<vector<ll>> mul(vector<vector<ll>> a, vector<vector<ll>> b, ll mod){
  int n = a.size();
  int m = b[0].size();
  int k = a[0].size();
  vector<vector<ll>> ans(n,vector<ll>(m));
  for (int i = 0; i<n;i++){
    for(int j = 0; j<m;j++){
      for (int c = 0;c<k;c++){
        ans[i][j]= (ans[i][j]+a[i][c]*b[c][j])%mod;
      }
    }
  }
  return ans;
}
vector<vector<ll>> mqpow(vector<vector<ll>> m, ll p, ll mod){
  ll n = m.size();
  vector<vector<ll>> ans(n,vector<ll>(n));
  for (int i = 0;i<n;i++){
    ans[i][i]=1;
  }
  while(p){
    if(p&1) ans = mul(ans, m, mod);
    m = mul(m,m,mod);
    p>>=1;
  }
  return ans;
}
void solve() {
  ll n;
  cin >> n;
  if(n==1){
    cout << 1 ;
    return ;
  }
  vector<vector<ll>> unit = {{1,1},{1,0}};
  vector<vector<ll>> start = {{1,0}};
  vector<vector<ll>> ans = mul(start, mqpow(unit,n-1, mod), mod);
  cout << ans[0][0];
}

使用矩阵乘法来加速递推

动态规划问题中,若递推的数据规模过大,显然线性递推会超时,此时我们需要考虑使用矩阵乘法加速递推

原理和上面的斐波那契数列推导类似

使用矩阵乘法时,重点是推导转移的形式

即推导出\(f(x),f(x-1),f(x-2)..(如果必要)\)的递推式,他们的系数即为对应的转移矩阵

例1: 矩阵加速

https://www.luogu.com.cn/problem/P1939

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

/**
 * 矩阵优化递推
 * 注意到给的递推式,可以写成如下:
 * f(x) = 1*f(x-1)+0*f(x-2)+1*f(x-3)
 * f(x-1) = 1*f(x-1)+0*f(x-2)+0*f(x-3)
 * f(x-2) = 0*f(x-1)+1*f(x-2)+0*f(x-3)
 * 这里需要用f(x-2)来辅助从x转移到x-3
 * 因此递推矩阵可以写作
 * 1 0 1
 * 1 0 0
 * 0 1 0
 * 
 * n<=3都是初始条件的1,从n=4开始转移,因此要转移n-3次,快速幂实现
 */
matrix mul(matrix &a, matrix &b, ll mod)
{
  ll n = a.size();
  ll m = b[0].size();
  ll k = a[0].size();
  matrix ans(n, vector<ll>(m));
  for (int i = 0; i<n; i++)
  {
    for (int j = 0; j<m; j++)
    {
      for (int c = 0; c<k; c++ )
      {
        ans[i][j] = (ans[i][j]+a[i][c]%mod*b[c][j]%mod)%mod;
      }
    }
  }
  return ans;

}
matrix mqpow(matrix a, ll p, ll m)
{
  ll n = a.size();
  matrix ans(n, vector<ll>(n));
  for (int i = 0; i<n; i++)
  {
    ans[i][i] = 1;
  }
  while(p)
  {
    if (p&1) ans = mul(ans, a, m);
    p>>=1;
    a = mul(a,a,m);

  }
  return ans;
}
void solve() {
  ll n;
  cin >> n;
  if (n<=3)
  {
    cout << 1 <<endl;
    return ;
  }
  matrix init = {
    {1,0,1},
    {1,0,0},
    {0,1,0}
  };

  matrix ans = mqpow(init, n-3, mod);
  ll cnt = 0;
  for (int i = 0; i<3; i++)
  {
    cnt = (cnt+ans[0][i])%mod;
  }

  cout << cnt%mod <<endl;


}