为什么这个 triton 内核会崩溃?

Why this triton kernel crashes?

提问人:Didier 提问时间:11/15/2023 更新时间:11/15/2023 访问量:7

问:

我找不到这段代码中的错误是什么,它假设合并大小为 (n, k) 的 p 张量,使得合并的大小为 (n, k) 的张量具有这些 p 张量的交替元素,而每行没有双倍子。(或者,如果您可以建议其他代码(在 triton 或完整的 python 中)执行相同的操作,并且在计算和内存效率方面都很好)。谢谢你的帮助。

def _merge_edge_index_kernel(
    edge_index_stacked_ptr,
    num_edges,
    num_edge_types,
    edge_index_merged_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    i = tl.program_id(axis=0)

    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_edges

    edge_index_merged = tl.full((BLOCK_SIZE,), value=-1, dtype=tl.int32)
    stride = num_edges * num_edge_types

    count = 0
    num_duplicates = 0
    while count < num_edges:
        j = tl.load(edge_index_stacked_ptr + i * stride + count + num_duplicates)
        not_already_in = tl.sum(edge_index_merged == j, axis=0) == 0
        if not_already_in:
            edge_index_merged = tl.where(offsets == count, j, edge_index_merged)
            count += 1
        else:
            num_duplicates += 1

    tl.store(edge_index_merged_ptr + i * num_edges + offsets, edge_index_merged, mask=mask)


def merge_edge_index(edge_index_stacked: torch.Tensor) -> torch.Tensor:
    """
    Inputs
    ------
        * edge_index_stacked: (sum_i seq_lens[i], num_edges, num_edge_types)

    Output
    ------
        * edge_index_merged: (sum_i seq_lens[i], num_edges)
    """

    assert edge_index_stacked.is_cuda, "edge_index_stacked is not on cuda"
    assert edge_index_stacked.is_contiguous(), "edge_index_stacked is not contiguous"

    total_size, num_edges, num_edge_types = edge_index_stacked.shape

    edge_index_merged = torch.empty_like(edge_index_stacked[..., 0])
    BLOCK_SIZE = triton.next_power_of_2(num_edges)

    grid = (total_size,)
    _merge_edge_index_kernel[grid](
        edge_index_stacked,
        num_edges,
        num_edge_types,
        edge_index_merged,
        BLOCK_SIZE=BLOCK_SIZE,  # type: ignore
    )

    return edge_index_merged
图形 PyTorch Tensor Triton

评论


答: 暂无答案