加载数据集

    MindSpore可以加载常见的标准数据集。支持的数据集如下表:

    加载常见数据集的详细步骤如下,以创建CIFAR-10对象为例,用于加载支持的数据集。

    • 下载CIFAR-10数据集,并解压。这里使用的是二进制格式的数据集(CIFAR-10 binary version)。

    • 配置数据集目录,定义需要加载的数据集实例。

    1. DATA_DIR = "cifar10_dataset_dir/"
    2.  
    3. cifar10_dataset = ds.Cifar10Dataset(DATA_DIR)
    • 创建迭代器,通过迭代器读取数据。
    1. for data in cifar10_dataset.create_dict_iterator():
    2. # In CIFAR-10 dataset, each dictionary of data has keys "image" and "label".
    3. print(data["image"])
    4. print(data["label"])

    MindSpore天然支持读取MindSpore数据格式——MindRecord存储的数据集,在性能和特性上有更好的支持。

    • 创建MindDataset,用于读取数据。
    1. CV_FILE_NAME = os.path.join(MODULE_PATH, "./imagenet.mindrecord")
    2. data_set = ds.MindDataset(dataset_file=CV_FILE_NAME)

    其中,dataset_file:指定MindRecord的文件,含路径及文件名。

    • 创建字典迭代器,通过迭代器读取数据记录。
    1. num_iter = 0
    2. print(data["label"])
    3. num_iter += 1

    是华为ModelArts支持的数据格式文件,详细说明请参见:。

    Mindspore对Manifest格式的数据集提供了对应的数据集类。如下所示,配置数据集目录,定义需要加载的数据集实例。

    目前ManifestDataset仅支持加载图片、标签类型的数据集,默认列名为”image”和”label”。

    MindSpore也支持读取TFRecord数据格式的数据集,可以通过TFRecordDataset对象进行数据集读取。

    • 只需传入数据集路径或.tfrecord文件列表,即可创建TFRecordDataset
    1. DATA_DIR = ["tfrecord_dataset_path/train-0000-of-0001.tfrecord"]
    2.  
    3. dataset = ds.TFRecordDataset(DATA_DIR)
    • 用户可以通过创建Schema文件或Schema类,设定数据集格式及特征。

    Schema文件示例如下所示:

    1. {
    2. "datasetType": "TF",
    3. "numRows": 3,
    4. "columns": {
    5. "image": {
    6. "type": "uint8",
    7. "rank": 1
    8. },
    9. "label" : {
    10. "type": "int64",
    11. "rank": 1
    12. }
    13. }
    14. }

    在创建TFRecordDataset时将Schema文件路径传入,使用样例如下:

    1. DATA_DIR = ["tfrecord_dataset_path/train-0000-of-0001.tfrecord"]
    2.  
    3. dataset = ds.TFRecordDataset(DATA_DIR, schema=SCHEMA_DIR)

    创建Schema类使用样例如下:

    1. import mindspore.common.dtype as mstype
    2. schema = ds.Schema()
    3. schema.add_column('image', de_type=mstype.uint8) # Binary data usually use uint8 here.
    4. schema.add_column('label', de_type=mstype.int32)
    5.  
    6. dataset = ds.TFRecordDataset(DATA_DIR, schema=schema)
    • 创建字典迭代器,通过迭代器读取数据。

    对于自定义数据集,可以通过GeneratorDataset对象加载。

    • 定义一个函数(示例函数名为Generator1D)用于生成数据集的函数。

    自定义的生成函数返回的是可调用的对象,每次返回numpy array的元组,作为一行数据。

    自定义函数示例如下:

    1. import numpy as np # Import numpy lib.
    2. def Generator1D():
    3. for i in range(64):
    4. yield (np.array([i]),) # Notice, tuple of only one element needs following a comma at the end.
    • Generator1D传入GeneratorDataset创建数据集,并设定名为“data”。
    1. dataset = ds.GeneratorDataset(Generator1D, ["data"])
    • 在创建数据集后,可以通过给数据创建迭代器的方式,获取相应的数据。有两种创建迭代器的方法。

      • 创建返回值为序列类型的迭代器。
    1. for data in dataset.create_tuple_iterator(): # each data is a sequence
    2. print(data[0])
    • 创建返回值为字典类型的迭代器。
    1. for data in dataset.create_dict_iterator(): # each data is a dictionary