题解 | #小红的智能语音分类器#
小红的智能语音分类器
https://www.nowcoder.com/practice/93a42c433f264776bf06c35bb87842b9
题目链接
题目描述
小红正在开发一个基于智能音箱的语音意图识别系统。系统将语音信号转化为三维特征向量,并使用 K最近邻(KNN) 算法进行分类。 对于一个待识别的目标特征向量,分类逻辑如下:
- 在样本库中寻找与目标向量欧氏距离最近的
个已知样本。
- 统计这
个样本的意图类别标签,取出现频率最高的标签作为预测结果。
- 如果存在多个标签出现次数相同且均为最高,则输出其中数值最小的那个标签。
题目保证在第 近的边界上不存在距离相等的样本。
解题思路
本题是标准的 KNN 分类算法实现。
-
欧氏距离计算: 目标向量为
,样本向量为
。 欧氏距离
。 为了简化计算,可以直接使用距离的平方进行比较:
。
-
寻找最近邻: 计算目标向量到样本库中所有
个样本的距离平方,并将距离与样本标签关联存储。 对所有样本按距离平方进行升序排序,取前
个样本。
-
投票统计: 遍历这
个最近邻样本,使用映射表(Map)统计每个标签出现的次数。 记录当前出现次数的最大值
和对应的预测标签
。 遍历统计结果:
- 若当前标签频次
,更新
和
。
- 若当前标签频次
,且当前标签数值
,更新
。
- 若当前标签频次
-
复杂度分析:
- 时间复杂度:
。计算所有距离需
,排序需
,统计
个邻居需
。
- 空间复杂度:
。用于存储样本数据。
- 时间复杂度:
代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <map>
#include <cmath>
using namespace std;
// 定义样本结构体
struct Sample {
double x, y, z;
int label;
double dist_sq;
};
// 排序比较函数:按距离平方升序
bool compareSamples(const Sample& a, const Sample& b) {
return a.dist_sq < b.dist_sq;
}
int main() {
int n, k;
cin >> n >> k;
vector<Sample> library(n);
for (int i = 0; i < n; ++i) {
cin >> library[i].x >> library[i].y >> library[i].z >> library[i].label;
}
double tx, ty, tz;
cin >> tx >> ty >> tz;
// 计算到目标点的距离平方
for (int i = 0; i < n; ++i) {
double dx = library[i].x - tx;
double dy = library[i].y - ty;
double dz = library[i].z - tz;
library[i].dist_sq = dx * dx + dy * dy + dz * dz;
}
// 排序找到最近的 K 个
sort(library.begin(), library.end(), compareSamples);
// 统计前 K 个样本的标签频次
map<int, int> counts;
for (int i = 0; i < k; ++i) {
counts[library[i].label]++;
}
int ans = -1;
int max_freq = -1;
// 寻找频次最高且数值最小的标签
for (auto const& [label, freq] : counts) {
if (freq > max_freq) {
max_freq = freq;
ans = label;
} else if (freq == max_freq) {
if (ans == -1 || label < ans) {
ans = label;
}
}
}
cout << ans << endl;
return 0;
}
import java.util.*;
public class Main {
static class Sample {
double distSq;
int label;
Sample(double distSq, int label) {
this.distSq = distSq;
this.label = label;
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int k = sc.nextInt();
double[][] coords = new double[n][3];
int[] labels = new int[n];
for (int i = 0; i < n; i++) {
coords[i][0] = sc.nextDouble();
coords[i][1] = sc.nextDouble();
coords[i][2] = sc.nextDouble();
labels[i] = sc.nextInt();
}
double tx = sc.nextDouble();
double ty = sc.nextDouble();
double tz = sc.nextDouble();
Sample[] samples = new Sample[n];
for (int i = 0; i < n; i++) {
double dx = coords[i][0] - tx;
double dy = coords[i][1] - ty;
double dz = coords[i][2] - tz;
double d2 = dx * dx + dy * dy + dz * dz;
samples[i] = new Sample(d2, labels[i]);
}
// 按距离升序排序
Arrays.sort(samples, (a, b) -> Double.compare(a.distSq, b.distSq));
// 统计前 K 个邻居的标签频次
Map<Integer, Integer> freqMap = new HashMap<>();
for (int i = 0; i < k; i++) {
int label = samples[i].label;
freqMap.put(label, freqMap.getOrDefault(label, 0) + 1);
}
int maxFreq = -1;
int resultLabel = Integer.MAX_VALUE;
for (Map.Entry<Integer, Integer> entry : freqMap.entrySet()) {
int label = entry.getKey();
int freq = entry.getValue();
if (freq > maxFreq) {
maxFreq = freq;
resultLabel = label;
} else if (freq == maxFreq) {
if (label < resultLabel) {
resultLabel = label;
}
}
}
System.out.println(resultLabel);
}
}
def solve():
# 读取 N 和 K
line1 = input().split()
n, k = int(line1[0]), int(line1[1])
samples = []
for _ in range(n):
parts = list(map(float, input().split()))
# 前三个是坐标,最后一个是标签
samples.append((parts[0], parts[1], parts[2], int(parts[3])))
# 读取目标向量
target = list(map(float, input().split()))
tx, ty, tz = target[0], target[1], target[2]
# 计算距离平方并排序
dist_info = []
for i in range(n):
sx, sy, sz, label = samples[i]
d2 = (sx - tx)**2 + (sy - ty)**2 + (sz - tz)**2
dist_info.append((d2, label))
# 按距离升序排列
dist_info.sort(key=lambda x: x[0])
# 取前 K 个最近邻并统计标签频次
freq_map = {}
for i in range(k):
label = dist_info[i][1]
freq_map[label] = freq_map.get(label, 0) + 1
# 寻找出现次数最多且标签值最小的结果
max_freq = -1
ans = -1
for label, freq in freq_map.items():
if freq > max_freq:
max_freq = freq
ans = label
elif freq == max_freq:
if ans == -1 or label < ans:
ans = label
print(ans)
solve()
算法及复杂度
- 算法:K最近邻(KNN)分类。通过计算所有样本到目标点的欧氏距离(或距离平方),排序选取最近的
个样本进行投票。
- 时间复杂度:
。计算
个点的距离为
,排序为
,统计为
。
- 空间复杂度:
。需要存储
个样本的信息。
查看13道真题和解析