首页 > 试题广场 >

牛牛的背包问题

[编程题]牛牛的背包问题
  • 热度指数:6171 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 32M,其他语言64M
  • 算法知识视频讲解
牛牛准备参加学校组织的春游, 出发前牛牛准备往背包里装入一些零食, 牛牛的背包容量为w。
牛牛家里一共有n袋零食, 第i袋零食体积为v[i]。
牛牛想知道在总体积不超过背包容量的情况下,他一共有多少种零食放法(总体积为0也算一种放法)。

输入描述:
输入包括两行
第一行为两个正整数n和w(1 <= n <= 30, 1 <= w <= 2 * 10^9),表示零食的数量和背包的容量。
第二行n个正整数v[i](0 <= v[i] <= 10^9),表示每袋零食的体积。


输出描述:
输出一个正整数, 表示牛牛一共有多少种零食放法。
示例1

输入

3 10
1 2 4

输出

8

说明

三种零食总体积小于10,于是每种零食有放入和不放入两种情况,一共有2*2*2 = 8种情况。
整体思路来自左程云老师的算法体系课。根据数据量猜解法。
自己使用了改写有序表简化了分治后的运算逻辑。但改写有序表的代码量也不少。-_-||

由于零食体积过大,使用动态规划肯定超时,但零食数量较小。分治后左右分别动态规划不会超时。

采用分治思想。
将arr分为左右两部分。
左右分别暴力枚举所有选择方法,
使用map收集所有方法的背包容量。(每种方法的背包容量都是不同的,超出w的就不要收集了)

定义ans收集答案。
遍历左侧结果集 
    当前背包容量为left 
    从右侧结果集种查出小于等于w-left的方法数funCount
    ans累加上funCount
返回ans
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        int bag = sc.nextInt();
        int[] arr = new int[N];
        for (int i = 0; i < arr.length; i++) {
            arr[i] = sc.nextInt();
        }
        long ways = ways(arr, bag);
        System.out.println(ways);
        sc.close();
    }

    /**
     * 分治统计有多少种零食的拿法。
     */
    private static long ways(int[] arr, int bag) {
        int N = arr.length;
        int mid = arr.length-1 >> 1;
        SBTSet<Long> leftSet = new SBTSet<>();
        process(arr, bag, 0, mid, 0L,leftSet);
        SBTSet<Long> rightSet = new SBTSet<>();
        process(arr, bag, mid+1, arr.length-1, 0L, rightSet);

        int ans = 0;
        for (int i = 0; i < leftSet.size(); i++) {
            Long leftWeight = leftSet.getIndexKey(i);
            ans += rightSet.lessEqualsCount(bag - leftWeight);
        }
        return ans;
    }

    /**
     * 当前来到i位置,背包内放入了preSum体积的零食,
     * 将i..end的所有选择后的体积放入set,但要求总体积不能超过bag
     * 如果当前来到end+1位置,set添加preSum后返回。
     * 尝试不拿当前位置零食,递归调用(i+1, preSum)
     * 如果arr[i]+preSum<=bag, 递归调用(i+1, preSum+arr[i])
     */
    private static void process(int[] arr, int bag, int i, int end, Long preSum, SBTSet<Long> set) {
        if(i == end+1){
            set.put(preSum);
            return;
        }
        process(arr, bag, i+1,end, preSum, set);
        if(arr[i]+preSum<=bag){
            process(arr, bag, i+1,end, preSum+arr[i], set);
        }
    }

    public static class SBTNode<K extends Comparable<K>>{
        public K key;
        public int all;
        public int size;
        public SBTNode l;
        public SBTNode r;
        public SBTNode(K k){
            key = k;
            size = 1;
            all = 1;
        }
    }

    public static class SBTSet<K extends Comparable<K>>{
        private SBTNode<K> root;

        public int size(){
            return getAll(root);
        }

        /**
         * 放入一个数据。
         */
        public void put(K key){
            root = add(root, key);
        }

        /**
         * 在cur树上新增节点,值为key。
         */
        private SBTNode<K> add(SBTNode<K> cur, K key) {
            if(cur==null) return new SBTNode<>(key);
            cur.all++;
            if(key.compareTo(cur.key)==0) return cur;
            if(key.compareTo(cur.key)<0){
                cur.l = add(cur.l, key);
            }else {
                cur.r = add(cur.r, key);
            }
            cur.size = getSize(cur.l)+getSize(cur.r)+1;
            return maintain(cur);
        }

        /**
         * 调整cur节点的平衡性并返回新节点。
         *
         */
        private SBTNode<K> maintain(SBTNode<K> cur) {
            int ls = getSize(cur.l);
            int lls = 0;
            int lrs = 0;
            if(cur.l != null){
                lls = getSize(cur.l.l);
                lrs = getSize(cur.l.r);
            }
            int rs = getSize(cur.r);
            int rrs = 0;
            int rls = 0;
            if(cur.r != null){
                rrs = getSize(cur.r.r);
                rls = getSize(cur.r.l);
            }

            if(lls>rs){
                cur = rightRotate(cur);
                cur.r = maintain(cur.r);
                cur = maintain(cur);
            }else if(lrs>rs){
                cur.l = leftRotate(cur.l);
                cur = rightRotate(cur);
                cur.l = maintain(cur.l);
                cur.r= maintain(cur.r);
                cur = maintain(cur);
            }else if(rrs>ls){
                cur = leftRotate(cur);
                cur.l = maintain(cur.l);
                cur = maintain(cur);
            }else if(rls>ls){
                cur.r = rightRotate(cur.r);
                cur = leftRotate(cur);
                cur.l = maintain(cur.l);
                cur.r = maintain(cur.r);
                cur = maintain(cur);
            }

            return cur;
        }

        /**
         * 左旋
         */
        private SBTNode<K> leftRotate(SBTNode<K> cur) {
            int same = cur.all - getAll(cur.l) - getAll(cur.r);
            SBTNode<K> right = cur.r;
            cur.r = right.l;
            right.l = cur;
            right.size = cur.size;
            cur.size = getSize(cur.l)+getSize(cur.r)+1;
            right.all = cur.all;
            cur.all = getAll(cur.l)+getAll(cur.r)+same;
            return right;
        }

        private int getAll(SBTNode<K> cur) {
            return cur==null?0:cur.all;
        }

        /**
         * 右旋
         */
        private SBTNode<K> rightRotate(SBTNode<K> cur) {
            int same = cur.all - getAll(cur.l) - getAll(cur.r);
            SBTNode left = cur.l;
            cur.l = left.r;
            left.r = cur;
            left.size = cur.size;
            cur.size = getSize(cur.l)+getSize(cur.r)+1;
            left.all = cur.all;
            cur.all = getAll(cur.l)+getAll(cur.r)+same;
            return left;
        }

        private int getSize(SBTNode cur) {
            return cur==null?0:cur.size;
        }

        /**
         * 返回小于等于指定值的个数。
         * 从root开始向下滑。
         * 等于返回答案。
         * 向右滑收集答案。
         * 向左滑不收集。
         */
        public int lessEqualsCount(K key){
            SBTNode<K> cur = root;
            int ans = 0;
            while (cur!=null){
                if(key.compareTo(cur.key)==0) return ans+cur.all-getAll(cur.r);
                if(key.compareTo(cur.key)<0){
                    cur = cur.l;
                }else {
                    ans += cur.all-getAll(cur.r);
                    cur = cur.r;
                }
            }
            return ans;
        }

        /**
         * 返回排序后位于指定位置的key
         * 检查key是否超出集合范围。
         * 从root向下滑。
         * 如果小于左树all范围,来到左树找。
         * 如果大于等于cur.all-右树.all,去右树找。
         * 否则,返回当前值。
         */
        public K getIndexKey(int index){
            if(index<0 || index>=getAll(root)) throw new RuntimeException("out of range");
            return getIndex(root, index).key;
        }

        /**
         * 返回cur树上的第k个节点。
         */
        private SBTNode<K> getIndex(SBTNode<K> cur, int index){
            if(index<getAll(cur.l)){
                return getIndex(cur.l, index);
            }else if(index >= cur.all-getAll(cur.r)){
                return getIndex(cur.r, index-(cur.all-getAll(cur.r)));
            }else {
                return cur;
            }
        }
    }
}


发表于 2022-03-09 16:51:30 回复(0)
import java.util.Map;
import java.util.Scanner;
import java.util.TreeMap;

public class Main {

        // 根据数据量的规模,选择用分治来处理
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        int bag = sc.nextInt();
        int[] arr = new int[N];
        for (int i = 0; i < arr.length; i++) {
            arr[i] = sc.nextInt();
        }
        long ways = ways(arr, bag);
        System.out.println(ways);
        sc.close();
    }

    public static long ways(int[] arr, int bag) {
        // process过程定义好,那么在当前函数中就应该这样调用:
        // arr  30
        // process(arr, 0, 14, 0, bag, map)
        // process(arr, 15, 29, 0, bag, map)
        if (arr == null || arr.length == 0) {
            return 0;
        }
        if (arr.length == 1) {
            return arr[0] <= bag ? 2 : 1;
        }
        // 分治算法
        int mid = (arr.length - 1) >> 1;
        TreeMap<Long, Long> lmap = new TreeMap<>();
        // 只从左侧选取得到的符合题意的方法数
        long ways = process(arr, 0, mid, 0, bag, lmap);
        TreeMap<Long, Long> rmap = new TreeMap<>();
        // 只从右侧选取得到的符合题意的方法数(和左侧的加起来)
        ways += process(arr, mid + 1, arr.length - 1, 0, bag, rmap);
        // 两张表lmap和rmap分别已经收集好左右两部分所有符合条件sum及其方法数
        // 此时把右侧表建立一个前缀和表出来
        TreeMap<Long, Long> rpre = new TreeMap<>();
        long pre = 0;
        for (Map.Entry<Long, Long> entry : rmap.entrySet()) {
            pre += entry.getValue();
            rpre.put(entry.getKey(), pre);
        }
        // 右侧的前缀和表已经建好
        // 接下来左侧严格按照表中每一个数据来和右侧符合条件的数据相结合
        for (Map.Entry<Long, Long> entry : lmap.entrySet()) {
            long lweight = entry.getKey();
            long lways = entry.getValue();
            // 左侧当前的零食大小是lweight,右侧在不超过bag的情况下最大的装载量是floor
            Long floor = rpre.floorKey(bag - lweight);// floorKey()方法意即找到小于等于该数且离他最近的那个数
            // 右侧只要零食大小不超过floor,都视为符合条件,找出他们的方法数
            if (floor != null) {
                long rways = rpre.get(floor);
                ways += lways * rways;
            }
        }
        return ways + 1;
    }

    // 从index出发,到end结束
    // 之前的选择,已经形成的累加和为sum
    // 零食[index...end]自由选择,出来的所有累加和不能超过bag,每一种累加和对应的方法数,填在map里
    // 最后不能什么货都没选(最后直接在总方法数上加1即可表示这种情况)
    // 举例:
    // [3,3,3,3]  bag=6
    //  0 1 2 3
    //  - - - -   0 -> (0 : 1)
    //  - - - $   3 -> (0 : 1) (3, 1)
    //  - - $ -   3 -> (0 : 1) (3, 2)
    public static long process(int[] arr, int index, int end, long sum, int bag, TreeMap<Long, Long> map) {
        if (sum > bag) {
            return 0;
        }
        // sum <= bag
        if (index > end) {
            if (sum != 0) {
                if (!map.containsKey(sum)) {
                    map.put(sum, 1L);
                } else {
                    map.put(sum, map.get(sum) + 1);
                }
                return 1;
            } else {
                // sum==0 说明什么都没选,这种情况不计数
                return 0;
            }
        }
        // sum < bag && index <= end(还有货)
        // 1) 不要当前index位置的货
        long ways = process(arr, index + 1, end, sum, bag, map);

        // 2) 要当前index位置的货
        ways += process(arr, index + 1, end, sum + arr[index], bag, map);
        return ways;
    }


}

发表于 2022-01-12 11:31:33 回复(0)
package test;

import java.util.Scanner;
import java.io.*;
public class test{
    static int n=0;
    static long[] v=new long[40];
    static long ans=0;
    static long w;
     
         public static void dfs( int t ,long sum) {
        
                if(sum>w)
                {
                    return ;
                }
                /*    else
                {
                    ans++;
                }*/
            /* if(sum<w&&sum+v[t]>w)
                {
                    ans++;
                    return ;
                }*/
                if(t==n)
                {
                    ans++;
                    return ;
                }
                
                dfs(t+1,(sum+v[t]));
                
                dfs(t+1,sum);
            
                
            }             
    
    public static void main (String[] args){
        Scanner scanner = new Scanner(System.in);
         n = scanner.nextInt();
         w = scanner.nextLong();
        long sum=0;
        
        for(int j=0;j<n;j++)
        {
            v[j]= scanner.nextLong();
            sum+=v[j];
        }
        if(sum<=w)
        {
            ans=1<<n;
        }
        else
        {
            dfs(0,0);
        }
        System.out.print(ans);
    }
    
}
发表于 2018-06-13 16:00:16 回复(0)