{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Numeric.BLAS.Matrix.RowMajor ( Matrix, Vector, takeRow, takeColumn, fromRows, tensorProduct, decomplex, recomplex, scaleRows, scaleColumns, multiplyVectorLeft, multiplyVectorRight, ) where import qualified Numeric.BLAS.Matrix.Modifier as Modifier import qualified Numeric.BLAS.Private as Private import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated,Conjugated)) import Numeric.BLAS.Scalar (zero, one) import Numeric.BLAS.Private (ShapeInt, shapeInt, ComplexShape, pointerSeq, fill) import qualified Numeric.BLAS.FFI.Generic as Blas import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Foreign.Marshal.Array (copyArray, advancePtr) import Foreign.ForeignPtr (withForeignPtr, castForeignPtr) import Foreign.Storable (Storable) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Data.Foldable (forM_) import Data.Complex (Complex) import Data.Tuple.HT (swap) type Matrix height width = Array (height,width) type Vector = Array takeRow :: (Shape.Indexed height, Shape.C width, Shape.Index height ~ ix, Storable a) => ix -> Matrix height width a -> Vector width a takeRow ix (Array (height,width) x) = Array.unsafeCreateWithSize width $ \n yPtr -> withForeignPtr x $ \xPtr -> copyArray yPtr (advancePtr xPtr (n * Shape.offset height ix)) n takeColumn :: (Shape.C height, Shape.Indexed width, Shape.Index width ~ ix, Class.Floating a) => ix -> Matrix height width a -> Vector height a takeColumn ix (Array (height,width) x) = Array.unsafeCreateWithSize height $ \n yPtr -> evalContT $ do let offset = Shape.offset width ix nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint $ Shape.size width incyPtr <- Call.cint 1 liftIO $ Blas.copy nPtr (advancePtr xPtr offset) incxPtr yPtr incyPtr fromRows :: (Shape.C width, Eq width, Storable a) => width -> [Vector width a] -> Matrix ShapeInt width a fromRows width rows = Array.unsafeCreate (shapeInt $ length rows, width) $ \dstPtr -> let widthSize = Shape.size width in forM_ (zip (pointerSeq widthSize dstPtr) rows) $ \(dstRowPtr, Array.Array rowWidth srcFPtr) -> withForeignPtr srcFPtr $ \srcPtr -> do Call.assert "Matrix.fromRows: non-matching vector size" (width == rowWidth) copyArray dstRowPtr srcPtr widthSize -- ToDo: use lapack:Private.multiplyMatrix tensorProduct :: (Shape.C height, Shape.C width, Class.Floating a) => Either Conjugation Conjugation -> Vector height a -> Vector width a -> Matrix height width a tensorProduct side (Array height x) (Array width y) = Array.unsafeCreate (height,width) $ \cPtr -> do let m = Shape.size width let n = Shape.size height let trans conjugated = case conjugated of NonConjugated -> 'T'; Conjugated -> 'C' let ((transa,transb),(lda,ldb)) = case side of Left c -> ((trans c, 'N'),(1,1)) Right c -> (('N', trans c),(m,n)) evalContT $ do transaPtr <- Call.char transa transbPtr <- Call.char transb mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- Call.cint 1 alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr y ldaPtr <- Call.leadingDim lda bPtr <- ContT $ withForeignPtr x ldbPtr <- Call.leadingDim ldb betaPtr <- Call.number zero ldcPtr <- Call.leadingDim m liftIO $ Blas.gemm transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr decomplex :: (Class.Real a) => Matrix height width (Complex a) -> Matrix height (width, ComplexShape) a decomplex (Array (height,width) a) = Array (height, (width, Shape.static)) (castForeignPtr a) recomplex :: (Class.Real a) => Matrix height (width, ComplexShape) a -> Matrix height width (Complex a) recomplex (Array (height, (width, Shape.NestedTuple _)) a) = Array (height,width) (castForeignPtr a) scaleRows :: (Shape.C height, Eq height, Shape.C width, Class.Floating a) => Vector height a -> Matrix height width a -> Matrix height width a scaleRows (Array heightX x) (Array shape@(height,width) a) = Array.unsafeCreate shape $ \bPtr -> do Call.assert "scaleRows: sizes mismatch" (heightX == height) evalContT $ do let m = Shape.size height let n = Shape.size width nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x aPtr <- ContT $ withForeignPtr a incaPtr <- Call.cint 1 incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith3 (\xkPtr akPtr bkPtr -> do Blas.copy nPtr akPtr incaPtr bkPtr incbPtr Blas.scal nPtr xkPtr bkPtr incbPtr) (pointerSeq 1 xPtr) (pointerSeq n aPtr) (pointerSeq n bPtr) scaleColumns :: (Shape.C height, Shape.C width, Eq width, Class.Floating a) => Vector width a -> Matrix height width a -> Matrix height width a scaleColumns (Array widthX x) (Array shape@(height,width) a) = Array.unsafeCreate shape $ \bPtr -> do Call.assert "scaleColumns: sizes mismatch" (widthX == width) evalContT $ do let m = Shape.size height let n = Shape.size width transPtr <- Call.char 'N' nPtr <- Call.cint n klPtr <- Call.cint 0 kuPtr <- Call.cint 0 alphaPtr <- Call.number one xPtr <- ContT $ withForeignPtr x ldxPtr <- Call.leadingDim 1 aPtr <- ContT $ withForeignPtr a incaPtr <- Call.cint 1 betaPtr <- Call.number zero incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\akPtr bkPtr -> Private.gbmv transPtr nPtr nPtr klPtr kuPtr alphaPtr xPtr ldxPtr akPtr incaPtr betaPtr bkPtr incbPtr) (pointerSeq n aPtr) (pointerSeq n bPtr) multiplyVectorLeft :: (Eq height, Shape.C height, Shape.C width, Class.Floating a) => Vector height a -> Matrix height width a -> Vector width a multiplyVectorLeft = multiplyVector nonTransposed multiplyVectorRight :: (Shape.C height, Shape.C width, Eq width, Class.Floating a) => Matrix height width a -> Vector width a -> Vector height a multiplyVectorRight = flip $ multiplyVector transposed data Transposition heightA widthA heightB widthB = Transposition Modifier.Transposition Char ((heightA,widthA) -> (heightB,widthB)) transposed :: Transposition height width width height transposed = Transposition Modifier.Transposed 'T' swap nonTransposed :: Transposition height width height width nonTransposed = Transposition Modifier.NonTransposed 'N' id multiplyVector :: (Shape.C heightB, Shape.C widthB, Eq heightB, Class.Floating a) => Transposition heightA widthA heightB widthB -> Vector heightB a -> Matrix heightA widthA a -> Vector widthB a multiplyVector (Transposition trans transChar assignDims) (Array sh x) (Array shA a) = let (height,width) = assignDims shA in Array.unsafeCreateWithSize width $ \m0 yPtr -> do Call.assert "Matrix.RowMajor.multiplyVector: shapes mismatch" (height == sh) let n0 = Shape.size height let (m,n) = case trans of Modifier.NonTransposed -> (m0,n0) Modifier.Transposed -> (n0,m0) if n==0 then fill zero m yPtr else evalContT $ do let lda = m transPtr <- Call.char transChar mPtr <- Call.cint m nPtr <- Call.cint n alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim lda xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 betaPtr <- Call.number zero incyPtr <- Call.cint 1 liftIO $ Blas.gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr