点集A
是一个Nx3
矩阵,从具有相同大小B
的两个点集C
和Mx3
中,我们可以得到在它们之间的线BC
。现在,我要计算从A
中的每个点到BC
中的每条线的距离。 B
是Mx3
并且C
是Mx3
,则线是从具有对应行的点开始的,因此BC
是Mx3
矩阵。基本方法计算如下:
D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
p = A[i] # 1x3
for j in range(M):
p1 = B[j] # 1x3
p2 = C[j] # 1x3
D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2)
有没有更快的方法可以完成这项工作?谢谢。
最佳答案
您可以通过执行以下操作删除for
循环(除非M
和N
很小,否则它应以内存为代价加快速度):
diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines
当然,您不需要逐步进行操作。我只是想弄清楚变量名。
注意:如果不向
dim
提供torch.cross
,它将使用第一个dim=3
,如果N=3
(来自docs),则会给出错误的结果:如果未指定dim,则默认为尺寸为3的第一个尺寸。
如果您想知道,可以查看here为什么选择
expand
而不是repeat
。关于python - 如何在PyTorch中计算点集和线之间的成对距离?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/58660031/