洗牌数据后 xarray rechunk

xarray rechunk after shuffling data

提问人:danny 提问时间:11/7/2023 最后编辑:danny 更新时间:11/7/2023 访问量:27

问:

我正在尝试将 xarray 与 tensorflow 一起使用。在每个纪元结束后,数据将沿时间维度进行随机排序。netcdf 文件具有 .我在时间维度上使用 1024 的块大小,它有大约 400k 个值。我遇到的问题是,在洗牌后,数据管道的性能会显着下降。我尝试在数据洗牌后设置块大小并重新分块,但似乎没有帮助。另一方面,如果我将洗牌数据写入磁盘,然后以 1024 的块大小再次读回,则它的处理速度与未洗牌的数据一样快。那么,与从磁盘再次读取相比,我错过了什么?这是我正在使用的代码(time,lat,lon,level)DataArray

    def on_epoch_end(self):
        "Shuffle dataset at the end of epoch"
        if self.shuffle == True:
            # Get the Dask array containing the data values
            dask_data = self.data.data

            # Create a shuffled index array along the 'time' dimension
            shuffled_indices = da.random.permutation(
                dask_data.shape[0]
            )

            # Use Dask delayed computation to perform the shuffling
            shuffled_data = dask_data[shuffled_indices, :,  :, :]

            # shuffled_data = da.rechunk(shuffled_data, chunks={0: 1024})

            # Create a new DataArray with the shuffled data
            self.data = xr.DataArray(
                shuffled_data, coords=self.data.coords, dims=self.data.dims
            )
            self.data = self.data.chunk({"time": 1024})
    
            ## save data to file and read it agan
            #self.data.to_netcdf("save.nc")
            #ds = xr.open_dataset(
            #    "save.nc",
            #    chunks={"time": 1024},
            #)
            #first_variable_name = list(ds.variables)[4]
            #self.data = ds[first_variable_name]

在此输入验证码

TensorFlow python-xarray netcdf

评论


答: 暂无答案