提问人:Didier 提问时间:11/15/2023 更新时间:11/15/2023 访问量:7
为什么这个 triton 内核会崩溃?
Why this triton kernel crashes?
问:
我找不到这段代码中的错误是什么,它假设合并大小为 (n, k) 的 p 张量,使得合并的大小为 (n, k) 的张量具有这些 p 张量的交替元素,而每行没有双倍子。(或者,如果您可以建议其他代码(在 triton 或完整的 python 中)执行相同的操作,并且在计算和内存效率方面都很好)。谢谢你的帮助。
def _merge_edge_index_kernel(
edge_index_stacked_ptr,
num_edges,
num_edge_types,
edge_index_merged_ptr,
BLOCK_SIZE: tl.constexpr,
):
i = tl.program_id(axis=0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < num_edges
edge_index_merged = tl.full((BLOCK_SIZE,), value=-1, dtype=tl.int32)
stride = num_edges * num_edge_types
count = 0
num_duplicates = 0
while count < num_edges:
j = tl.load(edge_index_stacked_ptr + i * stride + count + num_duplicates)
not_already_in = tl.sum(edge_index_merged == j, axis=0) == 0
if not_already_in:
edge_index_merged = tl.where(offsets == count, j, edge_index_merged)
count += 1
else:
num_duplicates += 1
tl.store(edge_index_merged_ptr + i * num_edges + offsets, edge_index_merged, mask=mask)
def merge_edge_index(edge_index_stacked: torch.Tensor) -> torch.Tensor:
"""
Inputs
------
* edge_index_stacked: (sum_i seq_lens[i], num_edges, num_edge_types)
Output
------
* edge_index_merged: (sum_i seq_lens[i], num_edges)
"""
assert edge_index_stacked.is_cuda, "edge_index_stacked is not on cuda"
assert edge_index_stacked.is_contiguous(), "edge_index_stacked is not contiguous"
total_size, num_edges, num_edge_types = edge_index_stacked.shape
edge_index_merged = torch.empty_like(edge_index_stacked[..., 0])
BLOCK_SIZE = triton.next_power_of_2(num_edges)
grid = (total_size,)
_merge_edge_index_kernel[grid](
edge_index_stacked,
num_edges,
num_edge_types,
edge_index_merged,
BLOCK_SIZE=BLOCK_SIZE, # type: ignore
)
return edge_index_merged
答: 暂无答案
下一个:定义多个输入图的平均曲线
评论