首页 > 试题广场 >

均衡版 KMeans 分群与新用户归类

[编程题]均衡版 KMeans 分群与新用户归类
  • 热度指数:171 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解

某电商平台需要把 N 位老客户按其 M 维非负整数特征划分为 K 个群组(2 ≤ K ≤ min(20, N))。为避免资源倾斜,要求每个群组的容量严格均衡:每组人数为 N//K 或 N//K+1,多出来的人数依次补给“中心编号较小”的群组。你需要实现一个“按顺序分配 + 均衡容量 + 中心取整”的 KMeans 变体,并用最终的中心点将一个新客户归到最近的中心。

算法规则

1) 初始中心为输入的前 K 位客户的特征。
2) 每一轮分配按客户输入顺序从 1 到 N 顺序处理。对每个客户:

  • 计算其到每个中心的欧氏距离(可用“平方和”比较,无需开方)。
  • 在“尚未满员”的中心里选择距离最小者;若有距离并列,取中心编号更小的。
  • 每个中心的容量固定:前 N%K 个中心容量为 N//K+1,其余为 N//K。
    3) 一轮分配完成后,更新每个中心为该组所有成员的逐维均值向下取整(floor)。
    4) 若“本轮的分配结果和中心”与上一轮完全一致,则停止。
    5) 输出时先将最终中心按字典序(先比第 1 维,再比第 2 维,依此类推)升序排序;随后给定新客户特征,计算他到“已排序中心”的距离,归到最近的中心;若有并列,选择字典序最小的中心。输出该中心在“排序后列表”中的序号(从 1 开始)。

输入描述:
  • 第 1 行:N M K
  • 第 2 ~ N+1 行:每行 M 个非负整数,表示一位老客户的特征
  • 第 N+2 行:M 个非负整数,表示新客户的特征


输出描述:
  • 先输出 K 行:排序后的 K 个中心(每行 M 个整数)
  • 再输出 1 行:新客户所在中心在排序后列表中的序号(从 1 开始)
示例1

输入

4 1 2
0
10
9
11
8

输出

5
9
2

说明

1.按“容量均衡 + 顺序分配”规则,4 个点分到两组容量各 2:{0,11} 与 {10,9}。  
2.组中心为各组均值下取整:{5,9},再次分配不变,收敛。  
3.新点 8 到中心 5、9 的距离分别为 9 和 1,选 9,排序后位次为 2。


备注:
  • 可用“平方距离”比较代替真实欧氏距离,顺序不变。
  • 均衡容量固定不变:前 (N mod K) 个中心容量为 N//K+1,其余为 N//K。
  • 所有并列均按“编号小/字典序小”打破平局。
好折磨的大模拟。。
#include "bits/stdc++.h"
using namespace std;

int n,m,k,preLimit,otherLimit,n_mod_k;

struct guest{
    vector<int> features;
};

struct group{
    int centerID, id, limit = 0;
    vector<int> guestID;
    vector<int> center, preCenter;
    // 按center排序
    bool operator<(const group& o) const{
        return center < o.center;
    }
};

vector<guest> guests;
vector<group> groups;
vector<vector<int>> centers; // 存储中心

long long calcDis(vector<int> vector1, vector<int> &vector2) {
    long long res = 0;
    for(int i = 0; i < vector1.size(); i++){
        res += (vector1[i] - vector2[i]) * (vector1[i] - vector2[i]);
    }
    return res;
}

int main(){
    cin >> n >> m >> k;
    guests.resize(n);
    groups.resize(k);
    n_mod_k = n % k;
    if (n_mod_k == 0){
        preLimit = n / k;
        otherLimit = n / k;
    }else{
        preLimit = n / k + 1;
        otherLimit = n / k;
    }
    // 初试化
    for(int i = 0; i < n; i++){
        guests[i].features.resize(m);
        for(int j = 0; j < m; j++){
            cin >> guests[i].features[j]; // 保存特征
        }
        if (i < k) {
            groups[i].centerID = i;
            groups[i].guestID.push_back(i);
            groups[i].id = i;
            groups[i].preCenter = guests[i].features;
            groups[i].center = guests[i].features;
            //guests[i].centerID = i;
            if (i+1 <= n_mod_k) groups[i].limit = preLimit;
            else groups[i].limit = otherLimit;
            centers.push_back(guests[i].features);
        }else { // 为当前顾客找个中心组
            long long dis = 1e18;
            int tempGID = -1, preCenterID = -1;
            for(int j = 0; j < k; j++){
                group curG = groups[j];
                // 人数已满
                if (curG.guestID.size() >= curG.limit) continue;
                vector<int> v = guests[curG.centerID].features;
                long long temp = calcDis(guests[i].features, v);
                if (temp < dis){
                    dis = temp;
                    tempGID = curG.id;
                    preCenterID = groups[j].centerID;
                }else if(temp == dis){
                    if (preCenterID > groups[j].centerID){
                        tempGID = curG.id;
                        preCenterID = groups[j].centerID;
                    }
                }
            }
            // 将顾客加入当前组
            groups[tempGID].guestID.push_back(i);
        }
    }
    // 更新每组的中心
    while(true) {
        // 清空组中的顾客
        for(int i = 0; i < k; i++){
            groups[i].guestID.clear();
        }
        int count = 0;
        // 按顺序给顾客重新分配组
        for(int ii = 0; ii < n; ii++){
            vector<int> curGuest = guests[ii].features;
            long long dis = 1e18;
            int tempGID = -1, preCenterID = -1;
            for(int j = 0; j < k; j++){
                // 人数已满
                if (groups[j].guestID.size() >= groups[j].limit) continue;
                long long temp = calcDis(groups[j].center, curGuest);
                if (temp < dis){
                    dis = temp;
                    tempGID = groups[j].id;
                    preCenterID = groups[j].centerID;
                }else if(temp == dis){
                    if (preCenterID > groups[j].centerID){
                        tempGID = groups[j].id;
                        preCenterID = groups[j].centerID;
                    }
                }
            }
            // 将顾客加入当前组
            groups[tempGID].guestID.push_back(ii);
        }

        // 更新中心维度
        for (int i = 0; i < k; i++) {
            //group curG = groups[i];
            vector<int> vID = groups[i].guestID;
            // 更新m维特征
            for(int j = 0; j < m; j++){
                double tmp = 0.0;
                int sz = vID.size();
                for (int mm = 0; mm < sz; ++mm) {
                    vector<int> fe = guests[vID[mm]].features;
                    tmp += fe[j];
                }
                groups[i].center[j] = floor(tmp / sz);
            }
            if (!equal(groups[i].center.begin(),groups[i].center.end(),groups[i].preCenter.begin(),groups[i].preCenter.end())){
                groups[i].preCenter = groups[i].center;
            }else{
                count++;
            }
        }
        // 更新停止
        if (count == k) break;
    }

    // 将最终中心按字典序排序
    sort(groups.begin(),groups.end());
    // 读取最后一个新顾客的特征
    guest newGuest;
    newGuest.features.resize(m);
    for(int i = 0; i < m; i++){
        cin >> newGuest.features[i];
    }
    long long dis = 1e18;
    int tempID = -1;
    for(int j = 0; j < k; j++){
        group curG = groups[j];
        vector<int> v = newGuest.features;
        long long temp = calcDis(curG.center, v);
        if (temp < dis){
            dis = temp;
            tempID = j;
        }
    }
    // 输出中心
    for(int i = 0; i < k; i++){
        group curG = groups[i];
        vector<int> cen = curG.center;
        for(int j = 0; j < m; j++){
            cout << cen[j];
            if (j < m-1) cout << " ";
        }
        cout << "\n";
    }
    // 输出序号
    tempID++;
    cout << tempID << "\n";
    return 0;
}

发表于 2026-01-11 17:47:37 回复(0)