实现一个批量迭代器函数,用于将数据集分批处理。这在深度学习中特别有用,可以控制内存使用并实现小批量梯度下降等算法。
输入描述:
函数`batch_iterator`接收三个参数:1. X:特征数据,numpy数组2. y:标签数据(可选),numpy数组3. batch_size:批量大小,正整数,默认64


输出描述:
返回一个列表,包含所有批次:- 如果只有X,每个批次是X的子数组- 如果有y,每个批次是[X子数组, y子数组]的列表- 最后一个批次可能小于batch_size
示例1

输入

[[1,2], [3,4], [5,6], [7,8]]
[0, 1, 0, 1]
2

输出

[[array([[1, 2],
       [3, 4]]), array([0, 1])], [array([[5, 6],
       [7, 8]]), array([0, 1])]]

说明

    

备注:
1.对应的输入、输出已给出,您只用实现核心功能函数即可。2.支持numpy、scipy、pandas、scikit-learn库。
加载中...