直接使用
请打开基于COMMON_IO使用指南,并点击右上角 “ 在DSW中打开” 。
为了方便读写MaxCompute Table数据,我们基于MaxCompute Tunnel开发了COMMON_IO模块,它提供了TableReader和TableWriter两个接口,使用TableReader可以读取MaxCompute Table中的数据,使用TableWriter可以将数据写入MaxCompute Table,使用这两个接口时需要先在配置文件中配置账户AK等信息,否则无权读写MaxCompute Table。
说明:
- COMMON_IO已在DLC/DSW官方镜像中安装, 暂不支持自定义镜像;
- COMMON_IO适用PyTorch任务读取MaxCompute Table数据场景;
- COMMON_IO适用机器学习任务写MaxCompute Table场景;
1. 准备工作:配置账户信息
配置文件内容格式如下所示,包含了MaxCompute access_id、access_key以及endpoint信息。
access_id/access_key获取方式参见链接。
end_point填入您的MaxCompute项目所在区域对应的Endpoint,可参考链接,例如杭州region endpoint为:http://service.cn-hangzhou.maxcompute.aliyun.com/api
access_id=xxxx
access_key=xxxx
end_point=http://xxxx
在代码中通过以下方式指定配置文件路径
os.environ['ODPS_CONFIG_FILE_PATH'] = '<your MaxCompute config file path>'
2. TableReader使用说明
2.1 接口说明
接口定义
common_io.table.TableReader( table, selected_cols="", excluded_cols="", slice_id=0, slice_count=1)
接口方法
reader.read(num_records=1, allow_smaller_final_batch=False)
- 顺序读取num_records值对应的行数并返回,默认读取1行。当num_records参数超出未读的行数时,返回读取到的所有行。当未读取到记录时,抛出异常(Exception: "End of table reached!")。
- Read读取操作返回一个python数组,数组中每个元素为表的一行数据组成的一个tuple。
reader.start_pos
- 获取读取的表(分片)起始位置
reader.end_pos
- 获取读取的表(分片)结束位置
reader.offset_pos
- 获取正在读取的位置
reader.get_row_count()
- 返回表的行数。如果设置slice_id和slice_count,则返回分片大小
reader.get_schema()
- 获取表的schema
reader.seek(offset=0)
- 定位到相应行,下一个Read操作将从定位的行开始
reader.close()
- 关闭reader
2.2 使用示例
假设在algo_platform_dev项目中存储了一张名为test的表,内容如下所示。
以下代码实现了使用TableReader读取itemid、name及price列的数据。
import os import common_io # 指定配置文件路径 os.environ['ODPS_CONFIG_FILE_PATH'] = "/mnt/workspace/tunnel_io/odps_config.ini" # 打开一个表,返回reader对象 reader = common_io.table.TableReader( "odps://algo_platform_dev/tables/test", selected_cols="itemid,name,price") # 获得表的总行数 total_records_num = reader.get_row_count() print("total_records_num:", total_records_num) batch_size = 2 # 读表,返回值将是一个python数组,形式为[(itemid, name, price)*2] records = reader.read(batch_size) print("records:", records) records = reader.read(batch_size, True) print("records:", records) try: # 继续读取将抛出异常, 原因是数据已全部读取完毕 records = reader.read(batch_size, True) except common_io.exception.OutOfRangeException: pass # 关闭reader reader.close()
total_records_num: 3 records: [(25, 'Apple', 5.0), (38, 'Pear', 4.5)] records: [(17, 'Watermelon', 2.2)]
3. TableWriter使用说明
3.1 接口说明
接口定义
common_io.table.TableWriter( table, slice_id=0)
接口方法
writer.write(values, indices)
- values为需要写入的数据,类型为python数组、或者np.ndarray
- col_indices为写入数据对应的列号,类型为python tuple
writer.close()
- 关闭close,close调用后数据才会真正写入
3.2 使用示例
假设在algo_platform_dev项目中存储了一张名为test的表,一共四列数据,分别为:
itemid(bigint)、name(string)、price(double)、virtual(bool)
下面的示例展示了如何将数据写入test表。
import os import common_io # 指定配置文件路径 os.environ['ODPS_CONFIG_FILE_PATH'] = "/mnt/workspace/tunnel_io/odps_config.ini" # 准备数据 values = [(25, "Apple", 5.0, False), (38, "Pear", 4.5, False), (17, "Watermelon", 2.2, False)] # 打开一个表,返回writer对象 writer = common_io.table.TableWriter("odps://algo_platform_dev/tables/test") # 将数据写至表中的第0-3列 records = writer.write(values, col_indices=[0, 1, 2, 3]) # 关闭writer, 执行close后,数据才会真正写入 writer.close()
4. 最佳实践
4.1 构建pytorch dataset
基于common_io构建dataset示例如下,构建了一个流式dataset。
import os import re import torch import common_io from torch.utils.data import Dataset train_table = "odps://algo_platform_dev/tables/common_io_test" class TableDataset(torch.utils.data.IterableDataset): def __init__(self, table_path, slice_id=0, slice_count=1): self.table_path = table_path reader = common_io.table.TableReader(table_path, slice_id=slice_id, slice_count=slice_count, num_threads=0) self.row_count = reader.get_row_count() self.start_pos = reader.start_pos self.end_pos = reader.end_pos reader.close() super(TableDataset, self).__init__() print("table total_row_count:{}, start_pos:{}, end_pos:{}".format(self.row_count, self.start_pos, self.end_pos)) def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: worker_id = 0 num_workers = 1 else: worker_id = worker_info.id num_workers = worker_info.num_workers print("worker_id:{}, num_workers:{}".format(worker_id, num_workers)) table_start, table_end = self._get_slice_range(self.row_count, worker_id, num_workers, self.start_pos) table_path = "{}?start={}&end={}".format(self.table_path, table_start, table_end) print("table_path:%s" % table_path) def table_data_iterator(): reader = common_io.table.TableReader(table_path, num_threads=1, capacity=1024) while True: try: data = reader.read(num_records=1, allow_smaller_final_batch=True) except common_io.exception.OutOfRangeException: reader.close() break yield data return table_data_iterator() def _get_slice_range(self, row_count, worker_id, num_workers, baseline=0): # div-mod split, each slice data count max diff 1 size = int(row_count / num_workers) split_point = row_count % num_workers if worker_id < split_point: start = worker_id * (size + 1) + baseline end = start + (size + 1) else: start = split_point * (size + 1) + (worker_id - split_point) * size + baseline end = start + size return start, end slice_id = int(os.environ.get('RANK', 0)) slice_count = int(os.environ.get('WORLD_SIZE', 1)) train_dataset = TableDataset(train_table, slice_id, slice_count) train_ld = torch.utils.data.DataLoader( train_dataset, batch_size=3, shuffle=False, pin_memory=False, sampler=None, num_workers=5, collate_fn=lambda x: x ) for data in train_ld: print(data)
5. FAQ
5.1 错误 No such file: /root/.odps_config.ini
该错误表示未找到配置文件,参考使用说明准备工作部分。