首页 > 试题广场 >

小A的线段(easy version)

[编程题]小A的线段(easy version)
  • 热度指数:700 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
\hspace{15pt}在坐标轴的整数点 1\sim n 上给出 m 条闭区间线段,第 i 条线段用其端点 [\,st_i,\,ed_i\,] 描述。

\hspace{15pt}现在要从这 m 条线段中选择若干条,使得每个整数点至少两条所选线段覆盖。求满足条件的选择方案数量;两种方案视为不同,当且仅当存在某条线段在两方案中的"选/不选"状态不同。

\hspace{15pt}答案对 P=998\,244\,353 取模。

输入描述:
\hspace{15pt}第一行输入整数 n,m\;(2\leqq n\leqq 10^5,\ 1\leqq m\leqq 10)
\hspace{15pt}随后 m 行,每行两个整数 st_i,ed_i1\leqq st_i<ed_i\leqq n) 描述一条线段。


输出描述:
\hspace{15pt}输出满足条件的方案数对 998244353 取模的结果。
示例1

输入

5 4
4 5
1 5
3 5
1 4

输出

3
状态压缩 (m很小) + 差分 (TreeMap版)
import java.util.*;

// m较小, 状态压缩枚举所有可能 + 差分检查
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt(), m = in.nextInt();
        int[][] line = new int[m][2];
        for (int i = 0; i < m; i++) {
            line[i][0] = in.nextInt();
            line[i][1] = in.nextInt();
        }

        int res = 0;
        for (int i = 1; i < (1 << m); i++) {
            TreeMap<Integer, Integer> map = new TreeMap<>();
            for (int x = i, k = 0; x > 0; x >>= 1, k++) {
                if ((x & 1) == 1) { // 选:差分 [line[k][0],line[k][1]]区间都+1
                    map.merge(line[k][0], 1, Integer::sum);
                    map.merge(line[k][1] + 1, -1, Integer::sum);
                }
            }
            if (check(map, n)) {
                res++;
            }
        }
        System.out.print(res);
    }

    // 检查[1,n]每个点能否被覆盖至少2次
    private static boolean check(TreeMap<Integer, Integer> map, int n) {
        if (map.firstKey() != 1 || map.lastKey() != n + 1)// 未覆盖[1,n]
            return false;
        int sum = 0;
        for (Map.Entry<Integer, Integer> e : map.entrySet()) {
            sum += e.getValue();
            if (e.getKey() <= n && sum < 2) // 中间次数 < 2
                return false;
        }
        return true;
    }
}


发表于 2025-10-09 17:45:26 回复(0)
使用回溯全排列+贪心,首先对于最多10条线段按左端点排序,然后通过简单回溯获得全排列共最多1024种情况。对于每一种情况,维护两条结果线段,初始两条线段都是[0,0],对于每一个排列中的线段,我如果当前线段无法并入其中的一条,那么他右边的肯定也无法并入,那么肯定这种排列是错的,提前减枝。否则就将这条线段并入右边更短的线段。如果提前满足两条线段都覆盖1~n了,就count++。如果线段用完了还是没有覆盖完,就失败了。最后输出count即可。
import
java.util.*;

public class Main {
    static int count=0;
    public static boolean canCover(int[][] arr,List<Integer> path,int n){
        if(path.size()==0) return false;
        int left1=0;
        int right1=0;
        int left2=0;
        int right2=0;
        for(int num:path){
            int left = arr[num][0];
            int right = arr[num][1];

            //如果当前数字已经无法让right更小的那个满足了,那么直接false
            if(left>right1+1||left>right2+1) return false;
            //当前的两段都可以融入,让right更短的融入
            if(right1<=right2){
                //融入第一段
                right1=Math.max(right1,right);
            }
            else{
                //融入第二段
                right2=Math.max(right2,right);
            }
            if(right1>=n&&right2>=n){
                return true;
            }

        }
        return false;
    }
    public static void backtrack(int[][] arr,int index,List<Integer> path,int n){
        if(index==arr.length){
            //满了
            if(canCover(arr,path,n)){
                count++;
            }
            return;
        }
        //获取全排列
        path.add(index);
        backtrack(arr,index+1,path,n);
        path.remove(path.size()-1);
        backtrack(arr,index+1,path,n);
    }
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int m = in.nextInt();
        int[][] arr = new int[m][2];
        for(int i =0;i<m;i++){
            for(int j=0;j<=1;j++){
                arr[i][j]=in.nextInt();
            }
        }
        Arrays.sort(arr,(a,b)->a[0]-b[0]);
        backtrack(arr,0,new ArrayList<>(),n);
        System.out.println(count);
       
    }
}

编辑于 2025-09-05 02:26:54 回复(1)
n, m = map(int, input().split())
line = [[0, 0]]*m
select = [True]*m
for i in range(m):
    sti, edi = map(int, input().split())
    line[i] = [sti, edi]

def check(line):
    cover = [0]*n
    for j in range(len(line)):
        temp = [0]*(line[j][0]-1) + [1]*(line[j][1]-line[j][0]+1) + [0]*(n-line[j][1])
        cover = [temp[i]+cover[i] for i in range(len(temp))]
    return cover

def check_update(state, j, line):
    temp = [0]*(line[j][0]-1) + [1]*(line[j][1]-line[j][0]+1) + [0]*(n-line[j][1])
    cover = [-temp[i]+state[i] for i in range(len(temp))]
    return cover

cnt = [0]
def dfs(i, state, line):
    if i == m:
        cnt[0] = (cnt[0] + 1)%998244353
        return

    if min(state) < 2:
        return
   
    state_ = check_update(state, i, line)
    if min(state_) >= 2:
        dfs(i+1, state_, line)
   
    dfs(i+1, state, line)
   

cover = check(line)
dfs(0, cover, line)
print(cnt[0]%998244353)
发表于 2025-08-14 23:55:38 回复(1)