// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2008-2016 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H #define EIGEN_GENERAL_MATRIX_VECTOR_H namespace Eigen { namespace internal { /* Optimized col-major matrix * vector product: * This algorithm processes the matrix per vertical panels, * which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments. * * Mixing type logic: C += alpha * A * B * | A | B |alpha| comments * |real |cplx |cplx | no vectorization * |real |cplx |real | alpha is converted to a cplx when calling the run function, no vectorization * |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp * |cplx |real |real | optimal case, vectorization possible via real-cplx mul * * The same reasoning apply for the transposed case. */ template struct general_matrix_vector_product { typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; enum { Vectorizable = packet_traits::Vectorizable && packet_traits::Vectorizable && int(packet_traits::size)==int(packet_traits::size), LhsPacketSize = Vectorizable ? packet_traits::size : 1, RhsPacketSize = Vectorizable ? packet_traits::size : 1, ResPacketSize = Vectorizable ? packet_traits::size : 1 }; typedef typename packet_traits::type _LhsPacket; typedef typename packet_traits::type _RhsPacket; typedef typename packet_traits::type _ResPacket; typedef typename conditional::type LhsPacket; typedef typename conditional::type RhsPacket; typedef typename conditional::type ResPacket; EIGEN_DONT_INLINE static void run( Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, RhsScalar alpha); }; template EIGEN_DONT_INLINE void general_matrix_vector_product::run( Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, RhsScalar alpha) { EIGEN_UNUSED_VARIABLE(resIncr); eigen_internal_assert(resIncr==1); // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate propoer code. LhsMapper lhs(alhs); conj_helper cj; conj_helper pcj; const Index lhsStride = lhs.stride(); // TODO: for padded aligned inputs, we could enable aligned reads enum { LhsAlignment = Unaligned }; const Index n8 = rows-8*ResPacketSize+1; const Index n4 = rows-4*ResPacketSize+1; const Index n3 = rows-3*ResPacketSize+1; const Index n2 = rows-2*ResPacketSize+1; const Index n1 = rows-1*ResPacketSize+1; // TODO: improve the following heuristic: const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4); ResPacket palpha = pset1(alpha); for(Index j2=0; j2(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)), c4 = pset1(ResScalar(0)), c5 = pset1(ResScalar(0)), c6 = pset1(ResScalar(0)), c7 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+LhsPacketSize*2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+LhsPacketSize*3,j),b0,c3); c4 = pcj.pmadd(lhs.template load(i+LhsPacketSize*4,j),b0,c4); c5 = pcj.pmadd(lhs.template load(i+LhsPacketSize*5,j),b0,c5); c6 = pcj.pmadd(lhs.template load(i+LhsPacketSize*6,j),b0,c6); c7 = pcj.pmadd(lhs.template load(i+LhsPacketSize*7,j),b0,c7); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu(res+i+ResPacketSize*2))); pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu(res+i+ResPacketSize*3))); pstoreu(res+i+ResPacketSize*4, pmadd(c4,palpha,ploadu(res+i+ResPacketSize*4))); pstoreu(res+i+ResPacketSize*5, pmadd(c5,palpha,ploadu(res+i+ResPacketSize*5))); pstoreu(res+i+ResPacketSize*6, pmadd(c6,palpha,ploadu(res+i+ResPacketSize*6))); pstoreu(res+i+ResPacketSize*7, pmadd(c7,palpha,ploadu(res+i+ResPacketSize*7))); } if(i(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+LhsPacketSize*2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+LhsPacketSize*3,j),b0,c3); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu(res+i+ResPacketSize*2))); pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu(res+i+ResPacketSize*3))); i+=ResPacketSize*4; } if(i(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+LhsPacketSize*2,j),b0,c2); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu(res+i+ResPacketSize*2))); i+=ResPacketSize*3; } if(i(ResScalar(0)), c1 = pset1(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+LhsPacketSize*0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+LhsPacketSize*1,j),b0,c1); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu(res+i+ResPacketSize*1))); i+=ResPacketSize*2; } if(i(ResScalar(0)); for(Index j=j2; j(rhs(j,0)); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); } pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu(res+i+ResPacketSize*0))); i+=ResPacketSize; } for(;i struct general_matrix_vector_product { typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; enum { Vectorizable = packet_traits::Vectorizable && packet_traits::Vectorizable && int(packet_traits::size)==int(packet_traits::size), LhsPacketSize = Vectorizable ? packet_traits::size : 1, RhsPacketSize = Vectorizable ? packet_traits::size : 1, ResPacketSize = Vectorizable ? packet_traits::size : 1 }; typedef typename packet_traits::type _LhsPacket; typedef typename packet_traits::type _RhsPacket; typedef typename packet_traits::type _ResPacket; typedef typename conditional::type LhsPacket; typedef typename conditional::type RhsPacket; typedef typename conditional::type ResPacket; EIGEN_DONT_INLINE static void run( Index rows, Index cols, const LhsMapper& lhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha); }; template EIGEN_DONT_INLINE void general_matrix_vector_product::run( Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha) { // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate propoer code. LhsMapper lhs(alhs); eigen_internal_assert(rhs.stride()==1); conj_helper cj; conj_helper pcj; // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large, // processing 8 rows at once might be counter productive wrt cache. const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7; const Index n4 = rows-3; const Index n2 = rows-1; // TODO: for padded aligned inputs, we could enable aligned reads enum { LhsAlignment = Unaligned }; Index i=0; for(; i(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)), c4 = pset1(ResScalar(0)), c5 = pset1(ResScalar(0)), c6 = pset1(ResScalar(0)), c7 = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+3,j),b0,c3); c4 = pcj.pmadd(lhs.template load(i+4,j),b0,c4); c5 = pcj.pmadd(lhs.template load(i+5,j),b0,c5); c6 = pcj.pmadd(lhs.template load(i+6,j),b0,c6); c7 = pcj.pmadd(lhs.template load(i+7,j),b0,c7); } ResScalar cc0 = predux(c0); ResScalar cc1 = predux(c1); ResScalar cc2 = predux(c2); ResScalar cc3 = predux(c3); ResScalar cc4 = predux(c4); ResScalar cc5 = predux(c5); ResScalar cc6 = predux(c6); ResScalar cc7 = predux(c7); for(; j(ResScalar(0)), c1 = pset1(ResScalar(0)), c2 = pset1(ResScalar(0)), c3 = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+1,j),b0,c1); c2 = pcj.pmadd(lhs.template load(i+2,j),b0,c2); c3 = pcj.pmadd(lhs.template load(i+3,j),b0,c3); } ResScalar cc0 = predux(c0); ResScalar cc1 = predux(c1); ResScalar cc2 = predux(c2); ResScalar cc3 = predux(c3); for(; j(ResScalar(0)), c1 = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i+0,j),b0,c0); c1 = pcj.pmadd(lhs.template load(i+1,j),b0,c1); } ResScalar cc0 = predux(c0); ResScalar cc1 = predux(c1); for(; j(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { RhsPacket b0 = rhs.template load(j,0); c0 = pcj.pmadd(lhs.template load(i,j),b0,c0); } ResScalar cc0 = predux(c0); for(; j