题解 | #小红的智能语音分类器#

小红的智能语音分类器

https://www.nowcoder.com/practice/93a42c433f264776bf06c35bb87842b9

题目链接

小红的智能语音分类器

题目描述

小红正在开发一个基于智能音箱的语音意图识别系统。系统将语音信号转化为三维特征向量,并使用 K最近邻(KNN) 算法进行分类。 对于一个待识别的目标特征向量,分类逻辑如下:

  1. 在样本库中寻找与目标向量欧氏距离最近的 个已知样本。
  2. 统计这 个样本的意图类别标签,取出现频率最高的标签作为预测结果。
  3. 如果存在多个标签出现次数相同且均为最高,则输出其中数值最小的那个标签。

题目保证在第 近的边界上不存在距离相等的样本。

解题思路

本题是标准的 KNN 分类算法实现。

  1. 欧氏距离计算: 目标向量为 ,样本向量为 。 欧氏距离 。 为了简化计算,可以直接使用距离的平方进行比较:

  2. 寻找最近邻: 计算目标向量到样本库中所有 个样本的距离平方,并将距离与样本标签关联存储。 对所有样本按距离平方进行升序排序,取前 个样本。

  3. 投票统计: 遍历这 个最近邻样本,使用映射表(Map)统计每个标签出现的次数。 记录当前出现次数的最大值 和对应的预测标签 。 遍历统计结果:

    • 若当前标签频次 ,更新
    • 若当前标签频次 ,且当前标签数值 ,更新
  4. 复杂度分析

    • 时间复杂度:。计算所有距离需 ,排序需 ,统计 个邻居需
    • 空间复杂度:。用于存储样本数据。

代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <map>
#include <cmath>

using namespace std;

// 定义样本结构体
struct Sample {
    double x, y, z;
    int label;
    double dist_sq;
};

// 排序比较函数:按距离平方升序
bool compareSamples(const Sample& a, const Sample& b) {
    return a.dist_sq < b.dist_sq;
}

int main() {
    int n, k;
    cin >> n >> k;

    vector<Sample> library(n);
    for (int i = 0; i < n; ++i) {
        cin >> library[i].x >> library[i].y >> library[i].z >> library[i].label;
    }

    double tx, ty, tz;
    cin >> tx >> ty >> tz;

    // 计算到目标点的距离平方
    for (int i = 0; i < n; ++i) {
        double dx = library[i].x - tx;
        double dy = library[i].y - ty;
        double dz = library[i].z - tz;
        library[i].dist_sq = dx * dx + dy * dy + dz * dz;
    }

    // 排序找到最近的 K 个
    sort(library.begin(), library.end(), compareSamples);

    // 统计前 K 个样本的标签频次
    map<int, int> counts;
    for (int i = 0; i < k; ++i) {
        counts[library[i].label]++;
    }

    int ans = -1;
    int max_freq = -1;

    // 寻找频次最高且数值最小的标签
    for (auto const& [label, freq] : counts) {
        if (freq > max_freq) {
            max_freq = freq;
            ans = label;
        } else if (freq == max_freq) {
            if (ans == -1 || label < ans) {
                ans = label;
            }
        }
    }

    cout << ans << endl;

    return 0;
}
import java.util.*;

public class Main {
    static class Sample {
        double distSq;
        int label;

        Sample(double distSq, int label) {
            this.distSq = distSq;
            this.label = label;
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();

        double[][] coords = new double[n][3];
        int[] labels = new int[n];

        for (int i = 0; i < n; i++) {
            coords[i][0] = sc.nextDouble();
            coords[i][1] = sc.nextDouble();
            coords[i][2] = sc.nextDouble();
            labels[i] = sc.nextInt();
        }

        double tx = sc.nextDouble();
        double ty = sc.nextDouble();
        double tz = sc.nextDouble();

        Sample[] samples = new Sample[n];
        for (int i = 0; i < n; i++) {
            double dx = coords[i][0] - tx;
            double dy = coords[i][1] - ty;
            double dz = coords[i][2] - tz;
            double d2 = dx * dx + dy * dy + dz * dz;
            samples[i] = new Sample(d2, labels[i]);
        }

        // 按距离升序排序
        Arrays.sort(samples, (a, b) -> Double.compare(a.distSq, b.distSq));

        // 统计前 K 个邻居的标签频次
        Map<Integer, Integer> freqMap = new HashMap<>();
        for (int i = 0; i < k; i++) {
            int label = samples[i].label;
            freqMap.put(label, freqMap.getOrDefault(label, 0) + 1);
        }

        int maxFreq = -1;
        int resultLabel = Integer.MAX_VALUE;

        for (Map.Entry<Integer, Integer> entry : freqMap.entrySet()) {
            int label = entry.getKey();
            int freq = entry.getValue();
            if (freq > maxFreq) {
                maxFreq = freq;
                resultLabel = label;
            } else if (freq == maxFreq) {
                if (label < resultLabel) {
                    resultLabel = label;
                }
            }
        }

        System.out.println(resultLabel);
    }
}
def solve():
    # 读取 N 和 K
    line1 = input().split()
    n, k = int(line1[0]), int(line1[1])
    
    samples = []
    for _ in range(n):
        parts = list(map(float, input().split()))
        # 前三个是坐标,最后一个是标签
        samples.append((parts[0], parts[1], parts[2], int(parts[3])))
        
    # 读取目标向量
    target = list(map(float, input().split()))
    tx, ty, tz = target[0], target[1], target[2]
    
    # 计算距离平方并排序
    dist_info = []
    for i in range(n):
        sx, sy, sz, label = samples[i]
        d2 = (sx - tx)**2 + (sy - ty)**2 + (sz - tz)**2
        dist_info.append((d2, label))
        
    # 按距离升序排列
    dist_info.sort(key=lambda x: x[0])
    
    # 取前 K 个最近邻并统计标签频次
    freq_map = {}
    for i in range(k):
        label = dist_info[i][1]
        freq_map[label] = freq_map.get(label, 0) + 1
        
    # 寻找出现次数最多且标签值最小的结果
    max_freq = -1
    ans = -1
    
    for label, freq in freq_map.items():
        if freq > max_freq:
            max_freq = freq
            ans = label
        elif freq == max_freq:
            if ans == -1 or label < ans:
                ans = label
                
    print(ans)

solve()

算法及复杂度

  • 算法:K最近邻(KNN)分类。通过计算所有样本到目标点的欧氏距离(或距离平方),排序选取最近的 个样本进行投票。
  • 时间复杂度:。计算 个点的距离为 ,排序为 ,统计为
  • 空间复杂度:。需要存储 个样本的信息。
全部评论

相关推荐

Rac000n:淘天-客户运营部-AI研发工程师,智能客服方向,暑期实习招聘,欢迎联系
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务