题解 | #【模板】矩阵快速幂#

【模板】矩阵快速幂

https://www.nowcoder.com/practice/9aa4e720ca7c47f4a0abf1e13918ab0c

题目链接

【模板】矩阵快速幂

题目描述

给定一个 的整数方阵 以及一个非负整数 ,请计算矩阵 。当 时,约定 单位矩阵 。所有计算结果对 取模。

解题思路

本题要求计算一个矩阵的 次幂。如果 非常大,直接进行 次矩阵乘法()会超时。这是一个典型的可以使用快速幂 (Binary Exponentiation) 思想来解决的问题。

常规的快速幂用于计算一个数的幂(例如 ),其核心思想是将指数 进行二进制拆分,从而将时间复杂度从 优化到 。同样的思想也适用于矩阵乘法,因为矩阵乘法满足结合律(即 ),这是使用快速幂算法的前提。

算法步骤

  1. 定义矩阵乘法: 首先,我们需要一个函数来计算两个 矩阵的乘积。设 ,则矩阵 中的每个元素 由以下公式计算得出: 在计算过程中,每次乘法和加法的结果都需要对 取模,以防止溢出。特别需要注意,中间结果可能为负数,取模时需要确保结果落在 区间内。此操作的时间复杂度为

  2. 矩阵快速幂算法: 我们将整数快速幂的算法应用于矩阵:

    • 初始化一个结果矩阵 单位矩阵(主对角线为1,其余为0)。
    • 初始化一个基底矩阵 为输入的矩阵
    • 对指数 进行循环,直到 变为 0:
      • 如果 的当前二进制最低位为 1(即 ),则将 乘以
      • 自乘:
      • 右移一位(即 )。
    • 循环结束后, 矩阵即为最终结果

整个算法需要进行 次矩阵乘法,每次乘法复杂度为

代码

#include <iostream>
#include <vector>

using namespace std;

const int MOD = 1e9 + 7;

// 定义矩阵类型
using Matrix = vector<vector<long long>>;

// 矩阵乘法
Matrix multiply(const Matrix& a, const Matrix& b, int n) {
    Matrix c(n, vector<long long>(n, 0));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            for (int l = 0; l < n; ++l) {
                long long product = a[i][l] * b[l][j];
                c[i][j] = (c[i][j] + product % MOD + MOD) % MOD;
            }
        }
    }
    return c;
}

// 矩阵快速幂
Matrix matrix_pow(Matrix base, long long exp, int n) {
    Matrix res(n, vector<long long>(n, 0));
    // 初始化为单位矩阵
    for (int i = 0; i < n; ++i) {
        res[i][i] = 1;
    }

    while (exp > 0) {
        if (exp % 2 == 1) {
            res = multiply(res, base, n);
        }
        base = multiply(base, base, n);
        exp /= 2;
    }
    return res;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    long long k;
    cin >> n >> k;

    Matrix a(n, vector<long long>(n));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> a[i][j];
        }
    }

    Matrix result = matrix_pow(a, k, n);

    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            cout << result[i][j] << (j == n - 1 ? "" : " ");
        }
        cout << "\n";
    }

    return 0;
}
import java.util.Scanner;

public class Main {
    static final int MOD = 1_000_000_007;

    // 矩阵乘法
    public static long[][] multiply(long[][] a, long[][] b, int n) {
        long[][] c = new long[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                for (int l = 0; l < n; l++) {
                    long product = a[i][l] * b[l][j];
                    c[i][j] = (c[i][j] + product % MOD + MOD) % MOD;
                }
            }
        }
        return c;
    }

    // 矩阵快速幂
    public static long[][] matrixPow(long[][] base, long exp, int n) {
        long[][] res = new long[n][n];
        // 初始化为单位矩阵
        for (int i = 0; i < n; i++) {
            res[i][i] = 1;
        }

        while (exp > 0) {
            if (exp % 2 == 1) {
                res = multiply(res, base, n);
            }
            base = multiply(base, base, n);
            exp /= 2;
        }
        return res;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long k = sc.nextLong();

        long[][] a = new long[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                a[i][j] = sc.nextLong();
            }
        }

        long[][] result = matrixPow(a, k, n);

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(result[i][j] + (j == n - 1 ? "" : " "));
            }
            System.out.println();
        }
    }
}
import sys

MOD = 10**9 + 7

def multiply(a, b, n):
    c = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            for l in range(n):
                c[i][j] = (c[i][j] + a[i][l] * b[l][j]) % MOD
    return c

def matrix_pow(base, exp, n):
    res = [[0] * n for _ in range(n)]
    # 初始化为单位矩阵
    for i in range(n):
        res[i][i] = 1

    while exp > 0:
        if exp % 2 == 1:
            res = multiply(res, base, n)
        base = multiply(base, base, n)
        exp //= 2
    return res

def main():
    try:
        input = sys.stdin.readline
        n, k = map(int, input().split())
        
        a = []
        for _ in range(n):
            a.append(list(map(int, input().split())))

        result = matrix_pow(a, k, n)

        for i in range(n):
            sys.stdout.write(" ".join(map(str, result[i])) + '\n')

    except (IOError, ValueError):
        return

main()

算法及复杂度

  • 算法:矩阵快速幂 (Matrix Exponentiation by Squaring)
  • 时间复杂度:。其中 是单次矩阵乘法的复杂度, 是快速幂算法所需的乘法次数。
  • 空间复杂度:,用于存储矩阵。
全部评论

相关推荐

03-13 00:04
已编辑
吉林大学 Java
约面的挺突然。。狠下心接了1.自我介绍2.讲讲JAVA的反射3.可以继续讲讲AOP,动态代理[&nbsp;因为讲反射不小心吟唱到了例如AOP的动态代理,但是这块记忆的非常不熟,结果磕磕绊绊&nbsp;]4.项目我看你写了AOP和注解,具体怎么实现滑动窗口限流的[&nbsp;梦到什么说什么,吟唱八股发散千万不要散到自己不熟悉的区域&nbsp;]5.也讲讲为什么另一个项目选择令牌桶,具体流程6.&nbsp;OK,讲讲&nbsp;Redis&nbsp;的数据类型?还有吗?就了解这五种嘛[&nbsp;把5个的基础类型从应用对比到历届底层全都吟唱了一遍。一句还有吗直接没力气了,简历就写了理解5种,别的我是真一点没看TT&nbsp;]7.讲讲Redission分布式锁实现8.这个指数退避怎么实现的9.在这里有考虑去保障幂等性嘛10.这里为什么使用指数退避呢?&nbsp;什么时候用均匀重传[已经晕过去了说不了解,刚说了后就意识到,估计应该说指数退避能缓解压力防止下游服务器雪崩之类的]11.ok,那讲讲JMM12.讲讲RocketMQ如何保证的不丢消息13.讲讲RocketMQ延迟消息原理14.讲讲项目Redis实现会话记忆这一块15.如果ai调用function&nbsp;calling出现幻觉,有考虑怎么解决吗?[&nbsp;不了解,面试官说什么接口幂等化,高危操作人工防护,没在听,感觉人已经飞升了TT&nbsp;]16.mcp了解嘛?和function&nbsp;calling有什么区别[&nbsp;依旧不了解,只能说了个前者规范架构抽象解耦,后者耦合高只能算个工具调用]17.AI生成代码的代码质量怎么保障,那平时如何review的呢18.算法。lc215&nbsp;&nbsp;数组中最大第k个元素19.打算考研还是本科就业20.反问1️⃣有哪里不足,有哪些需要提高的部分。[主要说知识广度不够,多刷算法,让我别太紧张]2️⃣部门业务会做什么人生第二次面试。感觉大厂面试官的气场压力很大应该凉了不过这次面试非常锻炼心态,多面试,多面试。
Luxlord:面经太硬核了
点赞 评论 收藏
分享
KKorz:是这样的,还会定期默写抽查
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务