题解 | 数据集的批量迭代器

数据集的批量迭代器

https://www.nowcoder.com/practice/17ffa20b827449a0a5a78be08b84de80

import numpy as np

def batch_iterator(X, y=None, batch_size=64):
    n_samples = X.shape[0]
    batches = []
    for i in np.arange(0,n_samples,batch_size):
        begin,end = i,min(i+batch_size,n_samples)
        if y is not None:
            batches.append([X[begin:end],y[begin:end]])
        else:
            batches.append(X[begin:end])

            


    return batches
    
if __name__ == "__main__":
    X = np.array(eval(input()))
    y = np.array(eval(input()))
    batch_size = int(input())
    print(batch_iterator(X, y, batch_size))


全部评论

相关推荐

不愿透露姓名的神秘牛友
05-13 14:16
战争学院:你妈妈第一反应是骗子,我妈妈第一反应是培训贷,全国家长系统是统一的吗哈哈哈
点赞 评论 收藏
分享
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务