题解 | #小红的 AI 配送聚类优化#

小红的 AI 配送聚类优化

https://www.nowcoder.com/practice/fff828ed40764c94b5ff6fb745dcc1f4

题目链接

小红的 AI 配送聚类优化

题目描述

小红正在开发配送机器人的路径规划系统。系统需要先将 个包裹坐标通过 K-Means 算法聚类为 个核心服务点,然后机器人按顺序访问这些点。

  1. 初始化
    • ,每个包裹直接作为服务点。
    • 否则,按包裹到原点 的欧几里得距离从小到大排序(距离相同保持原序),取前 个作为初始中心,编号
  2. 聚类迭代(最多 50 次):
    • 分配:每个包裹分配给最近的中心。距离相等时分配给编号最小的中心。
    • 更新:中心位置更新为所分配包裹坐标的平均值。若未分配到包裹,坐标保持不变。
    • 终止:若某次更新中所有中心移动的距离之和 ,或达到 50 次,则停止。
  3. 路径规划
    • 将最终的 个服务点按到原点 的距离升序排列。
    • 机器人路径:(0,0)$。
    • 计算总长度 (km)和总耗时(秒),结果向下取整。耗时计算公式:

解题思路

本题是一个完整的 K-Means 算法模拟结合路径几何计算。

  1. K-Means 模拟

    • 使用结构体或类存储坐标
    • 初始化阶段注意排序的稳定性。
    • 迭代过程中,每一轮需要记录旧中心坐标以计算移动距离之和。
    • 计算距离时,分配阶段可以使用距离平方以提高效率,但终止条件必须使用欧几里得距离。
  2. 路径计算

    • 聚类结束后,对 个中心点按到原点的距离进行排序。
    • 依次计算相邻点之间的距离:
    • 注意路径是闭环,包含从原点出发和返回原点的两段。
  3. 数值处理

    • 使用 double 保证精度。
    • 最终耗时使用 floor 或强制类型转换为 long long 实现向下取整。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

struct Point {
    double x, y;
    double dist_to_origin() const {
        return sqrt(x * x + y * y);
    }
};

double get_dist(Point a, Point b) {
    return sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y));
}

int main() {
    int k, n;
    double speed;
    cin >> k >> n >> speed;

    vector<Point> packages(n);
    for (int i = 0; i < n; ++i) {
        cin >> packages[i].x >> packages[i].y;
    }

    vector<Point> centers;
    if (k >= n) {
        centers = packages;
    } else {
        vector<pair<Point, int>> sorted_p(n);
        for (int i = 0; i < n; ++i) sorted_p[i] = {packages[i], i};
        stable_sort(sorted_p.begin(), sorted_p.end(), [](const pair<Point, int>& a, const pair<Point, int>& b) {
            return a.first.dist_to_origin() < b.first.dist_to_origin();
        });
        for (int i = 0; i < k; ++i) centers.push_back(sorted_p[i].first);

        for (int iter = 0; iter < 50; ++iter) {
            vector<vector<Point>> clusters(k);
            for (int i = 0; i < n; ++i) {
                int best_idx = 0;
                double min_d2 = 1e18;
                for (int j = 0; j < k; ++j) {
                    double d2 = (packages[i].x - centers[j].x) * (packages[i].x - centers[j].x) +
                                (packages[i].y - centers[j].y) * (packages[i].y - centers[j].y);
                    if (d2 < min_d2 - 1e-11) {
                        min_d2 = d2;
                        best_idx = j;
                    }
                }
                clusters[best_idx].push_back(packages[i]);
            }

            double total_move = 0;
            for (int i = 0; i < k; ++i) {
                if (clusters[i].empty()) continue;
                double sx = 0, sy = 0;
                for (auto& p : clusters[i]) { sx += p.x; sy += p.y; }
                Point next_c = {sx / clusters[i].size(), sy / clusters[i].size()};
                total_move += get_dist(centers[i], next_c);
                centers[i] = next_c;
            }

            if (total_move < 1e-6) break;
        }
    }

    sort(centers.begin(), centers.end(), [](const Point& a, const Point& b) {
        return a.dist_to_origin() < b.dist_to_origin();
    });

    double total_len = 0;
    Point cur = {0, 0};
    for (int i = 0; i < centers.size(); ++i) {
        total_len += get_dist(cur, centers[i]);
        cur = centers[i];
    }
    total_len += get_dist(cur, {0, 0});

    double total_seconds = total_len / speed * 3600.0;
    cout << (long long)(total_seconds + 1e-9) << endl;

    return 0;
}
import java.util.*;

public class Main {
    static class Point {
        double x, y;
        Point(double x, double y) { this.x = x; this.y = y; }
        double distToOrigin() { return Math.sqrt(x * x + y * y); }
    }

    static double getDist(Point a, Point b) {
        return Math.sqrt(Math.pow(a.x - b.x, 2) + Math.pow(a.y - b.y, 2));
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in).useLocale(Locale.US);
        int k = sc.nextInt();
        int n = sc.nextInt();
        double speed = sc.nextDouble();

        Point[] packages = new Point[n];
        for (int i = 0; i < n; i++) packages[i] = new Point(sc.nextDouble(), sc.nextDouble());

        List<Point> centers = new ArrayList<>();
        if (k >= n) {
            centers.addAll(Arrays.asList(packages));
        } else {
            Point[] sorted = packages.clone();
            Arrays.sort(sorted, Comparator.comparingDouble(Point::distToOrigin));
            for (int i = 0; i < k; i++) centers.add(sorted[i]);

            for (int iter = 0; iter < 50; iter++) {
                List<Point>[] clusters = new ArrayList[k];
                for (int i = 0; i < k; i++) clusters[i] = new ArrayList<>();

                for (Point p : packages) {
                    int bestIdx = 0;
                    double minD2 = Double.MAX_VALUE;
                    for (int j = 0; j < k; j++) {
                        double d2 = Math.pow(p.x - centers.get(j).x, 2) + Math.pow(p.y - centers.get(j).y, 2);
                        if (d2 < minD2 - 1e-11) {
                            minD2 = d2;
                            bestIdx = j;
                        }
                    }
                    clusters[bestIdx].add(p);
                }

                double totalMove = 0;
                for (int i = 0; i < k; i++) {
                    if (clusters[i].isEmpty()) continue;
                    double sx = 0, sy = 0;
                    for (Point p : clusters[i]) { sx += p.x; sy += p.y; }
                    Point nextC = new Point(sx / clusters[i].size(), sy / clusters[i].size());
                    totalMove += getDist(centers.get(i), nextC);
                    centers.set(i, nextC);
                }
                if (totalMove < 1e-6) break;
            }
        }

        centers.sort(Comparator.comparingDouble(Point::distToOrigin));
        double totalLen = 0;
        Point cur = new Point(0, 0);
        for (Point p : centers) {
            totalLen += getDist(cur, p);
            cur = p;
        }
        totalLen += getDist(cur, new Point(0, 0));

        System.out.println((long)(totalLen / speed * 3600 + 1e-9));
    }
}
import math

def solve():
    line1 = input().split()
    k, n, speed = int(line1[0]), int(line1[1]), float(line1[2])
    
    packages = []
    for _ in range(n):
        packages.append(list(map(float, input().split())))
        
    def get_dist(p1, p2):
        return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)

    if k >= n:
        centers = [p[:] for p in packages]
    else:
        # 初始化:按到原点距离排序,稳定排序
        indexed_p = []
        for i in range(n):
            d = math.sqrt(packages[i][0]**2 + packages[i][1]**2)
            indexed_p.append((d, i, packages[i]))
        
        indexed_p.sort()
        centers = [item[2][:] for item in indexed_p[:k]]
        
        for _ in range(50):
            clusters = [[] for _ in range(k)]
            for p in packages:
                best_idx = 0
                min_d2 = float('inf')
                for j in range(k):
                    d2 = (p[0]-centers[j][0])**2 + (p[1]-centers[j][1])**2
                    if d2 < min_d2 - 1e-11:
                        min_d2 = d2
                        best_idx = j
                clusters[best_idx].append(p)
            
            total_move = 0
            for j in range(k):
                if not clusters[j]:
                    continue
                new_x = sum(p[0] for p in clusters[j]) / len(clusters[j])
                new_y = sum(p[1] for p in clusters[j]) / len(clusters[j])
                move = math.sqrt((new_x - centers[j][0])**2 + (new_y - centers[j][1])**2)
                total_move += move
                centers[j] = [new_x, new_y]
            
            if total_move < 1e-6:
                break
                
    # 路径规划
    centers.sort(key=lambda p: math.sqrt(p[0]**2 + p[1]**2))
    
    total_len = 0
    cur = [0.0, 0.0]
    for p in centers:
        total_len += get_dist(cur, p)
        cur = p
    total_len += get_dist(cur, [0.0, 0.0])
    
    ans = int(total_len / speed * 3600 + 1e-9)
    print(ans)

solve()

算法及复杂度

  • 算法:K-Means 聚类 + 几何路径计算。
  • 时间复杂度:。其中 为迭代次数(最大 50), 为包裹数, 为服务点数。
  • 空间复杂度:。用于存储包裹坐标及其所属簇。
全部评论

相关推荐

在打卡的大老虎很想潜...:你在找实习,没啥实习经历,技术栈放前面,项目多就分两页写,太紧凑了,项目你最多写两个,讲清楚就行,项目背景。用到的技术栈、亮点、难点如何解决,人工智能进面太难了,需求少。你可以加最新大模型的东西
点赞 评论 收藏
分享
评论
1
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务