题解 | 最优划分

最优划分

https://www.nowcoder.com/practice/f4aae7b07cd3403996b920aa2de5268c

34/35通过,我感觉思路应该没问题,后面回看一下。

import math


class SparseTable:
    def __init__(self, data):
        n = len(data)
        self.n = n
        self.k = n.bit_length()
        self.st = [[0] * self.k for _ in range(n)]
        self.log = [0] * (n + 1)

        for i in range(n):
            self.st[i][0] = data[i]

        for j in range(1, self.k):
            i = 0
            while i + (1 << j) <= n:
                self.st[i][j] = max(
                    self.st[i][j - 1], self.st[i + (1 << (j - 1))][j - 1]
                )
                i += 1

        self.log[1] = 0
        for i in range(2, n + 1):
            self.log[i] = self.log[i // 2] + 1

    def query(self, l, r):
        if l > r:
            return -(10 ** 18)
        length = r - l + 1
        j = self.log[length]
        return max(self.st[l][j], self.st[r - (1 << j) + 1][j])


def main():
    import sys

    data = sys.stdin.read().split()
    if not data:
        return
    n = int(data[0])
    k = int(data[1])
    a = list(map(int, data[2 : 2 + n]))

    INF = -(10 ** 18)
    arr = [0] * (n + 1)
    for i in range(1, n + 1):
        arr[i] = a[i - 1]

    dp = [[INF] * (k + 1) for _ in range(n + 1)]
    dp[0][0] = 0

    if k >= 1:
        gcd_temp = arr[1]
        dp[1][1] = gcd_temp
        for i in range(2, n + 1):
            gcd_temp = math.gcd(gcd_temp, arr[i])
            dp[i][1] = gcd_temp

    for j in range(2, k + 1):
        D = [0] * (n + 1)
        for i in range(0, n + 1):
            D[i] = dp[i][j - 1]
        ST = SparseTable(D)

        for i in range(j, n + 1):
            L = []
            current_g = arr[i]
            current_l = i
            for m in range(i - 1, 0, -1):
                if current_g == 1:
                    current_l = 1
                    break
                new_g = math.gcd(current_g, arr[m])
                if new_g != current_g:
                    L.append((current_g, current_l))
                    current_g = new_g
                    current_l = m
                else:
                    current_l = m
            L.append((current_g, current_l))

            S = []
            for idx in range(len(L)):
                g, l_val = L[idx]
                if idx == 0:
                    left = l_val
                    right = i
                else:
                    left = l_val
                    right = L[idx - 1][1] - 1
                S.append((g, left, right))

            best = INF
            for seg in S:
                g, l, r = seg
                m_l = max(l, j)
                m_r = min(r, i)
                if m_l > m_r:
                    continue
                low_index = m_l - 1
                high_index = m_r - 1
                max_val = ST.query(low_index, high_index)
                candidate = max_val + g
                if candidate > best:
                    best = candidate
            dp[i][j] = best

    print(dp[n][k])


if __name__ == "__main__":
    main()

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务