function x = tensor_sol(a, b, ka, kx) % tensor_sol -- Tensor solve. % tensor_sol(a, b, ka, kx) solves for tensor x in a*x = b, % for summations along the respective ka and kx directions. % If the directions are not specified, the innermost ones % are assumed. % tensor_sol('demo') demonstrates itself with two 2-by-2 % random arrays. % % Also see: tensor_mul, tensor_test. % Copyright (C) 1998 Dr. Charles R. Denham, ZYDECO. % All Rights Reserved. % Disclosure without explicit written consent from the % copyright owner does not constitute publication. % Version of 16-Sep-1998 05:34:15. if nargout > 0, x = []; end if nargin < 1, help(mfilename), a = 'demo'; end if isequal(a, 'demo') a = floor(rand(2, 2) * 10); b = floor(rand(2, 2) * 10); ka = 2; kx = 1; result = tensor_sol(a, b, ka, kx); begets(mfilename, 4, a, b, ka, kx, result) ka = 1; kx = 1; result = tensor_sol(a, b, ka, kx); begets(mfilename, 4, a, b, ka, kx, result) ka = 1; kx = 2; result = tensor_sol(a, b, ka, kx); begets(mfilename, 4, a, b, ka, kx, result) ka = 2; kx = 2; result = tensor_sol(a, b, ka, kx); begets(mfilename, 4, a, b, ka, kx, result) return end if nargin < 2, help(mfilename), return, end % Size of factors. sa = size(a); sb = size(b); % Default directions. if nargin < 3, ka = length(sa); end if nargin < 4, kx = 1; end % Compatibility of sizes. u = sa; u(ka) = []; v = sb(1:length(u)); if ~isequal(u, v) s = [mat2str(sa) ' vs ' mat2str(sb)]; warning([' ## Incompatible sizes: ' s '.']) return end % Size of result. sx = sb; sx(1:length(sa)-1) = []; ix = [1:length(sx) (kx-0.5)]; sx = [sx sa(ka)]; [ignore, ind] = sort(ix); sx = sx(ind); % Check size compatibilities. if any(rem(ka, 1)) | ka < 1 | ka > length(sa) s = num2str(ka); warning([' ## First direction index incompatible: ' s '.']) return end if any(rem(kx, 1)) | kx < 1 | kx > length(sx) s = num2str(kx); warning([' ## Second direction index incompatible: ' s '.']) return end if sa(ka) ~= sx(kx) s = [int2str(sa(ka)) ' ~= ' int2str(sx(kx))]; warning([' ## Specified directions not same length: ' s '.']) return end % Permute a so ka is leftmost direction. pa = 1:length(sa); if ka > 1 pa(ka) = -pa(ka); pa = abs(sort(pa)); a = permute(a, pa); end sa = size(a); % Form two-dimensional a. a = reshape(a, [sa(1) prod(sa)/sa(1)]); sa = size(a); % Transpose a. pa = 1:length(sa); pa([1 2]) = pa([2 1]); a = permute(a, pa); sa = size(a); % Form two-dimensional b. b = reshape(b, [sa(1) prod(sb)/sa(1)]); % Matrix solve and reshape. x = reshape(a \ b, sx); % Inverse permutation of x. px = 1:length(sx); if kx > 1 px(kx) = -px(kx); px = abs(sort(px)); x = ipermute(x, px); end