{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Numeric.LAPACK.Matrix ( General, (##), Format, FormatArray, ZeroInt, zeroInt, transpose, adjoint, fromScalar, toScalar, fromList, identity, diagonal, getDiagonal, fromRows, fromRowsWithSize, fromColumns, fromColumnsWithSize, singleRow, singleColumn, flattenRow, flattenColumn, pickRow, pickColumn, takeRows, takeColumns, dropRows, dropColumns, reverseRows, reverseColumns, fromRowMajor, toRowMajor, flatten, (|||), (===), rowSums, columnSums, scaleRows, scaleColumns, multiply, multiplyVector, Multiply, (<#>), MultiplyLeft, (<#), MultiplyRight, (#>), ) where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Square as Square import qualified Numeric.LAPACK.Vector as Vector import Numeric.LAPACK.Format (Format, FormatArray, (##)) import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor)) import Numeric.LAPACK.Matrix.Multiply (Multiply((<#>)), MultiplyLeft((<#)), MultiplyRight((#>)), transpose, multiplyVector, multiply, multiplyVectorUnchecked) import Numeric.LAPACK.Matrix.Private (General, ZeroInt, zeroInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Private (zero, one, pointerSeq, copyTransposed, copySubMatrix, copyBlock) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Data.Array.Comfort.Shape ((:+:)((:+:))) import Foreign.Marshal.Array (copyArray, advancePtr, pokeArray) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek) import System.IO.Unsafe (unsafePerformIO) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import qualified Data.NonEmpty as NonEmpty import Data.Foldable (forM_) import Data.Bool.HT (if') {- | conjugate transpose -} adjoint :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> General width height a adjoint = transpose . Vector.conjugate fromScalar :: (Storable a) => a -> General () () a fromScalar = Square.toGeneral . Square.fromScalar toScalar :: (Storable a) => General () () a -> a toScalar (Array (MatrixShape.General _ () ()) a) = unsafePerformIO $ withForeignPtr a peek fromList :: (Shape.C height, Shape.C width, Storable a) => height -> width -> [a] -> General height width a fromList height width = Array.fromList (MatrixShape.General RowMajor height width) identity :: (Shape.C sh, Class.Floating a) => sh -> General sh sh a identity = Square.toGeneral . Square.identity diagonal :: (Shape.C sh, Class.Floating a) => Vector sh a -> General sh sh a diagonal = Square.toGeneral . Square.diagonal getDiagonal :: (Shape.C sh, Eq sh, Class.Floating a) => General sh sh a -> Vector sh a getDiagonal = Square.getDiagonal . Square.fromGeneral singleRow :: Vector width a -> General () width a singleRow (Array sh fptr) = Array (MatrixShape.General RowMajor () sh) fptr singleColumn :: Vector width a -> General width () a singleColumn (Array sh fptr) = Array (MatrixShape.General ColumnMajor sh ()) fptr flattenRow :: General () width a -> Vector width a flattenRow (Array (MatrixShape.General _ () sh) fptr) = Array sh fptr flattenColumn :: General width () a -> Vector width a flattenColumn (Array (MatrixShape.General _ sh ()) fptr) = Array sh fptr fromRows :: (Shape.C width, Eq width, Storable a) => NonEmpty.T [] (Vector width a) -> General ZeroInt width a fromRows (NonEmpty.Cons row rows) = fromRowsWithSize (Array.shape row) (row:rows) fromRowsWithSize :: (Shape.C width, Eq width, Storable a) => width -> [Vector width a] -> General ZeroInt width a fromRowsWithSize width rows = Array.unsafeCreate (MatrixShape.General RowMajor (zeroInt $ length rows) width) (gather width rows) fromColumns :: (Shape.C height, Eq height, Storable a) => NonEmpty.T [] (Vector height a) -> General height ZeroInt a fromColumns (NonEmpty.Cons column columns) = fromColumnsWithSize (Array.shape column) (column:columns) fromColumnsWithSize :: (Shape.C height, Eq height, Storable a) => height -> [Vector height a] -> General height ZeroInt a fromColumnsWithSize height columns = Array.unsafeCreate (MatrixShape.General ColumnMajor height (zeroInt $ length columns)) (gather height columns) gather :: (Shape.C width, Eq width, Storable a) => width -> [Array width a] -> Ptr a -> IO () gather width rows dstPtr = let widthSize = Shape.size width in forM_ (zip (pointerSeq widthSize dstPtr) rows) $ \(dstRowPtr, Array.Array rowWidth srcFPtr) -> withForeignPtr srcFPtr $ \srcPtr -> do Call.assert "Matrix.fromRows/fromColumns: non-matching vector size" (width == rowWidth) copyArray dstRowPtr srcPtr widthSize pickRow :: (Shape.C height, Shape.C width, Shape.Index height ~ ix, Class.Floating a) => General height width a -> ix -> Vector width a pickRow (Array (MatrixShape.General order height width) x) ix = case order of RowMajor -> pickConsecutive height width x ix ColumnMajor -> pickScattered width height x ix pickColumn :: (Shape.C height, Shape.C width, Shape.Index width ~ ix, Class.Floating a) => General height width a -> ix -> Vector height a pickColumn (Array (MatrixShape.General order height width) x) ix = case order of RowMajor -> pickScattered height width x ix ColumnMajor -> pickConsecutive width height x ix pickConsecutive :: (Shape.C height, Shape.C width, Shape.Index height ~ ix, Class.Floating a) => height -> width -> ForeignPtr a -> ix -> Vector width a pickConsecutive height width x ix = Array.unsafeCreateWithSize width $ \n yPtr -> evalContT $ do let offset = Shape.offset height ix nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr (advancePtr xPtr (n*offset)) incxPtr yPtr incyPtr pickScattered :: (Shape.C height, Shape.C width, Shape.Index width ~ ix, Class.Floating a) => height -> width -> ForeignPtr a -> ix -> Vector height a pickScattered height width x ix = Array.unsafeCreateWithSize height $ \n yPtr -> evalContT $ do let offset = Shape.offset width ix nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint $ Shape.size width incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr (advancePtr xPtr offset) incxPtr yPtr incyPtr takeRows, dropRows :: (Shape.C width, Class.Floating a) => Int -> General ZeroInt width a -> General ZeroInt width a takeRows k (Array (MatrixShape.General order (Shape.ZeroBased heightA) width) a) = let heightB = min k heightA n = Shape.size width in if' (k<0) (error "take: negative number") $ Array.unsafeCreateWithSize (MatrixShape.General order (Shape.ZeroBased heightB) width) $ \blockSize bPtr -> withForeignPtr a $ \aPtr -> case order of RowMajor -> copyBlock blockSize aPtr bPtr ColumnMajor -> copySubMatrix heightB n heightA aPtr heightB bPtr dropRows k0 (Array (MatrixShape.General order (Shape.ZeroBased heightA) width) a) = let k = min k0 heightA heightB = heightA - k n = Shape.size width in if' (k<0) (error "take: negative number") $ Array.unsafeCreateWithSize (MatrixShape.General order (Shape.ZeroBased heightB) width) $ \blockSize bPtr -> withForeignPtr a $ \aPtr -> case order of RowMajor -> copyBlock blockSize (advancePtr aPtr (k*n)) bPtr ColumnMajor -> copySubMatrix heightB n heightA (advancePtr aPtr k) heightB bPtr takeColumns, dropColumns :: (Shape.C height, Class.Floating a) => Int -> General height ZeroInt a -> General height ZeroInt a takeColumns k = transpose . takeRows k . transpose dropColumns k = transpose . dropRows k . transpose -- alternative: laswp reverseRows :: (Shape.C width, Class.Floating a) => General ZeroInt width a -> General ZeroInt width a reverseRows (Array shape@(MatrixShape.General order height width) a) = Array.unsafeCreateWithSize shape $ \blockSize bPtr -> evalContT $ do let n = Shape.size height let m = Shape.size width fwdPtr <- Call.bool True nPtr <- Call.cint n mPtr <- Call.cint m kPtr <- Call.allocaArray n aPtr <- ContT $ withForeignPtr a liftIO $ do copyBlock blockSize aPtr bPtr pokeArray kPtr $ take n $ iterate (subtract 1) $ fromIntegral n case order of RowMajor -> LapackGen.lapmt fwdPtr mPtr nPtr bPtr mPtr kPtr ColumnMajor -> LapackGen.lapmr fwdPtr nPtr mPtr bPtr nPtr kPtr reverseColumns :: (Shape.C height, Class.Floating a) => General height ZeroInt a -> General height ZeroInt a reverseColumns = transpose . reverseRows . transpose fromRowMajor :: (Shape.C height, Shape.C width, Class.Floating a) => Array (height,width) a -> General height width a fromRowMajor (Array (height,width) x) = Array (MatrixShape.General RowMajor height width) x toRowMajor :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> Array (height,width) a toRowMajor (Array (MatrixShape.General order height width) x) = let shape = (height, width) in case order of RowMajor -> Array shape x ColumnMajor -> Array.unsafeCreate shape $ \yPtr -> evalContT $ do let n = Shape.size width let m = Shape.size height nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint m incyPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\xkPtr ykPtr -> BlasGen.copy nPtr xkPtr incxPtr ykPtr incyPtr) (pointerSeq 1 xPtr) (pointerSeq n yPtr) flatten :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> Vector ZeroInt a flatten x = case toRowMajor x of Array shape fptr -> Array (zeroInt $ Shape.size shape) fptr infixl 3 ||| infixl 2 === (|||) :: (Shape.C height, Eq height, Shape.C widtha, Shape.C widthb, Class.Floating a) => General height widtha a -> General height widthb a -> General height (widtha:+:widthb) a (|||) (Array (MatrixShape.General orderA heightA widthA) a) (Array (MatrixShape.General orderB heightB widthB) b) = if heightA /= heightB then error "(|||): mismatching heights" else case (orderA,orderB) of (RowMajor,RowMajor) -> Array.unsafeCreate (MatrixShape.General RowMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let ma = Shape.size widthA let mb = Shape.size widthB let m = ma+mb maPtr <- Call.cint ma mbPtr <- Call.cint mb aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ sequence_ $ take n $ zipWith3 (\akPtr bkPtr ckPtr -> do BlasGen.copy maPtr akPtr incxPtr ckPtr incyPtr BlasGen.copy mbPtr bkPtr incxPtr (ckPtr `advancePtr` ma) incyPtr) (pointerSeq ma aPtr) (pointerSeq mb bPtr) (pointerSeq m cPtr) (RowMajor,ColumnMajor) -> Array.unsafeCreate (MatrixShape.General ColumnMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let ma = Shape.size widthA let mb = Shape.size widthB aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b liftIO $ do copyTransposed n ma aPtr n cPtr copyBlock (n*mb) bPtr (advancePtr cPtr (n*ma)) (ColumnMajor,RowMajor) -> Array.unsafeCreate (MatrixShape.General ColumnMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let ma = Shape.size widthA let mb = Shape.size widthB let volA = n*ma aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b liftIO $ do copyBlock volA aPtr cPtr copyTransposed n mb bPtr n (advancePtr cPtr volA) (ColumnMajor,ColumnMajor) -> Array.unsafeCreate (MatrixShape.General ColumnMajor heightA (widthA:+:widthB)) $ \cPtr -> evalContT $ do let n = Shape.size heightA let na = n * Shape.size widthA let nb = n * Shape.size widthB naPtr <- Call.cint na nbPtr <- Call.cint nb aPtr <- ContT $ withForeignPtr a bPtr <- ContT $ withForeignPtr b incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ do BlasGen.copy naPtr aPtr incxPtr cPtr incyPtr BlasGen.copy nbPtr bPtr incxPtr (cPtr `advancePtr` na) incyPtr (===) :: (Shape.C width, Eq width, Shape.C heighta, Shape.C heightb, Class.Floating a) => General heighta width a -> General heightb width a -> General (heighta:+:heightb) width a (===) a b = transpose (transpose a ||| transpose b) rowSums :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> Vector height a rowSums m = let MatrixShape.General _ _ width = Array.shape m in multiplyVectorUnchecked m (Vector.constant width one) columnSums :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> Vector width a columnSums m = let MatrixShape.General _ height _ = Array.shape m in multiplyVectorUnchecked (transpose m) (Vector.constant height one) scaleRows :: (Shape.C height, Eq height, Shape.C width, Class.Floating a) => Vector height a -> General height width a -> General height width a scaleRows (Array heightX x) (Array shape@(MatrixShape.General order height width) a) = Array.unsafeCreate shape $ \bPtr -> do Call.assert "scaleRows: sizes mismatch" (heightX == height) case order of RowMajor -> evalContT $ do let m = Shape.size height let n = Shape.size width alphaPtr <- Call.alloca nPtr <- Call.cint n xPtr <- ContT $ withForeignPtr x aPtr <- ContT $ withForeignPtr a incaPtr <- Call.cint 1 incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith3 (\xkPtr akPtr bkPtr -> do poke alphaPtr =<< peek xkPtr BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr BlasGen.scal nPtr alphaPtr bkPtr incbPtr) (pointerSeq 1 xPtr) (pointerSeq n aPtr) (pointerSeq n bPtr) ColumnMajor -> evalContT $ do let m = Shape.size width let n = Shape.size height transPtr <- Call.char 'N' nPtr <- Call.cint n klPtr <- Call.cint 0 kuPtr <- Call.cint 0 alphaPtr <- Call.number one xPtr <- ContT $ withForeignPtr x ldxPtr <- Call.cint 1 aPtr <- ContT $ withForeignPtr a incaPtr <- Call.cint 1 betaPtr <- Call.number zero incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\akPtr bkPtr -> BlasGen.gbmv transPtr nPtr nPtr klPtr kuPtr alphaPtr xPtr ldxPtr akPtr incaPtr betaPtr bkPtr incbPtr) (pointerSeq n aPtr) (pointerSeq n bPtr) scaleColumns :: (Shape.C height, Shape.C width, Eq width, Class.Floating a) => Vector width a -> General height width a -> General height width a scaleColumns x = transpose . scaleRows x . transpose