首页 > 试题广场 >

K-Means聚类下的Anchor优化输出

[编程题]K-Means聚类下的Anchor优化输出
  • 热度指数:137 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
在目标检测任务中,常需为候选框选择一组代表性的 Anchor 尺寸。现给定 N 个矩形框的宽和高,使用基于 IOU 距离的 k-means 聚类得到 K 个 Anchor。初始化时直接取前 K 个框作为初始中心;每轮迭代将每个样本分配给距离最近的中心;随后将每个簇内样本的宽、高分别取均值并向下取整作为新中心。若达到最大迭代次数 T,或新旧中心之间的总“位移”小于 1e-4(用 d=1−IOU 作为中心间距离,并对 K 个中心求和),则停止。最终按 Anchor 面积(宽×高)从大到小输出 K 个中心。

说明与约束

1.距离度量:d = 1 − IOU,其中 IOU = 交集面积 / 并集面积,交集面积 = min(w1,w2) × min(h1,h2),并集面积 = w1×h1 + w2×h2 − 交集面积。
2.所有距离与 IOU 的计算均用浮点;每轮更新后的中心宽、高先取均值再向下取整为整数。
3.若某簇在某轮为空,则该簇中心保持不变。
4.输出前按面积从大到小排序;若面积相同,可按宽、再按高降序作为次序规则。

输入描述:
第一行:N K T(以空格分隔)  
接下来 N 行:每行两个整数 w h,表示一个检测框的宽与高。


输出描述:
输出 K 行:每行两个整数,依次为一个 Anchor 的宽与高,按面积从大到小排序。
示例1

输入

9 3 10
100 50
30 20
10 10
102 49
98 52
29 21
31 19
11 9
9 11

输出

100 50
30 20
10 10

说明

初始中心为 (100,50)、(30,20)、(10,10)。  
分配后每个簇的均值向下取整仍为 (100,50)、(30,20)、(10,10),迭代收敛。  
按面积排序的结果如上。

备注:
本题由牛友@Charles 整理上传
def distance(box1, box2):
    w1, h1 = box1
    w2, h2 = box2
    S1 = min(w1, w2) * min(h1, h2)
    S2 = w1 * h1 + w2 * h2 - S1

    d = 1 - S1 / S2

    return d

class KMeans:
    def __init__(self, N, K, boxes):
        self.N = N
        self.K = K
        self.boxes = boxes

        self.group = [0 for _ in range(N)]
        self.count = [0 for _ in range(K)]

        self.centers = list(boxes[:K])
        return

    def k_means(self, T, max_error=1.0e-4):
        step = 0
        error = 1.0

        while step < T and error > max_error:
            self.create_group()
            centers_cur = self.cal_center()
            error = self.cal_error(centers_cur)
            self.centers = centers_cur
            step += 1
        return

    def cal_error(self, centers_cur):
        error = 0.0
        for ii in range(self.K):
            error += distance(self.centers[ii], centers_cur[ii])
        return error

    def create_group(self):
        count = [0 for _ in range(self.K)]
        for ii in range(self.N):
            d = float('inf')
            for jj in range(self.K):
                d1 = distance(self.boxes[ii], self.centers[jj])
                if d1 < d:
                    d = d1
                    self.group[ii] = jj
            count[self.group[ii]] += 1
        self.count = count
        return

    def cal_center(self):
        centers = self.centers
        for ii in range(self.K):
            if self.count[ii] > 0:
                w = 0
                h = 0
                for jj in range(self.N):
                    if self.group[jj] == ii:
                        w += self.boxes[jj][0]
                        h += self.boxes[jj][1]
                w = w / self.count[ii]
                h = h / self.count[ii]
                centers[ii] = (int(w), int(h))
        return centers

    def print_centers(self):
        Swh = []
        for ii in range(self.K):
            w = self.centers[ii][0]
            h = self.centers[ii][1]
            S = w * h
            Swh.append((S, w, h))
        Swh.sort(reverse=True)
        for ii in range(self.K):
            print(Swh[ii][1], Swh[ii][2])


def read_in():
    N, K, T = map(int, input().split())
    boxes = []
    for _ in range(N):
        boxes.append(tuple(map(int, input().split())))

    boxes = tuple(boxes)
    return N, K, T, boxes

if __name__ == '__main__':
    readin = read_in()
    km = KMeans(readin[0], readin[1], readin[3])
    km.k_means(readin[2])
    km.print_centers()

发表于 2025-10-09 14:43:03 回复(0)
import sys

import numpy as np

def cal_d(A,C):
    h1,w1=A
    h2,w2=C
    inter = min(w1,w2) * min(h1,h2)
    over = w1*h1 + w2*h2 - inter
    return 1- inter/over 

# A=[2,3]
# C=[2,4]
# print(cal_d(A,C))
def sorte_centers_by_areas(center_list):
    areas = [c[0]*c[1] for c in center_list]

    areas = np.array(areas)
    center_list = np.array(center_list)
    new_center_list = []

    while max(areas)>0:
        max_item = center_list[np.where(areas==max(areas))]
        areas[np.where(areas==max(areas))]=0
        
        max_item = list(max_item.flatten())

        new_center_list.append(max_item)
    return new_center_list
# sorte_centers_by_areas([A,C])

def Kmeans(N,K,T,Anchors):
    #initial: centers, S={ci: anchors}, R
    # Anchors = np.array(Anchors)
    
    centers = Anchors[:K]
   
    stop = False
    t=0
    # update while not stop
    while not stop:
        # Remain = list(set(tuple(Anchors))-set(tuple(centers)))
        
        Remain = [ ]
        for i in Anchors:
            if not i in centers:
                Remain.append(i)
        S = dict()
        Dist_matrix=[]
        for i in range(K):
            ci=centers[i]
            S[i]=[ci]

            #对remain里的每个元素计算和ci的距离
            dist_to_ci = [cal_d(a, ci) for a in Remain]
            Dist_matrix.append(dist_to_ci)

        Dist_matrix = np.array(Dist_matrix)
        cls_list = Dist_matrix.argmin(axis=0)
        # print(cls_list)
        

        # get new_centers
        new_centers =[]
        for i in range(K):
            si = list(np.array(Remain)[np.array(cls_list)==i])
            S[i].extend(si)

            new_si = np.array(S[i]).mean(axis=0)
            new_si=list(map(int, new_si))
            new_centers.append(new_si)
        
        # print(new_centers)

        output_center = sorte_centers_by_areas(new_centers)

        # sum_move
        d=0        
        for i in range(K):
            d+= cal_d(centers[i],new_centers[i])
        if d<1e-4:
            stop= True
        t+=1
        if t>=T:
            stop = True
       
    return output_center 

def main():
    # ----------input--------

    line_1 = sys.stdin.readline().strip()
    line_1 = line_1.split(' ')
    N,K,T = list(map(int,line_1))
    
    Anchors =[]
    for i in range(N):
        line = sys.stdin.readline().strip()
        line = line.split(' ')
        anchor = list(map(int,line))
        Anchors.append(anchor)

    # ----------input--------
    # N,K,T=9,3,10
    # Anchors=[[100, 50], [30, 20], [10, 10], [102, 49], [98, 52], [29, 21], [31, 19], [11, 9], [9, 11]]
    centers = Kmeans(N,K,T,Anchors)

    for i in centers:
        print(i[0],i[1])
main()

# 9 3 10
# 100 50
# 30 20
# 10 10
# 102 49
# 98 52
# 29 21
# 31 19
# 11 9
# 9 11

发表于 2025-10-10 21:04:09 回复(0)