首页 > 试题广场 >

最优二叉树II

[编程题]最优二叉树II
  • 热度指数:7372 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解

小团有一个由N节点组成的二叉树,每个节点有一个权值。定义二叉树每条边的开销为其两端节点权值的乘积,二叉树的总开销即每条边的开销之和。小团按照二叉树的中序遍历依次记录下每个节点的权值,即他记录下了N个数,第i个数表示位于中序遍历第i个位置的节点的权值。之后由于某种原因,小团遗忘了二叉树的具体结构。在所有可能的二叉树中,总开销最小的二叉树被称为最优二叉树。现在,小团请小美求出最优二叉树的总开销。


输入描述:

第一行输入一个整数N(1<=N<=300),表示二叉树的节点数。

第二行输入N个由空格隔开的整数,表示按中序遍历记录下的各个节点的权值,所有权值均为不超过1000的正整数。



输出描述:

输出一个整数,表示最优二叉树的总开销。

示例1

输入

5
7 6 5 1 3

输出

45

说明

最优二叉树如图所示,总开销为7*1+6*5+5*1+1*3=45。


树形dp。
首先要明白中序遍历的特点:选取其中一个节点,其左边的节点都是其左子树上的节点,其右边的节点都是其右子树上的节点。
动态规划三步走:明确下标意义,寻找递推公式,dp数组初始化。
首先是dp数组的下标意义。
我用了两个二维数组,ldp[i][j]表示以以node[j+1]为根节点、node[i]到node[j]作为左子树节点的最优二叉树的权值;rdp[i][j]表示以以node[i-1]为根节点、node[i]到node[j]作为右子树节点的最优二叉树的权值。
其次是递推公式。
最后是初始化
其实也用不着初始化,在递推公式里就能完成。
代码如下:
#include <iostream>

using namespace std;

int ldp[302][302]{};
int rdp[302][302]{};
int node[302]{};

void f(int a, int b) {
	int x, y;
    ldp[a][b] = rdp[a][b] = 100000000;
	for (int i = a; i <= b; ++i) {
		x = ldp[a][i - 1] + node[i] * node[b + 1] + rdp[i + 1][b];
		y = ldp[a][i - 1] + node[i] * node[a - 1] + rdp[i + 1][b];
		if (x < ldp[a][b])ldp[a][b] = x;
		if (y < rdp[a][b])rdp[a][b] = y;
	}
}

int main() {
	int N;
	cin >> N;
	for (int i = 1; i <= N; ++i) {
		cin >> node[i];
	}
	for (int i = 0; i < N; ++i) {
		for (int j = 1; j <= N-i; ++j) {
			f(j, j + i);
		}
	}
	cout << ldp[1][N];
	return 0;
}


发表于 2021-08-11 20:26:38 回复(3)
树形dp

#include"bits/stdc++.h"

using namespace std;

const int N = 310;
int w[N];
int f[N][N][N];
int n;

int dp(int l, int r, int p) {
    if(l > r) return 0;
    if(f[l][r][p] != -1) return f[l][r][p];
    int ret = 2e9;
    for(int i = l; i<=r ; i++) {
        int left = dp(l,i-1,i);
        int right = dp(i+1,r,i);
        ret = min(ret,left + right + w[i]*w[p]);
    }
    f[l][r][p] = ret;
    return ret;
}

int main() {
    cin >> n;
    memset(f,-1,sizeof f);
    for(int i = 1; i<=n ; i++) cin >> w[i];
    cout << dp(1,n,0);
    return 0;
}


发表于 2021-03-02 10:46:34 回复(7)
分治重构二叉树+记忆化搜索(不保存计算过的最小开销会超时),n个节点编号为1~n,mem[i][j][k]表示以k为根节点,i~j为子树节点的树的最小开销。从节点1到n轮流当根节点,通过递归分治构建左右子树计算整体的开销。
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;

public class Main {
    static int[][][] mem;
    static int[] weight;
    static int n;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine().trim());
        String[] strW = br.readLine().trim().split(" ");
        weight = new int[n];
        for(int i = 0; i < n; i++) weight[i] = Integer.parseInt(strW[i]);
        // mem[l][r][k]表示以weight[l:r]为子节点,以weight[k]为根节点的树开销
        mem = new int[n][n][n];
        for(int i = 0; i < n; i++){
            for(int j = 0; j < n; j++)
                for(int k = 0; k < n; k++) mem[i][j][k] = -1;
        }
        System.out.println(recur(0, n - 1, -1));
    }
    
    private static int recur(int left, int right, int root) {
        if(left > right) return 0;
        if(root >= 0 && mem[left][right][root] != -1) return mem[left][right][root];
        int cost = Integer.MAX_VALUE;
        // [left,right]中的元素轮流做根节点构建二叉树
        int leftCost = 0, rightCost = 0;
        for(int i = left; i <= right; i++){
            leftCost = recur(left, i - 1, i);      // 左子树开销
            rightCost = recur(i + 1, right, i);    // 右子树开销
            // root=-1时表示初始根节点还没有确定,不会有根节点连接左右子树的边
            cost = Math.min(cost, leftCost + rightCost + weight[i]*(root != -1? weight[root]: 0));
        }
        if(root >= 0) mem[left][right][root] = cost;
        return cost;
    }
}


编辑于 2021-03-08 12:13:34 回复(2)
import java.util.Scanner;
import java.util.Arrays;
public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        int[] arr = new int[n];
        for (int i = 0; i < n; i++) {
            arr[i] = scanner.nextInt();
        }
        int[][][] dp = new int[n][n][n];
        for (int i = 0; i < n - 1; i++) {
            dp[i][i + 1][i] = arr[i] * arr[i + 1];
            dp[i][i + 1][i + 1] = dp[i][i + 1][i];
        }
        for (int k = 2; k < n; k++) {
            for (int i = 0; i < n - k; i++) {
                for (int m = i; m <= i + k; m++) {
                    int left = i == m ? 0 : dp[i][m - 1][i] + arr[i] * arr[m];
                    for (int l = i + 1; l < m; l++) {
                        left = Math.min(left, dp[i][m - 1][l] + arr[l] * arr[m]);
                    }
                    int right = m == i + k ? 0 : dp[m + 1][i + k][i + k] + arr[i + k] * arr[m];
                    for (int r = m + 1; r < i + k; r++) {
                        right = Math.min(right, dp[m + 1][i + k][r] + arr[r] * arr[m]);
                    }
                    dp[i][i + k][m] = left + right;
                }
            }
        }
        int res = dp[0][n - 1][0];
        for (int i = 1; i < n; i++) {
            res = Math.min(res, dp[0][n - 1][i]);
        }
        System.out.println(res);
    }
}

dp[i][j][k]代表,arr[i-j]中,以k为根的最优子结构
发表于 2021-03-05 16:38:28 回复(3)
美团不当人系列...妥妥的hard  二叉树+dp, 还是三维dp, 树还要考虑虚拟节点, 而且树有三层结构parent==>root==>left+right


const readline=require('readline');
const rl=readline.createInterface({
    input:process.stdin,
    output:process.stdout
})
const arr=[];
rl.on('line',function(line){
    arr.push(line);
    if(arr.length===2){
        let len=Number(arr[0]);
        let w=arr[1].split(' ').map(Number);
        w.unshift(0);

        //1.dp[low][high][root]表示以root为根节点,
        //其左/右子树节点范围在data[low,high]内的最小开销(左或者右,只有一边)
        //2.len+1是因为需要有一个虚拟的根节点
        let dp=Array.from({length:len+1},
            ()=>Array.from({length:len+1},()=>Array(len+1).fill(-1)));
        const dfs=(low,high,root)=>{
            if(low>high) return 0;
            //如果访问过则直接返回
            if(dp[low][high][root]!=-1) return dp[low][high][root];
            //在low~high每个位置都有可能作为根节点
            let cost=Infinity;
            for(let i=low;i<=high;i++){
                let left=dfs(low,i-1,i);
                let right=dfs(i+1,high,i);
                cost=Math.min(cost,left+right+w[i]*w[root]);
            }
            return dp[low][high][root]=cost;
        }
        console.log(dfs(1,len,0));

        //next
        arr.length=0;
    }
})



发表于 2021-04-06 15:44:16 回复(0)
import java.util.*;
public class Main {
    static int n;
    static int[] array;
    static int[][][] memo;
    public static void main(String[] args){
        Scanner sc=new Scanner(System.in);
        n =sc.nextInt();
        memo=new int[n+1][n+1][n+1];
                //数组用来储存树
        array =new int[n+1];
        for(int i=1;i<=n;i++){
            array[i]=sc.nextInt();
        }
                //初始化记忆集,方便后面查看从i到j,以k为节点的树是否被遍历过
        for (int i = 0; i < n + 1; i++) {
            for (int j = 0; j < n + 1; j++) {
                for (int k = 0; k < n + 1; k++) {
                    memo[i][j][k]=-1;
                }
            }
        }
               //定义函数,返回值为是以root为根节点,从数组第start到end个数作为其中一颗子树的最小开销,
               //这个定义很重要,方便后续遍历,我们要求的就是dfs(1,array.length-1,0),为什么是0,因为可以假设
               //有个值为0的root,并且0×任何数都得零,不影响最终结果,即使你子树中的根是什么值都和我无关         
                int res=dfs(1,array.length-1,0);
        System.out.println(res);
    }
    public static int dfs(int start,int end,int root){
        if(start>end){
            return 0;
        }
                 //首先查看记忆集里有没有之前遍历过的最优子树,有就返回
        if (memo[start][end][root]!=-1){
            return memo[start][end][root];
        }
                //没有就开始深度遍历,从start到end一个个作为根算出哪个作为根结果是最小的
        int min =Integer.MAX_VALUE;
        for(int i=start;i<=end;i++){
                //根据此函数定义可以得出左子树的最优二叉树总开销
            int left=dfs(start,i-1,i);
                //同理可得右子树的总开销
            int right=dfs(i+1,end,i);
                //最后就是左子树+右子树+当前根*父根,这也是为什么一开始父根设为0的原因。我们要求最小值,就要遍历所有当前根
            min=Math.min(min,left+right+array[root]*array[i]);

        }
               //最后把结果存到记忆集中,避免重复遍历
        memo[start][end][root]=min;
        return min;
    }
}


通过全部用例
运行时间476ms
占用内存121132KB
想了好久,一开始想复杂了,dfs(start,end,root)返回的是以root为根节点,从数组第start到end个数作为其中一颗子树的最小开销
编辑于 2022-08-31 20:51:36 回复(2)
mt不当人系列
区间DP 比较好想,但是怎么转移会难想一点
如果是普通的(左区间,右区间,区间断点) 的枚举方式,会发现需要再加一层循环枚举哪个子树是更优的。
为了减少这个循环,定义两个 dp数组:dp_l[i][j], dp_r[i][j] 分别代表 区间 [i,j] 以 a[i] 为根的最优解 和 以 a[j] 为根的最优解,这样就简单了不少。
转移方程(dp_l):min(dp_l[i][j], dp_r[i + 1][k] + dp_l[k][j] + a[i] * a[k])
转移方程(dp_r):min(dp_r[i][j], dp_r[i][k] + dp_l[k][j - 1] + a[j] * a[k])

枚举区间断点的时候,直接找 dp_r[i][k] 和 dp_l[k][j] 就好了,以 a[k] 为根 在区间 [i, j] 中组成最优的树的值就是 dp_r[i][k] + dp_l[k][j]
所以最后在区间 [1,n] 中,找到最小的 dp_r[1][k] + dp_l[k][n] 就行了

发表于 2021-08-16 16:23:52 回复(1)
from functools import lru_cache


@lru_cache(1000*1000)
def dfs(start, end, father):
    ret = 2**31
    if end - start <= 0: return 0
    for i in range(start, end):
        ret1 = dfs(start, i, nums[i])
        ret2 = dfs(i + 1, end, nums[i])
        ret = min(ret, ret1 + ret2 + nums[i] * father)
    return ret

N = int(input())
nums = [int(i) for i in input().split()]
print(dfs(0, N, 0))

发表于 2021-04-03 16:36:07 回复(0)
import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        while (sc.hasNext()) {
            int n = sc.nextInt();
            int[] nums = new int[n + 1];
            for (int i = 0; i < n; i++) {
                nums[i] = sc.nextInt();
            }
            Main main = new Main();
            int ans = Integer.MAX_VALUE;
            int[][][] memo = new int[n + 1][n + 1][n + 1];
            for (int i = 0; i < n; i++) {
                ans = Math.min(ans,main.dfs(nums,0,i - 1,i,memo) + main.dfs(nums,i + 1,n - 1,i,memo));
            }
            System.out.println(ans);
        }
    }
    int dfs(int[] nums, int l, int r, int father, int[][][] memo) {
        if (l > r) return 0;
        if (l == r) return nums[l] * nums[father];
        if (memo[l][r][father] != 0) return memo[l][r][father];
        int res = Integer.MAX_VALUE;
        for (int i = l; i <= r; i++) {
            int left = dfs(nums, l, i - 1, i, memo);
            int right = dfs(nums, i + 1, r, i, memo);
            res = Math.min(res, left + right + nums[i] * nums[father]);
        }
        memo[l][r][father] = res;
        return res;
    }
}

编辑于 2021-03-23 14:40:52 回复(1)

我完全翻译的楼上的代码, 但是过不了!! 超时了, 是python的问题吗

N = 6
w = [0]*N
f = [[[-1 for _ in range(N)] for _ in range(N)] for _ in range(N)]
def dp(l, r, p):
    if l>r:
        return 0
    if f[l][r][p] != -1:
        return f[l][r][p]
    ret = float('inf')
    for i in range(l, r+1):
        left = dp(l, i-1, i)
        right = dp(i+1, r, i)
        ret = min(ret, left+right+w[i]*w[p])
    f[l][r][p] = ret
    return ret
if __name__ == '__main__':
    n = int(input())
    tmp = list(map(int, input().split(" ")))
    for i in range(1, n+1):
        w[i] = tmp[i-1]
    tmp = dp(1, n, 0)
    print(tmp)
发表于 2021-03-20 23:58:55 回复(2)
dp,超时。。。
dp[i][j][k]表示节点i到节点j范围内的节点并以k为root的子树的最小开销。
递归关系:
dp[i][j][k] = min([dp[i][k-1][p]+l[p]*l[k] for p in range(i, k)]) + min([dp[k+1][j][q]+l[q]*l[k] for q in range(k+1, j+1)])

n = int(input())
l = list(map(int, input().split(' ')))

dp = [[[0]*n for _ in range(n)] for _ in range(n)]
for jiange in range(1, n):
    for i in range(0, n-jiange):
        j = i + jiange
        for k in range(i, i+jiange+1):
            if jiange == 1:
                dp[i][j][k] = l[i]*l[j]
            else:
                if k == i:
                    dp[i][j][k] = min([dp[k+1][j][q]+l[q]*l[k] for q in range(k+1, j+1)])
                elif k == j:
                    dp[i][j][k] = min([dp[i][k-1][p]+l[p]*l[k] for p in range(i, k)])
                else:
                    aa = min([dp[i][k-1][p]+l[p]*l[k] for p in range(i, k)])
                    bb = min([dp[k+1][j][q]+l[q]*l[k] for q in range(k+1, j+1)])
                    dp[i][j][k] = aa + bb
                
print(min(dp[0][n-1]))

发表于 2021-03-13 09:42:08 回复(2)
仿照第一个评论的java版本代码 ,用于学习参考
import java.util.Scanner;
// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    static int[][] leftdp = new int[302][302];
    static int[][] rightdp = new int[302][302];
    static  int[] node = new int[302];
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        //System.out.println("输入一个整数表示节点个数:");
        int n = sc.nextInt();
        for (int i = 1; i <= n; i++) {
            //System.out.println("输入整数表示第" + i + "个节点的权重:");
            node[i] = sc.nextInt();
        }
        for (int i = 0; i < n; i++) {
            for (int j = 1; j <= n - i; j++) {
                f(j, j + i);
            }
        }
        System.out.print(leftdp[1][n]);
    }
    public static void f(int a, int b) {
        // System.out.println(leftdp[a][b]);
        leftdp[a][b] = rightdp[a][b]  = 1000000000;
        int x, y;
        for (int i = a; i <= b; i++) {
            x = leftdp[a][i - 1] + node[i] * node[b + 1] + rightdp[i + 1][b];
            y = leftdp[a][i - 1] + node[i] * node[a - 1] + rightdp[i + 1][b];
            if (x < leftdp[a][b]) leftdp[a][b] = x;
            if (y < rightdp[a][b]) rightdp[a][b] = y;
        }
    }
}

发表于 2024-03-17 12:40:35 回复(0)
树形 dp,第三个维度大小为 2 即可:
0 表示当前 i-j 是上一层的左子树,根节点为 j + 1;
1 表示当前 i-j 是上一层的右子树,根节点为 i - 1
import java.util.Scanner;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int[] a = new int[n + 2];
        for (int i = 1; i <= n; i++) {
            a[i] = in.nextInt();
        }
        int[][][] f = new int[n + 2][n + 2][2];
        for (int i = n; i >= 1; i--) {
            f[i][i][0] = a[i] * a[i + 1];
            f[i][i][1] = a[i] * a[i - 1];
            for (int j = i + 1; j <= n; j++) {
                f[i][j][0] = Integer.MAX_VALUE;
                f[i][j][1] = Integer.MAX_VALUE; 
                for (int k = i; k <= j; k++) {
                    f[i][j][0] = Math.min(f[i][j][0], a[k] * a[j + 1] + f[i][k - 1][0] + f[k + 1][j][1]);
                    f[i][j][1] = Math.min(f[i][j][1], a[k] * a[i - 1] + f[i][k - 1][0] + f[k + 1][j][1]);
                }
            }
        }
        System.out.println(f[1][n][0]);
    }
}


编辑于 2024-03-01 13:23:08 回复(0)
先写递归,超时了再加dp,dp[i][j][k],我的k其实只能取i-1或j,表示[i,j]区间内以k=j为根节点最优解,或者[i-1,j)以k=i-1为根节点的最优解。之所以k只能取这两个值是因为我先写的递归,递归的时候遍历根节点然后分割为左子树和右子树,所以根节点要么是左子树的下一个点要么是右子树上一个点。
当然,可以优化为dp[i][j][2]
import java.util.*;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int[] nums = new int[n];
        for (int i = 0; i < n; i++) {
            nums[i] = in.nextInt();
        }
        int[][][] dp = new int[n + 1][n + 1][n + 1];
        System.out.println(cost(nums, dp));
    }
    public static int cost(int[] nums, int[][][] dp) {
        int n = nums.length;
        if (n == 1) return 0;
        if (n == 2) return nums[0] * nums[1];
        int ans = Integer.MAX_VALUE;
        for (int i = 0; i < n; i++) {
//            System.out.println(nums[i]);
            ans = Math.min(ans, bt(nums, 0, i, i, dp) + bt(nums, i + 1, n, i, dp));
        }
        return ans;
    }
    public static int bt(int[] nums, int l, int r, int root, int[][][] dp) {
        if (l == r) return 0;
        if (r == l + 1) {
            if (dp[l][r][root] == 0) dp[l][r][root] = nums[root] * nums[l];
            return dp[l][r][root];
        }
        if (r == l + 2) {
            if (dp[l][r][root] == 0) dp[l][r][root] = nums[l] * nums[l + 1] + Math.min(
                            nums[l], nums[l + 1]) * nums[root];
            return dp[l][r][root];
        }
        if (dp[l][r][root] != 0) return dp[l][r][root];
        int ans = Integer.MAX_VALUE;
        for (int i = l; i < r; i++) {
            if (dp[l][i][i] == 0) dp[l][i][i] = bt(nums, l, i, i, dp);
            if (dp[i + 1][r][i] == 0) dp[i + 1][r][i] = bt(nums, i + 1, r, i, dp);
            ans = Math.min(ans, dp[l][i][i] + dp[i + 1][r][i] + nums[i] * nums[root]);
        }
        dp[l][r][root] = ans;
        return ans;
    }
}


发表于 2023-03-31 15:07:13 回复(0)
整了半天就通过了一个用例
发表于 2023-03-04 16:13:52 回复(2)
#include <bits/stdc++.h>
using namespace std;
const int N = 3e2+5;
int n,m,x,y;
int a[N], dp[N][N][N];
int dfs(int l,int r, int fa){
    if(l > r)return 0;
    if(dp[l][r][fa] != -1)return dp[l][r][fa];
    int ans = 1e9;
    for(int i = l; i <= r; i++){
        int left = dfs(l, i-1, i);
        int right = dfs(i+1, r, i);
        ans = min(ans, left + right + a[i] * a[fa]);
    }
    dp[l][r][fa] = ans;
    return ans;
}
int main(){
    cin>>n;
    for(int i = 1; i <= n; i++){
        cin>>a[i];
    }
    memset(dp, -1, sizeof(dp));
    dfs(1,n,0);
    cout<<dp[1][n][0]<<'\n';
    return 0;
} 

发表于 2022-08-12 21:07:39 回复(0)
递归+备忘录
import java.io.*;
import java.util.*;
public class Main {
static String[] tree;
    static int[][][] dp;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        int size = Integer.parseInt(br.readLine().trim());
        tree = br.readLine().trim().split(" ");
        int min = Integer.MAX_VALUE;
        dp = new int[size][size][size];
        for(int i = 0; i < size; i++){
            for(int j = 0; j < size; j++)
                for(int k = 0; k < size; k++) dp[i][j][k] = Integer.MAX_VALUE;
        }
        for (int i = 0; i < size; i++) {
            int leftTree = build(0,i-1,i);
            int rightTree = build(i+1, size-1,i);
            min = Math.min(min, leftTree + rightTree);
        }
        bw.write(min+"\n");
        bw.flush();
    }

    public static int build(int left,int right,int parent){
        if (left > right){
            return 0;
        }
        if (dp[left][right][parent] != Integer.MAX_VALUE){
            return dp[left][right][parent];
        }
        for (int i = left; i <= right; i++) {
            int leftTree = build(left,i-1,i);
            int rightTree = build(i+1, right,i);
            dp[left][right][parent] = Math.min(dp[left][right][parent], Integer.parseInt(tree[parent]) * Integer.parseInt(tree[i]) + leftTree + rightTree);
        }
        return dp[left][right][parent];
    }     
}
发表于 2022-03-21 19:38:18 回复(0)
区间dp
import java.util.*;
  public class Main{
      public static void main(String args[]){
             Scanner sc=new Scanner(System.in);
            int n=sc.nextInt();
            int arr[]=new int[n];
            for(int i=0;i<n;i++){
                   arr[i]=sc.nextInt();
            }
           int dp[][][]=new int[n][n][n];
      
          for(int i=1;i<=n;i++){
               //区间长度
              for(int j=0;j<=n-i;j++){
                      //区间起始点
                    for(int k=j;k<j+i;k++){
                            //分割点
                        int lm=Integer.MAX_VALUE,rm=Integer.MAX_VALUE,max=Integer.MAX_VALUE;
                        for(int m=j;m<k;m++){
                              lm=Math.min(lm,dp[j][k-1][m]+arr[k]*arr[m]);
                        }
                        for(int f=k+1;f<j+i;f++){
                            rm=Math.min(rm,dp[k+1][j+i-1][f]+arr[k]*arr[f]);
                        }
                        if(lm==max&&rm==max){
                               dp[j][j+i-1][k]=0;
                        }
                        else if(lm==max){
                            dp[j][j+i-1][k]=rm;
                        }
                        else if(rm==max){
                              dp[j][j+i-1][k]=lm;
                        }
                        else{
                        dp[j][j+i-1][k]=lm+rm;
                        }
                    }
              }
          }
           int ans=Integer.MAX_VALUE;
          for(int i=0;i<n;i++){
              ans=Math.min(ans,dp[0][n-1][i]);
          }
          System.out.println(ans);
      }
  }


发表于 2022-03-05 22:24:20 回复(0)
记忆化递归
#include<iostream>
#include<fstream>
#include<vector>
#include<unordered_map>
#include<unordered_set>
#include<set>
#include<map>
#include<set>
#include<queue>
#include<stack>
#include<deque>
#include<cmath>
#include<algorithm>
#include<string>
#include<cstring>
using namespace std;
int dp[310][310][310];
int dfs(vector<int>&res,int pos,int l,int r)
{
    if(l>r) return 0;
    if(dp[pos][l][r]!=-1) return dp[pos][l][r];
    int ans=pow(10,9);
    for(int i=l;i<=r;i++)
    {
        ans=min(ans,res[pos]*res[i]+dfs(res,i,l,i-1)+dfs(res,i,i+1,r)); 
    }
    dp[pos][l][r]=ans;
    return ans;
}
int main()
{
    memset(dp,-1,sizeof dp);
    int n;cin>>n;
    vector<int>res(n);
    for(int i=0;i<n;i++) cin>>res[i];
    int ans=pow(10,9);
    for(int i=0;i<n;i++)
    {
        ans=min(ans,dfs(res,i,0,i-1)+dfs(res,i,i+1,n-1));
    }
    cout<<ans;
    return 0;
     
}


编辑于 2021-07-01 21:45:14 回复(0)
树形dp加记忆化搜索,超时
import java.util.*;
import java.io.*;

public class Main{
    public static Info[][] dp;
    public static void main(String[] args) throws IOException{
        BufferedReader br=new BufferedReader(new InputStreamReader(System.in));
        int cnt=Integer.parseInt(br.readLine());
        int[] nums=new int[cnt];
        String[] temp=br.readLine().split(" ");
        for(int i=0;i<cnt;i++) nums[i]=Integer.parseInt(temp[i]);
        dp=new Info[cnt][cnt];
        List<int[]> l=process(nums,0,cnt-1).res;
        int res=Integer.MAX_VALUE;
        for(int[] cur:l){
            res=Math.min(res,cur[0]);
        }
        System.out.print(res);
    }
    public static Info process(int[] nums,int left,int right){
        if(left>right) return null;
        List<int[]> l=new ArrayList<>();
        for(int i=left;i<=right;i++){
            int mul1=0;
            Info leftInfo;
            Info rightInfo;
            if(i-1<0||dp[left][i-1]==null) leftInfo=process(nums,left,i-1);
            else leftInfo=dp[left][i-1];
            if(i+1>=nums.length||dp[i+1][right]==null) rightInfo=process(nums,i+1,right);
            else rightInfo=dp[i+1][right];
            if(leftInfo!=null){
                int mul2=Integer.MAX_VALUE;
                for(int[] cur:leftInfo.res){
                    mul2=Math.min(mul2,cur[0]+nums[cur[1]]*nums[i]);
                }
                mul1+=mul2;
            }
            if(rightInfo!=null){
                int mul2=Integer.MAX_VALUE;
                for(int[] cur:rightInfo.res){
                    mul2=Math.min(mul2,cur[0]+nums[cur[1]]*nums[i]);
                }
                mul1+=mul2;
            }
            l.add(new int[]{mul1,i});
        }
        Info result=new Info(l);
        dp[left][right]=result;
        return result;
    }
}
class Info{
    List<int[]> res;
    public Info(List<int[]> res){
        this.res=res;
    }
}


发表于 2021-06-19 00:54:38 回复(0)