MXNet Data Iterator
2021-02-09 20:16
标签:格式 ima 数据流 它的 poc 不能 iterable orm sub 本文先就DataBatch、DataDesc、DataIter三个主要用到的类进行介绍,然后引出Mxnet中常见的迭代器。 MXNet中的数据迭代器Data iterators类似于Python迭代器对象。在Python中,函数iter允许通过对可iterable对象(如Python列表)调用next()按顺序获取项。迭代器提供了一个抽象接口,用于遍历各种类型的iterable集合,而无需公开底层数据源的详细信息。 看看DataBatch类以及他的方法: class 参数: 这个类就是一个批量的样本,每次data iterator调用next(),就会返回一个DataBatch,也即一个批量的样本。如果输入的数据是图像的话,这些图像的shape取决于DataDesc中的provide_data参数: class DataDesc用于存储数据的名字,形状,类型和格式信息。 参数: 方法: 每个训练样本的名称、形状、类型和布局等信息及其相应的标签可以通过DataBatch中的provide_data和provide_label属性作为DataDesc数据描述符对象提供。这里定义了DataDesc的结构。 class 是mxnet中数据迭代器dataiter的基类。mxnet中所有的数据IO都由该类的子类来处理。mxnet中的dataiter迭代器是和python中的iterators很像,每次调用nxet都会返回一个Databatch代表了一个批量中的数据。 参数: 方法: MXNet中的所有IO都通过mx.io.DataIter以及它的子类来处理。本文将讨论MXNet提供的一些常用迭代器。 当所有内置的迭代器不能满足时,可以定制。 mxnet中的迭代器应当满足: 创建新迭代器时,可以从头开始定义迭代器,也可以重用现有迭代器之一。例如,在图像caption应用程序中,输入示例是图像,而标签是句子。因此,我们可以通过以下方法创建新的迭代器: 一个实例: 构建一个mlp: 通过mxnet的module模块来喂入数据。 因为data_iter是迭代器类型,所以可以有get_data()、get_label()、get_index()、next()等方法。 同样因为data_iter.next()返回的是一个DataBatch类型,所以可以有data_iter.next().data、data_iter.next().label等属性。 其余内容见:mxnet 数据读取 MXNet Data Iterator 标签:格式 ima 数据流 它的 poc 不能 iterable orm sub 原文地址:https://www.cnblogs.com/king-lps/p/13057643.htmlMXNet Data Iterator
DataBatch
在MXNet中,数据迭代器在每次调用next()时返回一批数据作为DataBatch。数据批处理通常包含n个训练示例及其相应的标签。这里n是迭代器的批处理大小。在数据流结束时,当没有更多的数据可读取时,迭代器会引发像Python iter那样的StopIteration异常。DataBatch结构在这。mxnet.io.
DataBatch
(data, label=None, pad=None, index=None, bucket_key=None, provide_data=None, provide_label=None)[source]
DataDesc
mxnet.io.
DataDesc
[source]
get_batch_axis
(layout):获取与批处理大小相对应的维度。get_list
(shapes, types):从属性列表中获取DataDesc列表。DataIter
mxnet.io.
DataIter
(batch_size=0)[source]
Data iterators:Mxnet中所有常用的迭代器
import mxnet as mx
%matplotlib inline
import os
import sys
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
Reading data in memory
import numpy as np
# fix the seed
np.random.seed(42)
mx.random.seed(42)
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
print([batch.data, batch.label, batch.pad])
Reading data from CSV files
#lets save `data` into a csv file first and try reading it back
np.savetxt(‘data.csv‘, data, delimiter=‘,‘)
data_iter = mx.io.CSVIter(data_csv=‘data.csv‘, data_shape=(3,), batch_size=30)
for batch in data_iter:
print([batch.data, batch.pad])
Custom Iterator
1 class SimpleIter(mx.io.DataIter):
2 def __init__(self, data_names, data_shapes, data_gen,
3 label_names, label_shapes, label_gen, num_batches=10):
4 self._provide_data = list(zip(data_names, data_shapes))
5 self._provide_label = list(zip(label_names, label_shapes))
6 self.num_batches = num_batches
7 self.data_gen = data_gen
8 self.label_gen = label_gen
9 self.cur_batch = 0
10
11 def __iter__(self):
12 return self
13
14 def reset(self):
15 self.cur_batch = 0
16
17 def __next__(self):
18 return self.next()
19
20 @property
21 def provide_data(self):
22 return self._provide_data
23
24 @property
25 def provide_label(self):
26 return self._provide_label
27
28 def next(self):
29 if self.cur_batch self.num_batches:
30 self.cur_batch += 1
31 data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
32 label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
33 return mx.io.DataBatch(data, label)
34 else:
35 raise StopIteration
import mxnet as mx
num_classes = 10
net = mx.sym.Variable(‘data‘)
net = mx.sym.FullyConnected(data=net, name=‘fc1‘, num_hidden=64)
net = mx.sym.Activation(data=net, name=‘relu1‘, act_type="relu")
net = mx.sym.FullyConnected(data=net, name=‘fc2‘, num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name=‘softmax‘)
print(net.list_arguments())
print(net.list_outputs())
import logging
logging.basicConfig(level=logging.INFO)
n = 32
data_iter = SimpleIter([‘data‘], [(n, 100)],
[lambda s: np.random.uniform(-1, 1, s)],
[‘softmax_label‘], [(n,)],
[lambda s: np.random.randint(0, num_classes, s)])
mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)
上一篇:HTML
文章标题:MXNet Data Iterator
文章链接:http://soscw.com/index.php/essay/53241.html