{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Matrix.Triangular ( Triangular, MatrixShape.Uplo(..), Upper, Lower, fromList, autoFromList, lowerFromList, autoLowerFromList, upperFromList, autoUpperFromList, identity, diagonal, getDiagonal, transposeUp, transposeDown, adjointUp, adjointDown, toSquare, multiplyVectorLeft, multiplyVectorRight, square, multiply, multiplySquareLeft, multiplyGeneralLeft, multiplySquareRight, multiplyGeneralRight, ) where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Vector as Vector import Numeric.LAPACK.Matrix.Triangular.Private (diagonalPointers, pack, unpack, unpackZero, unpackToTemp) import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor,ColumnMajor), flipOrder, transposeFromOrder, uploFromOrder, uploOrder) import Numeric.LAPACK.Matrix.Square (Square) import Numeric.LAPACK.Matrix.Private (General, ZeroInt, zeroInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Private (fill, zero, one, copyBlock) import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Data.Foldable (forM_) type Triangular uplo sh = Array (MatrixShape.Triangular uplo sh) type Lower sh = Array (MatrixShape.LowerTriangular sh) type Upper sh = Array (MatrixShape.UpperTriangular sh) transposeUp :: Lower sh a -> Upper sh a transposeUp (Array sh a) = Array (MatrixShape.triangularTransposeUp sh) a transposeDown :: Upper sh a -> Lower sh a transposeDown (Array sh a) = Array (MatrixShape.triangularTransposeDown sh) a adjointUp :: (Shape.C sh, Class.Floating a) => Lower sh a -> Upper sh a adjointUp = Vector.conjugate . transposeUp adjointDown :: (Shape.C sh, Class.Floating a) => Upper sh a -> Lower sh a adjointDown = Vector.conjugate . transposeDown fromList :: (MatrixShape.Uplo uplo, Shape.C sh, Storable a) => Order -> sh -> [a] -> Triangular uplo sh a fromList order sh = Array.fromList (MatrixShape.Triangular MatrixShape.autoUplo order sh) lowerFromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Lower sh a lowerFromList = fromList upperFromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Upper sh a upperFromList = fromList autoFromList :: (MatrixShape.Uplo uplo, Storable a) => Order -> [a] -> Triangular uplo ZeroInt a autoFromList order xs = fromList order (zeroInt $ MatrixShape.triangleExtent "Triangular.autoFromList" $ length xs) xs autoLowerFromList :: (Storable a) => Order -> [a] -> Lower ZeroInt a autoLowerFromList = autoFromList autoUpperFromList :: (Storable a) => Order -> [a] -> Upper ZeroInt a autoUpperFromList = autoFromList toSquare :: (MatrixShape.Uplo uplo, Shape.C sh, Class.Floating a) => Triangular uplo sh a -> Square sh a toSquare (Array (MatrixShape.Triangular uplo order sh) a) = Array.unsafeCreate (MatrixShape.Square order sh) $ \bPtr -> withForeignPtr a $ \aPtr -> unpackZero (uploOrder uplo order) (Shape.size sh) aPtr bPtr identity :: (MatrixShape.Uplo uplo, Shape.C sh, Class.Floating a) => Order -> sh -> Triangular uplo sh a identity order sh = let (realOrder, uplo) = autoUploOrder order in Array.unsafeCreate (MatrixShape.Triangular uplo order sh) $ \aPtr -> do let n = Shape.size sh fill zero (MatrixShape.triangleSize n) aPtr forM_ (diagonalPointers realOrder n aPtr aPtr) $ flip poke one . snd diagonal :: (MatrixShape.Uplo uplo, Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Triangular uplo sh a diagonal order (Array sh x) = let (realOrder, uplo) = autoUploOrder order in Array.unsafeCreate (MatrixShape.Triangular uplo order sh) $ \aPtr -> do let n = Shape.size sh fill zero (MatrixShape.triangleSize n) aPtr withForeignPtr x $ \xPtr -> forM_ (diagonalPointers realOrder n xPtr aPtr) $ \(srcPtr,dstPtr) -> poke dstPtr =<< peek srcPtr getDiagonal :: (MatrixShape.Uplo uplo, Shape.C sh, Class.Floating a) => Triangular uplo sh a -> Vector sh a getDiagonal (Array (MatrixShape.Triangular uplo order sh) a) = Array.unsafeCreate sh $ \xPtr -> do withForeignPtr a $ \aPtr -> mapM_ (\(dstPtr,srcPtr) -> poke dstPtr =<< peek srcPtr) (diagonalPointers (uploOrder uplo order) (Shape.size sh) xPtr aPtr) multiplyVectorLeft, multiplyVectorRight :: (MatrixShape.Uplo uplo, Shape.C sh, Eq sh, Class.Floating a) => Triangular uplo sh a -> Vector sh a -> Vector sh a multiplyVectorLeft = multiplyVector True multiplyVectorRight = multiplyVector False multiplyVector :: (MatrixShape.Uplo uplo, Shape.C sh, Eq sh, Class.Floating a) => Bool -> Triangular uplo sh a -> Vector sh a -> Vector sh a multiplyVector transp (Array (MatrixShape.Triangular uplo order shA) a) (Array shX x) = Array.unsafeCreate shX $ \yPtr -> do Call.assert "Triangular.multiplyVector: width shapes mismatch" (shA == shX) let n = Shape.size shA evalContT $ do uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo order transPtr <- Call.char $ transposeFromOrder $ (if transp then flipOrder else id) order diagPtr <- Call.char 'N' nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a xPtr <- ContT $ withForeignPtr x incyPtr <- Call.cint 1 liftIO $ do copyBlock n xPtr yPtr BlasGen.tpmv uploPtr transPtr diagPtr nPtr aPtr yPtr incyPtr square :: (MatrixShape.Uplo uplo, Shape.C sh, Eq sh, Class.Floating a) => Triangular uplo sh a -> Triangular uplo sh a square (Array shape@(MatrixShape.Triangular uplo order sh) a) = Array.unsafeCreate shape $ \bpPtr -> do let n = Shape.size sh evalContT $ do sidePtr <- Call.char 'L' let realOrder = uploOrder uplo order uploPtr <- Call.char $ uploFromOrder realOrder transPtr <- Call.char 'N' diagPtr <- Call.char 'N' nPtr <- Call.cint n let ldPtr = nPtr aPtr <- unpackToTemp (unpack realOrder) n a bPtr <- unpackToTemp (unpackZero realOrder) n a alphaPtr <- Call.number one liftIO $ do BlasGen.trmm sidePtr uploPtr transPtr diagPtr nPtr nPtr alphaPtr aPtr ldPtr bPtr ldPtr pack realOrder n bPtr bpPtr multiply :: (MatrixShape.Uplo uplo, Shape.C sh, Eq sh, Class.Floating a) => Triangular uplo sh a -> Triangular uplo sh a -> Triangular uplo sh a multiply (Array (MatrixShape.Triangular uploA orderA shA) a) (Array shapeB@(MatrixShape.Triangular uploB orderB shB) b) = Array.unsafeCreate shapeB $ \cpPtr -> do Call.assert "Triangular.multiply: width shapes mismatch" (shA == shB) let n = Shape.size shA evalContT $ do let (side,trans) = case orderB of ColumnMajor -> ('L', orderA) RowMajor -> ('R', flipOrder orderA) sidePtr <- Call.char side let realOrderA = uploOrder uploA orderA let realOrderB = uploOrder uploB orderB uploPtr <- Call.char $ uploFromOrder realOrderA transPtr <- Call.char $ transposeFromOrder trans diagPtr <- Call.char 'N' nPtr <- Call.cint n let ldPtr = nPtr aPtr <- unpackToTemp (unpack realOrderA) n a bPtr <- unpackToTemp (unpackZero realOrderB) n b alphaPtr <- Call.number one liftIO $ do BlasGen.trmm sidePtr uploPtr transPtr diagPtr nPtr nPtr alphaPtr aPtr ldPtr bPtr ldPtr pack realOrderB n bPtr cpPtr multiplySquareLeft :: (MatrixShape.Uplo uplo, Shape.C sh, Eq sh, Class.Floating a) => Square sh a -> Triangular uplo sh a -> Square sh a multiplySquareLeft (Array shapeB@(MatrixShape.Square orderB shB) b) (Array (MatrixShape.Triangular uploA orderA shA) a) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Triangular.multiplySquareLeft: shapes mismatch" (shA == shB) let n = Shape.size shB MatrixShape.caseUplo uploA (multiplyAux MatrixShape.Upper) (multiplyAux MatrixShape.Lower) (flipOrder orderA) n a (flipOrder orderB) n b cPtr multiplyGeneralLeft :: (MatrixShape.Uplo uplo, Shape.C height, Shape.C width, Eq width, Class.Floating a) => General height width a -> Triangular uplo width a -> General height width a multiplyGeneralLeft (Array shapeB@(MatrixShape.General orderB height width) b) (Array (MatrixShape.Triangular uploA orderA shA) a) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Triangular.multiplyGeneralLeft: shapes mismatch" (shA == width) MatrixShape.caseUplo uploA (multiplyAux MatrixShape.Upper) (multiplyAux MatrixShape.Lower) (flipOrder orderA) (Shape.size width) a (flipOrder orderB) (Shape.size height) b cPtr multiplySquareRight :: (MatrixShape.Uplo uplo, Shape.C sh, Eq sh, Class.Floating a) => Triangular uplo sh a -> Square sh a -> Square sh a multiplySquareRight (Array (MatrixShape.Triangular uploA orderA shA) a) (Array shapeB@(MatrixShape.Square orderB shB) b) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Triangular.multiplySquareRight: shapes mismatch" (shA == shB) let n = Shape.size shB multiplyAux uploA orderA n a orderB n b cPtr multiplyGeneralRight :: (MatrixShape.Uplo uplo, Shape.C height, Eq height, Shape.C width, Class.Floating a) => Triangular uplo height a -> General height width a -> General height width a multiplyGeneralRight (Array (MatrixShape.Triangular uploA orderA shA) a) (Array shapeB@(MatrixShape.General orderB height width) b) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Triangular.multiplyGeneralRight: shapes mismatch" (shA == height) multiplyAux uploA orderA (Shape.size height) a orderB (Shape.size width) b cPtr multiplyAux :: (MatrixShape.Uplo uplo, Class.Floating a) => uplo -> Order -> Int -> ForeignPtr a -> Order -> Int -> ForeignPtr a -> Ptr a -> IO () multiplyAux uploA orderA m0 a orderB n0 b cPtr = evalContT $ do let (side,trans,(m,n)) = case orderB of ColumnMajor -> ('L', orderA, (m0,n0)) RowMajor -> ('R', flipOrder orderA, (n0,m0)) sidePtr <- Call.char side let realOrderA = uploOrder uploA orderA uploPtr <- Call.char $ uploFromOrder realOrderA transPtr <- Call.char $ transposeFromOrder trans diagPtr <- Call.char 'N' mPtr <- Call.cint m nPtr <- Call.cint n alphaPtr <- Call.number one aPtr <- unpackToTemp (unpack realOrderA) m0 a ldaPtr <- Call.cint m0 bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint m liftIO $ do copyBlock (m0*n0) bPtr cPtr BlasGen.trmm sidePtr uploPtr transPtr diagPtr mPtr nPtr alphaPtr aPtr ldaPtr cPtr ldbPtr autoUploOrder :: MatrixShape.Uplo uplo => Order -> (Order, uplo) autoUploOrder order = case MatrixShape.autoUplo of uplo -> (uploOrder uplo order, uplo)