题解 | #二分 K-means子网分割#
二分 K-means子网分割
http://www.nowcoder.com/questionTerminal/308edb1f38464124ba30197b48ebc3cf
import math
# If you need to import additional packages&nbs***bsp;classes, please import here.
N = int(input().strip()) # 期望分割的子网数量,用N表示
M = int(input().strip()) # 全网站点总数,用M表示
net_list = [] # 网络站点坐标列表,用net_list表示,列表中的每个元素是一个二元元组,第一位代表x坐标,第二位代表y坐标
for i in range(M):
x_axis, y_axis = input().strip().split()
net_list.append((int(x_axis), int(y_axis)))
def distance(x1, x2):
d = (x1[0] - x2[0]) ** 2 + (x1[1] - x2[1]) ** 2
return d
def kmeans(nodes_id):
nodes = []
for ii in nodes_id:
nodes.append(net_list[ii])
nodes.sort()
centers = [nodes[0], nodes[-1]]
group = [[], []]
step = 0
error = 1
while step < 1000 and error > 1e-6:
group = cal_group(nodes_id, centers)
centers_new = cal_centers(group)
error = cal_error(centers, centers_new)
centers = centers_new
step += 1
return group, centers
def cal_error(centers, centers_new):
error = math.sqrt(distance(centers[0], centers_new[0])) + math.sqrt(distance(centers[1], centers_new[1]))
return error
def cal_centers(group):
centers = [[0, 0], [0, 0]]
if group[0]:
for ii in group[0]:
centers[0][0] += net_list[ii][0]
centers[0][1] += net_list[ii][1]
centers[0][0] /= len(group[0])
centers[0][1] /= len(group[0])
if group[1]:
for ii in group[1]:
centers[1][0] += net_list[ii][0]
centers[1][1] += net_list[ii][1]
centers[1][0] /= len(group[1])
centers[1][1] /= len(group[1])
return centers
def cal_group(nodes_id, centers):
group = [[], []]
for ii in nodes_id:
d0 = distance(net_list[ii], centers[0])
d1 = distance(net_list[ii], centers[1])
if d0 > d1:
group[1].append(ii)
else:
group[0].append(ii)
return group
def bin_kmeans(groups, centers):
n = len(groups)
id_min = 0
SSE_min = float('inf')
for ii in range(n):
groups_cur = groups.copy()
centers_cur = centers.copy()
group = groups_cur[ii]
groups_cur.pop(ii)
centers_cur.pop(ii)
group_new, centers_new = kmeans(group)
if group_new[0]:
groups_cur.append(group_new[0])
centers_cur.append(centers_new[0])
if group_new[1]:
groups_cur.append(group_new[1])
centers_cur.append(centers_new[1])
SSE_cur = cal_SSE(groups_cur, centers_cur)
if SSE_cur < SSE_min:
SSE_min = SSE_cur
id_min = ii
groups_cur = groups.copy()
centers_cur = centers.copy()
group = groups_cur[id_min]
groups_cur.pop(id_min)
centers_cur.pop(id_min)
group_new, centers_new = kmeans(group)
if group_new[0]:
groups_cur.append(group_new[0])
centers_cur.append(centers_new[0])
if group_new[1]:
groups_cur.append(group_new[1])
centers_cur.append(centers_new[1])
return groups_cur, centers_cur
def cal_SSE(groups, centers):
SSE = 0
n = len(groups)
for ii in range(n):
SSE += dist_in_group(groups[ii], centers[ii])
return SSE
def dist_in_group(group, center):
d = 0
for ii in group:
d += distance(net_list[ii], center)
return d
def func():
# please define the python3 input here. For example: a,b = map(int, input().strip().split())
nodes_id = list(range(M))
if N == 1:
print(M)
else:
groups, centers = kmeans(nodes_id)
n = len(groups)
nums = []
for group in groups:
nums.append(len(group))
nums.sort(reverse=True)
re = []
for num in nums:
re.append(str(num))
print(' '.join(re))
while n < N:
groups, centers = bin_kmeans(groups, centers)
n = len(groups)
nums = []
for group in groups:
nums.append(len(group))
nums.sort(reverse=True)
re = []
for num in nums:
re.append(str(num))
print(' '.join(re))
if __name__ == "__main__":
func()

