题解 | #【模板】Trie 字典树#
【模板】Trie 字典树
https://www.nowcoder.com/practice/feed1cd7546a4901965751b9fbf5f8a1
题目链接
题目描述
给定 个模式串
与
次查询。每次给定一个字符串
,需要统计使得
为
的前缀的模式串个数(区分大小写,字母范围为大小写英文字母)。
解题思路
- 用 Trie 逐字符插入每个
,在每个节点维护通过该节点的计数
(即有多少模式串以该节点路径为前缀)。
- 回答查询
:沿 Trie 匹配
。若中途不存在边,则答案为
;否则返回终止节点的
。
- 字母表大小为
(先小写后大写或反之均可),用数组下标映射字符以获得线性时间。
代码
#include <bits/stdc++.h>
using namespace std;
static inline int idx(char c) {
if ('a' <= c && c <= 'z') return c - 'a';
return 26 + (c - 'A');
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int m, q;
if (!(cin >> m >> q)) return 0;
const int SIG = 52;
vector<array<int, SIG>> nxt(1);
vector<int> cnt(1, 0);
nxt[0].fill(-1);
// 插入所有模式串
for (int i = 0; i < m; ++i) {
string s; cin >> s;
int u = 0;
for (char ch : s) {
int c = idx(ch);
if (nxt[u][c] == -1) {
nxt[u][c] = (int)nxt.size();
nxt.push_back({}); nxt.back().fill(-1);
cnt.push_back(0);
}
u = nxt[u][c];
cnt[u] += 1; // 通过该前缀的模式串计数+1
}
}
// 查询
while (q--) {
string t; cin >> t;
int u = 0; bool ok = true;
for (char ch : t) {
int c = idx(ch);
if (nxt[u][c] == -1) { ok = false; break; }
u = nxt[u][c];
}
cout << (ok ? cnt[u] : 0) << '\n';
}
return 0;
}
import java.util.*;
public class Main {
static int idx(char c) {
if ('a' <= c && c <= 'z') return c - 'a';
return 26 + (c - 'A');
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int m = sc.nextInt();
int q = sc.nextInt();
final int SIG = 52;
ArrayList<int[]> nxt = new ArrayList<>();
ArrayList<Integer> cnt = new ArrayList<>();
nxt.add(new int[SIG]);
Arrays.fill(nxt.get(0), -1);
cnt.add(0);
for (int i = 0; i < m; i++) {
String s = sc.next();
int u = 0;
for (int iCh = 0; iCh < s.length(); iCh++) {
int c = idx(s.charAt(iCh));
if (nxt.get(u)[c] == -1) {
int[] node = new int[SIG];
Arrays.fill(node, -1);
nxt.get(u)[c] = nxt.size();
nxt.add(node);
cnt.add(0);
}
u = nxt.get(u)[c];
cnt.set(u, cnt.get(u) + 1);
}
}
while (q-- > 0) {
String t = sc.next();
int u = 0; boolean ok = true;
for (int iCh = 0; iCh < t.length(); iCh++) {
int c = idx(t.charAt(iCh));
int v = nxt.get(u)[c];
if (v == -1) { ok = false; break; }
u = v;
}
System.out.println(ok ? cnt.get(u) : 0);
}
}
}
import sys
def idx(c: str) -> int:
if 'a' <= c <= 'z':
return ord(c) - ord('a')
return 26 + (ord(c) - ord('A'))
data = sys.stdin.read().strip().split()
it = iter(data)
m = int(next(it)); q = int(next(it))
SIG = 52
nxt = [[-1] * SIG]
cnt = [0]
for _ in range(m):
s = next(it)
u = 0
for ch in s:
c = idx(ch)
if nxt[u][c] == -1:
nxt[u][c] = len(nxt)
nxt.append([-1] * SIG)
cnt.append(0)
u = nxt[u][c]
cnt[u] += 1
out_lines = []
for _ in range(q):
t = next(it)
u = 0
ok = True
for ch in t:
c = idx(ch)
v = nxt[u][c]
if v == -1:
ok = False
break
u = v
out_lines.append(str(cnt[u] if ok else 0))
sys.stdout.write("\n".join(out_lines))
算法及复杂度
- 算法:Trie 统计节点通过次数,查询时走到对应节点返回计数。
- 时间复杂度:构建
,每次查询
,总计线性于总长度。
- 空间复杂度:
。