function c = tensor_mul(a, b, ka, kb) % tensor_mul -- Tensor multiplication. % tensor_mul(a, b, ka, kb) returns the tensor-product a*b, % summed along the respective ka and kb directions, which % must have the same length. If the directions are not % specified, the innermost ones are assumed. % tensor_mul('demo') demonstrates itself with two 2-by-2 % random arrays. % % Also see: tensor_sol, tensor_test. % Copyright (C) 1998 Dr. Charles R. Denham, ZYDECO. % All Rights Reserved. % Disclosure without explicit written consent from the % copyright owner does not constitute publication. % Version of 16-Sep-1998 05:34:15. if nargout > 0, c = []; end if nargin < 1, help(mfilename), a = 'demo'; end if isequal(a, 'demo') a = floor(rand(2, 2) * 10); b = floor(rand(2, 2) * 10); ka = 2; kb = 1; result = tensor_mul(a, b, ka, kb); begets(mfilename, 4, a, b, ka, kb, result) ka = 1; kb = 1; result = tensor_mul(a, b, ka, kb); begets(mfilename, 4, a, b, ka, kb, result) ka = 1; kb = 2; result = tensor_mul(a, b, ka, kb); begets(mfilename, 4, a, b, ka, kb, result) ka = 2; kb = 2; result = tensor_mul(a, b, ka, kb); begets(mfilename, 4, a, b, ka, kb, result) return end if nargin < 2, help(mfilename), return, end % Size of factors. sa = size(a); sb = size(b); % Default directions. if nargin < 3, ka = length(sa); end if nargin < 4, kb = 1; end % Check size compatibilities. if any(rem(ka, 1)) | ka < 1 | ka > length(sa) s = num2str(ka); warning([' ## First direction index incompatible: ' s '.']) return end if any(rem(kb, 1)) | kb < 1 | kb > length(sb) s = num2str(kb); warning([' ## Second direction index incompatible: ' s '.']) return end if sa(ka) ~= sb(kb) s = [int2str(sa(ka)) ' ~= ' int2str(sb(kb))]; warning([' ## Specified directions not same length: ' s '.']) return end % Size of product. sc = [sa sb]; sc(length(sa)+kb) = []; sc(ka) = []; % Permute a so ka is leftmost direction. pa = 1:length(sa); if ka > 1 pa(ka) = -pa(ka); pa = abs(sort(pa)); a = permute(a, pa); end % Permute b so kb is leftmost direction. pb = 1:length(sb); if kb > 1 pb(kb) = -pb(kb); pb = abs(sort(pb)); b = permute(b, pb); end % New sizes. sa = size(a); sb = size(b); % Form two-dimensional arrays. a = reshape(a, [sa(1) prod(sa)/sa(1)]); b = reshape(b, [sb(1) prod(sb)/sb(1)]); % Transpose a. pa = 1:length(sa); pa([1 2]) = pa([2 1]); a = permute(a, pa); % Matrix multiply and reshape. c = reshape(a * b, sc);