题解 | #【模板】Trie 字典树#

【模板】Trie 字典树

https://www.nowcoder.com/practice/feed1cd7546a4901965751b9fbf5f8a1

题目链接

题目描述

给定 个模式串 次查询。每次给定一个字符串 ,需要统计使得 的前缀的模式串个数(区分大小写,字母范围为大小写英文字母)。

解题思路

  • 用 Trie 逐字符插入每个 ,在每个节点维护通过该节点的计数 (即有多少模式串以该节点路径为前缀)。
  • 回答查询 :沿 Trie 匹配 。若中途不存在边,则答案为 ;否则返回终止节点的
  • 字母表大小为 (先小写后大写或反之均可),用数组下标映射字符以获得线性时间。

代码

#include <bits/stdc++.h>
using namespace std;

static inline int idx(char c) {
    if ('a' <= c && c <= 'z') return c - 'a';
    return 26 + (c - 'A');
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int m, q;
    if (!(cin >> m >> q)) return 0;
    const int SIG = 52;
    vector<array<int, SIG>> nxt(1);
    vector<int> cnt(1, 0);
    nxt[0].fill(-1);

    // 插入所有模式串
    for (int i = 0; i < m; ++i) {
        string s; cin >> s;
        int u = 0;
        for (char ch : s) {
            int c = idx(ch);
            if (nxt[u][c] == -1) {
                nxt[u][c] = (int)nxt.size();
                nxt.push_back({}); nxt.back().fill(-1);
                cnt.push_back(0);
            }
            u = nxt[u][c];
            cnt[u] += 1; // 通过该前缀的模式串计数+1
        }
    }

    // 查询
    while (q--) {
        string t; cin >> t;
        int u = 0; bool ok = true;
        for (char ch : t) {
            int c = idx(ch);
            if (nxt[u][c] == -1) { ok = false; break; }
            u = nxt[u][c];
        }
        cout << (ok ? cnt[u] : 0) << '\n';
    }
    return 0;
}
import java.util.*;

public class Main {
    static int idx(char c) {
        if ('a' <= c && c <= 'z') return c - 'a';
        return 26 + (c - 'A');
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int m = sc.nextInt();
        int q = sc.nextInt();
        final int SIG = 52;
        ArrayList<int[]> nxt = new ArrayList<>();
        ArrayList<Integer> cnt = new ArrayList<>();
        nxt.add(new int[SIG]);
        Arrays.fill(nxt.get(0), -1);
        cnt.add(0);

        for (int i = 0; i < m; i++) {
            String s = sc.next();
            int u = 0;
            for (int iCh = 0; iCh < s.length(); iCh++) {
                int c = idx(s.charAt(iCh));
                if (nxt.get(u)[c] == -1) {
                    int[] node = new int[SIG];
                    Arrays.fill(node, -1);
                    nxt.get(u)[c] = nxt.size();
                    nxt.add(node);
                    cnt.add(0);
                }
                u = nxt.get(u)[c];
                cnt.set(u, cnt.get(u) + 1);
            }
        }

        while (q-- > 0) {
            String t = sc.next();
            int u = 0; boolean ok = true;
            for (int iCh = 0; iCh < t.length(); iCh++) {
                int c = idx(t.charAt(iCh));
                int v = nxt.get(u)[c];
                if (v == -1) { ok = false; break; }
                u = v;
            }
            System.out.println(ok ? cnt.get(u) : 0);
        }
    }
}
import sys

def idx(c: str) -> int:
    if 'a' <= c <= 'z':
        return ord(c) - ord('a')
    return 26 + (ord(c) - ord('A'))

data = sys.stdin.read().strip().split()
it = iter(data)
m = int(next(it)); q = int(next(it))
SIG = 52
nxt = [[-1] * SIG]
cnt = [0]

for _ in range(m):
    s = next(it)
    u = 0
    for ch in s:
        c = idx(ch)
        if nxt[u][c] == -1:
            nxt[u][c] = len(nxt)
            nxt.append([-1] * SIG)
            cnt.append(0)
        u = nxt[u][c]
        cnt[u] += 1

out_lines = []
for _ in range(q):
    t = next(it)
    u = 0
    ok = True
    for ch in t:
        c = idx(ch)
        v = nxt[u][c]
        if v == -1:
            ok = False
            break
        u = v
    out_lines.append(str(cnt[u] if ok else 0))

sys.stdout.write("\n".join(out_lines))

算法及复杂度

  • 算法:Trie 统计节点通过次数,查询时走到对应节点返回计数。
  • 时间复杂度:构建 ,每次查询 ,总计线性于总长度。
  • 空间复杂度
全部评论

相关推荐

还排名这么靠前?
赛博小蟑螂:虽然时间长,但是他工资低啊
投递浪潮等公司10个岗位
点赞 评论 收藏
分享
人间雪:简历最好只要一页,除非你牛逼到一页都写不下了
点赞 评论 收藏
分享
程序员小白条:学历和简历问题,你想走开发,现在很难的啦,尤其后端方向很难走,前端、测开,都会好很多,另外要等8月底和9月初去投日常
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务