{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Linear.Triangular ( solve, inverse, ) where import Numeric.LAPACK.Matrix.Triangular (Triangular) import Numeric.LAPACK.Matrix (General) import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import Numeric.LAPACK.Matrix.Shape.Private (Order(ColumnMajor), transposeFromOrder, uploFromOrder, uploOrder, triangleSize) import Numeric.LAPACK.Private (copyBlock, copyToTemp, copyToColumnMajor) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (peek) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative ((<$>)) import Text.Printf (printf) solve :: (MatrixShape.Uplo uplo, Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) => Triangular uplo sh a -> General sh nrhs a -> General sh nrhs a solve (Array (MatrixShape.Triangular uplo orderA shA) a) (Array (MatrixShape.General orderB heightB widthB) b) = Array.unsafeCreate (MatrixShape.General ColumnMajor heightB widthB) $ \xPtr -> do Call.assert "Triangular.solve: height shapes mismatch" (shA == heightB) let n = Shape.size heightB let nrhs = Shape.size widthB let ldb = n evalContT $ do uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo orderA transPtr <- Call.char $ transposeFromOrder orderA diagPtr <- Call.char 'N' nPtr <- Call.cint n nrhsPtr <- Call.cint nrhs apPtr <- copyToTemp (triangleSize n) a bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint ldb liftIO $ do copyToColumnMajor orderB n nrhs bPtr xPtr withInfo "tptrs" $ LapackGen.tptrs uploPtr transPtr diagPtr nPtr nrhsPtr apPtr xPtr ldbPtr inverse :: (MatrixShape.Uplo uplo, Shape.C sh, Class.Floating a) => Triangular uplo sh a -> Triangular uplo sh a inverse (Array shape@(MatrixShape.Triangular uplo order sh) a) = Array.unsafeCreateWithSize shape $ \triSize bPtr -> do evalContT $ do uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo order diagPtr <- Call.char 'N' nPtr <- Call.cint $ Shape.size sh aPtr <- ContT $ withForeignPtr a liftIO $ do copyBlock triSize aPtr bPtr withInfo "tptri" $ LapackGen.tptri uploPtr diagPtr nPtr bPtr withInfo :: String -> (Ptr CInt -> IO ()) -> IO () withInfo name computation = alloca $ \infoPtr -> do computation infoPtr info <- fromIntegral <$> peek infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf "%s: illegal value in %d-th argument" name (-info) GT -> error $ printf "%s: %d-th diagonal element zero" name info