{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Numeric.LAPACK.Matrix ( General, (Format.##), Format.Format, Format.FormatArray, ZeroInt, zeroInt, transpose, fromList, identity, diagonal, getDiagonal, fromRows, fromRowsWithSize, fromColumns, fromColumnsWithSize, singleRow, singleColumn, flattenRow, flattenColumn, pickRow, pickColumn, takeRows, takeColumns, dropRows, dropColumns, reverseRows, reverseColumns, fromRowMajor, toRowMajor, flatten, (|||), (===), rowSums, columnSums, scaleRows, scaleColumns, multiply, multiplyVector, trace, ) where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Private as Private import qualified Numeric.LAPACK.Format as Format import qualified Numeric.LAPACK.Vector as Vector import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor), charFromOrder) import Numeric.LAPACK.Private (zero, one, pointerSeq, copyTransposed, copySubMatrix, copyBlock) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Data.Array.Comfort.Shape ((:+:)((:+:))) import Foreign.Marshal.Array (copyArray, advancePtr, pokeArray) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek) import System.IO.Unsafe (unsafePerformIO) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import qualified Data.NonEmpty as NonEmpty import Data.Foldable (forM_) import Data.Bool.HT (if') type General height width = Array (MatrixShape.General height width) transpose :: General height width a -> General width height a transpose = Array.mapShape MatrixShape.transpose type ZeroInt = Shape.ZeroBased Int zeroInt :: Int -> ZeroInt zeroInt = Shape.ZeroBased fromList :: (Shape.C height, Shape.C width, Storable a) => height -> width -> [a] -> General height width a fromList height width = Array.fromList (MatrixShape.General RowMajor height width) type Vector = Array identity, _identity :: (Shape.C sh, Storable a, Class.Floating a) => sh -> General sh sh a identity sh = Array.unsafeCreate (MatrixShape.General ColumnMajor sh sh) $ \aPtr -> evalContT $ do uploPtr <- Call.char 'A' nPtr <- Call.cint $ Shape.size sh alphaPtr <- Call.number zero betaPtr <- Call.number one liftIO $ LapackGen.laset uploPtr nPtr nPtr alphaPtr betaPtr aPtr nPtr _identity sh = Array.unsafeCreate (MatrixShape.General ColumnMajor sh sh) $ \yPtr -> evalContT $ do nPtr <- Call.alloca xPtr <- Call.number zero incxPtr <- Call.cint 0 incyPtr <- Call.cint 1 liftIO $ do let n = fromIntegral $ Shape.size sh poke nPtr $ n*n BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr poke nPtr n poke xPtr one poke incyPtr (n+1) BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr diagonal :: (Shape.C sh, Storable a, Class.Floating a) => Vector sh a -> General sh sh a diagonal (Array sh x) = Array.unsafeCreate (MatrixShape.General ColumnMajor sh sh) $ \yPtr -> evalContT $ do nPtr <- Call.alloca xPtr <- ContT $ withForeignPtr x zPtr <- Call.number zero incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 inczPtr <- Call.cint 0 liftIO $ do let n = fromIntegral $ Shape.size sh poke nPtr $ n*n BlasGen.copy nPtr zPtr inczPtr yPtr incyPtr poke nPtr n poke incyPtr (n+1) BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr getDiagonal :: (Shape.C sh, Eq sh, Storable a, Class.Floating a) => General sh sh a -> Vector sh a getDiagonal (Array (MatrixShape.General _ height width) x) = Array.unsafeCreate height $ \yPtr -> do Call.assert "getDiagonal: non-square matrix" (height==width) evalContT $ do let n = Shape.size height nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint (n+1) incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr singleRow :: Vector width a -> General () width a singleRow (Array sh fptr) = Array (MatrixShape.General RowMajor () sh) fptr singleColumn :: Vector width a -> General width () a singleColumn (Array sh fptr) = Array (MatrixShape.General ColumnMajor sh ()) fptr flattenRow :: General () width a -> Vector width a flattenRow (Array (MatrixShape.General _ () sh) fptr) = Array sh fptr flattenColumn :: General width () a -> Vector width a flattenColumn (Array (MatrixShape.General _ sh ()) fptr) = Array sh fptr fromRows :: (Shape.C width, Eq width, Storable a) => NonEmpty.T [] (Vector width a) -> General ZeroInt width a fromRows (NonEmpty.Cons row rows) = fromRowsWithSize (Array.shape row) (row:rows) fromRowsWithSize :: (Shape.C width, Eq width, Storable a) => width -> [Vector width a] -> General ZeroInt width a fromRowsWithSize width rows = Array.unsafeCreate (MatrixShape.General RowMajor (zeroInt $ length rows) width) (gather width rows) fromColumns :: (Shape.C height, Eq height, Storable a) => NonEmpty.T [] (Vector height a) -> General height ZeroInt a fromColumns (NonEmpty.Cons column columns) = fromColumnsWithSize (Array.shape column) (column:columns) fromColumnsWithSize :: (Shape.C height, Eq height, Storable a) => height -> [Vector height a] -> General height ZeroInt a fromColumnsWithSize height columns = Array.unsafeCreate (MatrixShape.General ColumnMajor height (zeroInt $ length columns)) (gather height columns) gather :: (Shape.C width, Eq width, Storable a) => width -> [Array width a] -> Ptr a -> IO () gather width rows dstPtr = let widthSize = Shape.size width in forM_ (zip (pointerSeq widthSize dstPtr) rows) $ \(dstRowPtr, Array.Array rowWidth srcFPtr) -> withForeignPtr srcFPtr $ \srcPtr -> do Call.assert "Matrix.fromRows/fromColumns: non-matching vector size" (width == rowWidth) copyArray dstRowPtr srcPtr widthSize pickRow :: (Shape.C height, Shape.C width, Shape.Index height ~ ix, Storable a, Class.Floating a) => General height width a -> ix -> Vector width a pickRow (Array (MatrixShape.General order height width) x) ix = case order of RowMajor -> pickConsecutive height width x ix ColumnMajor -> pickScattered width height x ix pickColumn :: (Shape.C height, Shape.C width, Shape.Index width ~ ix, Storable a, Class.Floating a) => General height width a -> ix -> Vector height a pickColumn (Array (MatrixShape.General order height width) x) ix = case order of RowMajor -> pickScattered height width x ix ColumnMajor -> pickConsecutive width height x ix pickConsecutive :: (Shape.C height, Shape.C width, Shape.Index height ~ ix, Storable a, Class.Floating a) => height -> width -> ForeignPtr a -> ix -> Vector width a pickConsecutive height width x ix = Array.unsafeCreate width $ \yPtr -> evalContT $ do let n = Shape.size width let offset = Shape.offset height ix nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr (advancePtr xPtr (n*offset)) incxPtr yPtr incyPtr pickScattered :: (Shape.C height, Shape.C width, Shape.Index width ~ ix, Storable a, Class.Floating a) => height -> width -> ForeignPtr a -> ix -> Vector height a pickScattered height width x ix = Array.unsafeCreate height $ \yPtr -> evalContT $ do let n = Shape.size height let offset = Shape.offset width ix nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint $ Shape.size width incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr (advancePtr xPtr offset) incxPtr yPtr incyPtr takeRows, dropRows :: (Shape.C width, Storable a, Class.Floating a) => Int -> General ZeroInt width a -> General ZeroInt width a takeRows k (Array (MatrixShape.General order (Shape.ZeroBased heightA) width) a) = let heightB = min k heightA n = Shape.size width in if' (k<0) (error "take: negative number") $ Array.unsafeCreate (MatrixShape.General order (Shape.ZeroBased heightB) width) $ \bPtr -> withForeignPtr a $ \aPtr -> case order of RowMajor -> copyBlock (heightB*n) aPtr bPtr ColumnMajor -> copySubMatrix heightB n heightA aPtr heightB bPtr dropRows k0 (Array (MatrixShape.General order (Shape.ZeroBased heightA) width) a) = let k = min k0 heightA heightB = heightA - k n = Shape.size width in if' (k<0) (error "take: negative number") $ Array.unsafeCreate (MatrixShape.General order (Shape.ZeroBased heightB) width) $ \bPtr -> withForeignPtr a $ \aPtr -> case order of RowMajor -> copyBlock (heightB*n) (advancePtr aPtr (k*n)) bPtr ColumnMajor -> copySubMatrix heightB n heightA (advancePtr aPtr k) heightB bPtr takeColumns, dropColumns :: (Shape.C height, Storable a, Class.Floating a) => Int -> General height ZeroInt a -> General height ZeroInt a takeColumns k = transpose . takeRows k . transpose dropColumns k = transpose . dropRows k . transpose -- alternative: laswp reverseRows :: (Shape.C width, Storable a, Class.Floating a) => General ZeroInt width a -> General ZeroInt width a reverseRows (Array shape@(MatrixShape.General order height width) a) = Array.unsafeCreate shape $ \bPtr -> evalContT $ do let n = Shape.size height let m = Shape.size width fwdPtr <- Call.bool True nPtr <- Call.cint n mPtr <- Call.cint m kPtr <- Call.allocaArray n aPtr <- ContT $ withForeignPtr a liftIO $ do copyBlock (n*m) aPtr bPtr pokeArray kPtr $ take n $ iterate (subtract 1) $ fromIntegral n case order of RowMajor -> LapackGen.lapmt fwdPtr mPtr nPtr bPtr mPtr kPtr ColumnMajor -> LapackGen.lapmr fwdPtr nPtr mPtr bPtr nPtr kPtr reverseColumns :: (Shape.C height, Storable a, Class.Floating a) => General height ZeroInt a -> General height ZeroInt a reverseColumns = transpose . reverseRows . transpose fromRowMajor :: (Shape.C height, Shape.C width, Storable a, Class.Floating a) => Array (height,width) a -> General height width a fromRowMajor (Array (height,width) x) = Array (MatrixShape.General RowMajor height width) x toRowMajor :: (Shape.C height, Shape.C width, Storable a, Class.Floating a) => General height width a -> Array (height,width) a toRowMajor (Array (MatrixShape.General order height width) x) = let shape = (height, width) in case order of RowMajor -> Array shape x ColumnMajor -> Array.unsafeCreate shape $ \yPtr -> evalContT $ do let n = Shape.size width let m = Shape.size height nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint m incyPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\xkPtr ykPtr -> BlasGen.copy nPtr xkPtr incxPtr ykPtr incyPtr) (pointerSeq 1 xPtr) (pointerSeq n yPtr) flatten :: (Shape.C height, Shape.C width, Storable a, Class.Floating a) => General height width a -> Vector ZeroInt a flatten x = case toRowMajor x of Array (height,width) fptr -> Array (zeroInt $ Shape.size height * Shape.size width) fptr infixl 3 ||| infixl 2 === (|||) :: (Shape.C height, Eq height, Shape.C widtha, Shape.C widthb, Storable a, Class.Floating a) => General height widtha a -> General height widthb a -> General height (widtha:+:widthb) a (|||) (Array (MatrixShape.General orderA heightA widthA) a) (Array (MatrixShape.General orderB heightB widthB) b) = if heightA /= heightB then error "(|||): mismatching heights" else case (orderA,orderB) of (RowMajor,RowMajor) -> Array.unsafeCreate (MatrixShape.General RowMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let ma = Shape.size widthA let mb = Shape.size widthB let m = ma+mb maPtr <- Call.cint ma mbPtr <- Call.cint mb aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ sequence_ $ take n $ zipWith3 (\akPtr bkPtr ckPtr -> do BlasGen.copy maPtr akPtr incxPtr ckPtr incyPtr BlasGen.copy mbPtr bkPtr incxPtr (ckPtr `advancePtr` ma) incyPtr) (pointerSeq ma aPtr) (pointerSeq mb bPtr) (pointerSeq m cPtr) (RowMajor,ColumnMajor) -> Array.unsafeCreate (MatrixShape.General ColumnMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let ma = Shape.size widthA let mb = Shape.size widthB aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b liftIO $ do copyTransposed n ma aPtr n cPtr copyBlock (n*mb) bPtr (advancePtr cPtr (n*ma)) (ColumnMajor,RowMajor) -> Array.unsafeCreate (MatrixShape.General ColumnMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let ma = Shape.size widthA let mb = Shape.size widthB let volA = n*ma aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b liftIO $ do copyBlock volA aPtr cPtr copyTransposed n mb bPtr n (advancePtr cPtr volA) (ColumnMajor,ColumnMajor) -> Array.unsafeCreate (MatrixShape.General ColumnMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let na = n * Shape.size widthA let nb = n * Shape.size widthB naPtr <- Call.cint na nbPtr <- Call.cint nb aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ do BlasGen.copy naPtr aPtr incxPtr cPtr incyPtr BlasGen.copy nbPtr bPtr incxPtr (cPtr `advancePtr` na) incyPtr (===) :: (Shape.C width, Eq width, Shape.C heighta, Shape.C heightb, Storable a, Class.Floating a) => General heighta width a -> General heightb width a -> General (heighta:+:heightb) width a (===) a b = transpose (transpose a ||| transpose b) rowSums :: (Shape.C height, Shape.C width, Storable a, Class.Floating a) => General height width a -> Vector height a rowSums m = let MatrixShape.General _ _ width = Array.shape m in multiplyVectorUnchecked m (Vector.constant width one) columnSums :: (Shape.C height, Shape.C width, Storable a, Class.Floating a) => General height width a -> Vector width a columnSums m = let MatrixShape.General _ height _ = Array.shape m in multiplyVectorUnchecked (transpose m) (Vector.constant height one) multiplyVector :: (Shape.C height, Shape.C width, Eq width, Storable a, Class.Floating a) => General height width a -> Vector width a -> Vector height a multiplyVector a x = let MatrixShape.General _order _height width = Array.shape a in if width == Array.shape x then multiplyVectorUnchecked a x else error "multiplyVector: width shapes mismatch" multiplyVectorUnchecked :: (Shape.C height, Shape.C width, Storable a, Class.Floating a) => General height width a -> Vector width a -> Vector height a multiplyVectorUnchecked (Array shape@(MatrixShape.General order height _width) a) (Array _ x) = Array.unsafeCreate height $ \yPtr -> do let (m,n) = MatrixShape.dimensions shape let lda = m evalContT $ do transPtr <- Call.char $ charFromOrder order mPtr <- Call.cint m nPtr <- Call.cint n alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.cint lda xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 betaPtr <- Call.number zero incyPtr <- Call.cint 1 liftIO $ BlasGen.gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr multiply :: (Shape.C height, Shape.C fuse, Eq fuse, Shape.C width, Storable a, Class.Floating a) => General height fuse a -> General fuse width a -> General height width a multiply (Array (MatrixShape.General orderA height fuseA) a) (Array (MatrixShape.General orderB fuseB width) b) = Array.unsafeCreate (MatrixShape.General ColumnMajor height width) $ \cPtr -> do Call.assert "multiply: fuse shapes mismatch" (fuseA == fuseB) let m = Shape.size height let n = Shape.size width let k = Shape.size fuseA let lda = case orderA of RowMajor -> k; ColumnMajor -> m let ldb = case orderB of RowMajor -> n; ColumnMajor -> k let ldc = m evalContT $ do transaPtr <- Call.char $ charFromOrder orderA transbPtr <- Call.char $ charFromOrder orderB mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- Call.cint k alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.cint lda bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint ldb betaPtr <- Call.number zero ldcPtr <- Call.cint ldc liftIO $ BlasGen.gemm transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr scaleRows :: (Shape.C height, Eq height, Shape.C width, Storable a, Class.Floating a) => Vector height a -> General height width a -> General height width a scaleRows (Array heightX x) (Array shape@(MatrixShape.General order height width) a) = Array.unsafeCreate shape $ \bPtr -> do Call.assert "scaleRows: sizes mismatch" (heightX == height) case order of RowMajor -> evalContT $ do let m = Shape.size height let n = Shape.size width alphaPtr <- Call.alloca nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x aPtr <- ContT $ withForeignPtr a incaPtr <- Call.cint 1 incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith3 (\xkPtr akPtr bkPtr -> do poke alphaPtr =<< peek xkPtr BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr BlasGen.scal nPtr alphaPtr bkPtr incbPtr) (pointerSeq 1 xPtr) (pointerSeq n aPtr) (pointerSeq n bPtr) ColumnMajor -> evalContT $ do let m = Shape.size width let n = Shape.size height transPtr <- Call.char 'N' nPtr <- Call.cint n klPtr <- Call.cint 0 kuPtr <- Call.cint 0 alphaPtr <- Call.number one xPtr <- ContT $ withForeignPtr x ldxPtr <- Call.cint 1 aPtr <- ContT $ withForeignPtr a incaPtr <- Call.cint 1 betaPtr <- Call.number zero incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\akPtr bkPtr -> BlasGen.gbmv transPtr nPtr nPtr klPtr kuPtr alphaPtr xPtr ldxPtr akPtr incaPtr betaPtr bkPtr incbPtr) (pointerSeq n aPtr) (pointerSeq n bPtr) scaleColumns :: (Shape.C height, Shape.C width, Eq width, Storable a, Class.Floating a) => Vector width a -> General height width a -> General height width a scaleColumns x = transpose . scaleRows x . transpose trace :: (Shape.C sh, Eq sh, Class.Floating a) => General sh sh a -> a trace (Array (MatrixShape.General _ height width) x) = unsafePerformIO $ do Call.assert "trace: non-square matrix" (height==width) let n = Shape.size height withForeignPtr x $ \xPtr -> Private.sum n xPtr (n+1)