题解 | 数据集的批量迭代器
数据集的批量迭代器
https://www.nowcoder.com/practice/17ffa20b827449a0a5a78be08b84de80
import numpy as np
def batch_iterator(X, y=None, batch_size=64):
n_samples = X.shape[0]
batches = []
for i in np.arange(0,n_samples,batch_size):
begin,end = i,min(i+batch_size,n_samples)
if y is not None:
batches.append([X[begin:end],y[begin:end]])
else:
batches.append(X[begin:end])
return batches
if __name__ == "__main__":
X = np.array(eval(input()))
y = np.array(eval(input()))
batch_size = int(input())
print(batch_iterator(X, y, batch_size))
查看10道真题和解析
