{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} module Numeric.LAPACK.Matrix.Multiply where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Triangular as Triangular import qualified Numeric.LAPACK.Matrix.Hermitian as Hermitian import qualified Numeric.LAPACK.Matrix.Square as Square import qualified Numeric.LAPACK.Vector as Vector import qualified Numeric.LAPACK.Private as Private import Numeric.LAPACK.Matrix.Shape.Private (HeightOf, WidthOf, Order(ColumnMajor), transposeFromOrder) import Numeric.LAPACK.Matrix.Triangular (Triangular) import Numeric.LAPACK.Matrix.Private (General) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Private (zero, one) import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.ForeignPtr (withForeignPtr) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) transpose :: General height width a -> General width height a transpose = Array.mapShape MatrixShape.transpose multiplyVector :: (Shape.C height, Shape.C width, Eq width, Class.Floating a) => General height width a -> Vector width a -> Vector height a multiplyVector a x = let MatrixShape.General _order _height width = Array.shape a in if width == Array.shape x then multiplyVectorUnchecked a x else error "multiplyVector: width shapes mismatch" multiplyVectorUnchecked :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> Vector width a -> Vector height a multiplyVectorUnchecked (Array shape@(MatrixShape.General order height _width) a) (Array _ x) = Array.unsafeCreate height $ \yPtr -> do let (m,n) = MatrixShape.dimensions shape let lda = m evalContT $ do transPtr <- Call.char $ transposeFromOrder order mPtr <- Call.cint m nPtr <- Call.cint n alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.cint lda xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 betaPtr <- Call.number zero incyPtr <- Call.cint 1 liftIO $ BlasGen.gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr multiply :: (Shape.C height, Shape.C fuse, Eq fuse, Shape.C width, Class.Floating a) => General height fuse a -> General fuse width a -> General height width a multiply (Array (MatrixShape.General orderA height fuseA) a) (Array (MatrixShape.General orderB fuseB width) b) = Array.unsafeCreate (MatrixShape.General ColumnMajor height width) $ \cPtr -> do Call.assert "multiply: fuse shapes mismatch" (fuseA == fuseB) let m = Shape.size height let n = Shape.size width let k = Shape.size fuseA Private.multiplyMatrix orderA orderB m k n a b cPtr infixl 7 <#, <#> infixr 7 #> class MultiplyRight shape where (#>) :: (Class.Floating a) => Array shape a -> Array (WidthOf shape) a -> Array (HeightOf shape) a class MultiplyLeft shape where (<#) :: (Class.Floating a) => Array (HeightOf shape) a -> Array shape a -> Array (WidthOf shape) a class Multiply shapeA shapeB where type Multiplied shapeA shapeB (<#>) :: (Class.Floating a) => Array shapeA a -> Array shapeB a -> Array (Multiplied shapeA shapeB) a instance (Eq width, Shape.C width, Shape.C height) => MultiplyRight (MatrixShape.General height width) where (#>) = multiplyVector instance (Eq height, Shape.C width, Shape.C height) => MultiplyLeft (MatrixShape.General height width) where v <# m = multiplyVector (transpose m) v instance (Eq shape, Shape.C shape) => MultiplyRight (MatrixShape.Square shape) where m #> v = multiplyVector (Square.toGeneral m) v instance (Eq shape, Shape.C shape) => MultiplyLeft (MatrixShape.Square shape) where v <# m = multiplyVector (transpose $ Square.toGeneral m) v instance (Eq shape, Shape.C shape) => MultiplyRight (MatrixShape.Hermitian shape) where m #> v = Hermitian.multiplyVector m v instance (Eq shape, Shape.C shape) => MultiplyLeft (MatrixShape.Hermitian shape) where v <# m = Hermitian.multiplyVector (Vector.conjugate m) v instance (MatrixShape.Uplo uplo, Eq shape, Shape.C shape) => MultiplyRight (MatrixShape.Triangular uplo shape) where m #> v = Triangular.multiplyVectorRight m v instance (MatrixShape.Uplo uplo, Eq shape, Shape.C shape) => MultiplyLeft (MatrixShape.Triangular uplo shape) where v <# m = Triangular.multiplyVectorLeft m v instance (Shape.C heightA, Shape.C widthA, Shape.C widthB, widthA ~ heightB, Eq heightB) => Multiply (MatrixShape.General heightA widthA) (MatrixShape.General heightB widthB) where type Multiplied (MatrixShape.General heightA widthA) (MatrixShape.General heightB widthB) = MatrixShape.General heightA widthB (<#>) = multiply instance (Shape.C shapeA, Shape.C widthB, shapeA ~ heightB, Eq heightB) => Multiply (MatrixShape.Square shapeA) (MatrixShape.General heightB widthB) where type Multiplied (MatrixShape.Square shapeA) (MatrixShape.General heightB widthB) = MatrixShape.General heightB widthB a <#> b = multiply (Square.toGeneral a) b instance (Shape.C heightA, Shape.C widthA, widthA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.General heightA widthA) (MatrixShape.Square shapeB) where type Multiplied (MatrixShape.General heightA widthA) (MatrixShape.Square shapeB) = MatrixShape.General heightA widthA a <#> b = multiply a (Square.toGeneral b) instance (Shape.C shapeA, shapeA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.Square shapeA) (MatrixShape.Square shapeB) where type Multiplied (MatrixShape.Square shapeA) (MatrixShape.Square shapeB) = MatrixShape.Square shapeA (<#>) = Square.multiply instance (Shape.C shapeA, shapeA ~ width, Eq width, Shape.C height) => Multiply (MatrixShape.General height width) (MatrixShape.Hermitian shapeA) where type Multiplied (MatrixShape.General height width) (MatrixShape.Hermitian shapeA) = MatrixShape.General height width (<#>) = Hermitian.multiplyGeneralLeft instance (Shape.C shapeA, shapeA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.Square shapeB) (MatrixShape.Hermitian shapeA) where type Multiplied (MatrixShape.Square shapeB) (MatrixShape.Hermitian shapeA) = MatrixShape.Square shapeA (<#>) = Hermitian.multiplySquareLeft instance (Shape.C shapeA, shapeA ~ height, Eq height, Shape.C width) => Multiply (MatrixShape.Hermitian shapeA) (MatrixShape.General height width) where type Multiplied (MatrixShape.Hermitian shapeA) (MatrixShape.General height width) = MatrixShape.General height width (<#>) = Hermitian.multiplyGeneralRight instance (Shape.C shapeA, shapeA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.Hermitian shapeA) (MatrixShape.Square shapeB) where type Multiplied (MatrixShape.Hermitian shapeA) (MatrixShape.Square shapeB) = MatrixShape.Square shapeA (<#>) = Hermitian.multiplySquareRight instance (Shape.C shapeA, shapeA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.Hermitian shapeA) (MatrixShape.Hermitian shapeB) where type Multiplied (MatrixShape.Hermitian shapeA) (MatrixShape.Hermitian shapeB) = MatrixShape.Square shapeA a <#> b = Hermitian.multiplySquareRight a (Hermitian.toSquare b) instance (MatrixShape.Uplo uplo, Shape.C shapeA, shapeA ~ width, Eq width, Shape.C height) => Multiply (MatrixShape.General height width) (MatrixShape.Triangular uplo shapeA) where type Multiplied (MatrixShape.General height width) (MatrixShape.Triangular uplo shapeA) = MatrixShape.General height width (<#>) = Triangular.multiplyGeneralLeft instance (MatrixShape.Uplo uplo, Shape.C shapeA, shapeA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.Square shapeB) (MatrixShape.Triangular uplo shapeA) where type Multiplied (MatrixShape.Square shapeB) (MatrixShape.Triangular uplo shapeA) = MatrixShape.Square shapeA (<#>) = Triangular.multiplySquareLeft instance (MatrixShape.Uplo uplo, Shape.C shapeA, shapeA ~ height, Eq height, Shape.C width) => Multiply (MatrixShape.Triangular uplo shapeA) (MatrixShape.General height width) where type Multiplied (MatrixShape.Triangular uplo shapeA) (MatrixShape.General height width) = MatrixShape.General height width (<#>) = Triangular.multiplyGeneralRight instance (MatrixShape.Uplo uplo, Shape.C shapeA, shapeA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.Triangular uplo shapeA) (MatrixShape.Square shapeB) where type Multiplied (MatrixShape.Triangular uplo shapeA) (MatrixShape.Square shapeB) = MatrixShape.Square shapeA (<#>) = Triangular.multiplySquareRight instance (Shape.C shapeA, shapeA ~ shapeB, Eq shapeB, MultiplyTriangular uploA uploB) => Multiply (MatrixShape.Triangular uploA shapeA) (MatrixShape.Triangular uploB shapeB) where type Multiplied (MatrixShape.Triangular uploA shapeA) (MatrixShape.Triangular uploB shapeB) = MultipliedTriangular uploA uploB shapeB (<#>) = multiplyTriangular class MultiplyTriangular uploA uploB where type MultipliedTriangular uploA uploB :: * -> * multiplyTriangular :: (Class.Floating a, Shape.C shape, Eq shape) => Triangular uploA shape a -> Triangular uploB shape a -> Array (MultipliedTriangular uploA uploB shape) a instance MultiplyTriangular MatrixShape.Lower MatrixShape.Lower where type MultipliedTriangular MatrixShape.Lower MatrixShape.Lower = MatrixShape.Triangular MatrixShape.Lower multiplyTriangular = Triangular.multiply instance MultiplyTriangular MatrixShape.Upper MatrixShape.Upper where type MultipliedTriangular MatrixShape.Upper MatrixShape.Upper = MatrixShape.Triangular MatrixShape.Upper multiplyTriangular = Triangular.multiply instance MultiplyTriangular MatrixShape.Lower MatrixShape.Upper where type MultipliedTriangular MatrixShape.Lower MatrixShape.Upper = MatrixShape.Square multiplyTriangular a b = Square.multiply (Triangular.toSquare a) (Triangular.toSquare b) instance MultiplyTriangular MatrixShape.Upper MatrixShape.Lower where type MultipliedTriangular MatrixShape.Upper MatrixShape.Lower = MatrixShape.Square multiplyTriangular a b = Square.multiply (Triangular.toSquare a) (Triangular.toSquare b)