我在Matlab中有以下四个嵌套循环:
timesteps = 5;
inputsize = 10;
additionalinputsize = 3;
outputsize = 7;
input = randn(timesteps, inputsize);
additionalinput = randn(timesteps, additionalinputsize);
factor = randn(inputsize, additionalinputsize, outputsize);
output = zeros(timesteps,outputsize);
for t=1:timesteps
for i=1:inputsize
for o=1:outputsize
for a=1:additionalinputsize
output(t,o) = output(t,o) + factor(i,a,o) * input(t,i) * additionalinput(t,a);
end
end
end
end
有三个向量:一个输入向量,一个附加输入向量和一个输出向量。所有的都是由因素连接的。每个向量在给定的时间步都有值。我需要每个给定时间步长的所有组合输入,其他输入和因子的总和。稍后,我需要从输出到输入进行计算:
result2 = zeros(timesteps,inputsize);
for t=1:timesteps
for i=1:inputsize
for o=1:outputsize
for a=1:additionalinputsize
result2(t,i) = result2(t,i) + factor(i,a,o) * output(t,o) * additionalinput(t,a);
end
end
end
end
在第三种情况下,我需要在每个时间步求和所有三个向量的乘积:
product = zeros(inputsize,additionalinputsize,outputsize)
for t=1:timesteps
for i=1:inputsize
for o=1:outputsize
for a=1:additionalinputsize
product(i,a,o) = product(i,a,o) + input(t,i) * output(t,o) * additionalinput(t,a);
end
end
end
end
这两个代码段可以工作,但是速度却非常慢。如何删除嵌套循环?
编辑:增加了值并更改了次要的东西,因此代码片段是可执行的
Edit2:添加了其他用例
最佳答案
第一部分
一种方法-
t1 = bsxfun(@times,additionalinput,permute(input,[1 3 2]));
t2 = bsxfun(@times,t1,permute(factor,[4 2 1 3]));
t3 = permute(t2,[2 3 1 4]);
output = squeeze(sum(sum(t3)));
或略作改动以避免
squeeze
-t1 = bsxfun(@times,additionalinput,permute(input,[1 3 2]));
t2 = bsxfun(@times,t1,permute(factor,[4 2 1 3]));
t3 = permute(t2,[1 4 2 3]);
output = sum(sum(t3,3),4);
第二部分
t11 = bsxfun(@times,additionalinput,permute(output,[1 3 2]));
t22 = bsxfun(@times,permute(t11,[1 4 2 3]),permute(factor,[4 1 2 3]));
result2=sum(sum(t22,3),4);
第三方
t11 = bsxfun(@times,permute(output,[4 3 2 1]),permute(additionalinput,[4 2 3 1]));
t22 = bsxfun(@times,permute(input,[2 4 3 1]),t11);
product = sum(t22,4);