{-# LANGUAGE ForeignFunctionInterface #-} module Numeric.LinearAlgebra.Banded ( UpperMatrix(..), SymmetricMatrix(..), choleskyDecompose, CholeskySolve, choleskySolve, ) where import qualified Data.Packed.Development as Dev import qualified Data.Packed.Matrix as Matrix import Data.Packed.Matrix (Matrix) import Data.Packed.Vector (Vector) import qualified Foreign.Marshal.Utils as Marshal import qualified Foreign.Marshal.Array as Array import qualified Foreign.Marshal.Alloc as Alloc import qualified Foreign.C.String as CStr import qualified Foreign.C.Types as C import Foreign.Storable (Storable, peek) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr) import Control.Monad.Trans.Cont (ContT(ContT), runContT) import Control.Monad.IO.Class (liftIO) import Control.Monad (when) import Control.Applicative (pure, (<*>)) import Text.Printf (printf) import System.IO.Unsafe (unsafePerformIO) foreign import ccall "spbtrf_" spbtrf :: Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt -> Ptr Float -> Ptr C.CInt -> Ptr C.CInt -> IO () foreign import ccall "dpbtrf_" dpbtrf :: Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt -> Ptr Double -> Ptr C.CInt -> Ptr C.CInt -> IO () foreign import ccall "spbtrs_" spbtrs :: Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt -> Ptr C.CInt -> Ptr Float -> Ptr C.CInt -> Ptr Float -> Ptr C.CInt -> Ptr C.CInt -> IO () foreign import ccall "dpbtrs_" dpbtrs :: Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt -> Ptr C.CInt -> Ptr Double -> Ptr C.CInt -> Ptr Double -> Ptr C.CInt -> Ptr C.CInt -> IO () class (Matrix.Element a) => C a where pbtrf :: Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt -> Ptr a -> Ptr C.CInt -> Ptr C.CInt -> IO () pbtrs :: Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt -> Ptr C.CInt -> Ptr a -> Ptr C.CInt -> Ptr a -> Ptr C.CInt -> Ptr C.CInt -> IO () instance C Float where pbtrf = spbtrf pbtrs = spbtrs instance C Double where pbtrf = dpbtrf pbtrs = dpbtrs {- | Stores an upper triangular band matrix in the form: > a11 a12 a13 > a22 a23 a24 > a33 a34 a35 > a44 a45 a46 > a55 a56 a57 -} newtype UpperMatrix a = UpperMatrix (Matrix a) deriving (Show) {- | Stores the upper half of a symmetric band matrix in the same layout as 'UpperMatrix'. -} newtype SymmetricMatrix a = SymmetricMatrix (Matrix a) deriving (Show) {- import Numeric.Container ((<>)) let m = Matrix.fromLists [[1,2,0,0],[0,3,4,0],[0,0,5,6],[0,0,0,7 :: Double]] let mm = Matrix.trans m <> m let a = SymmetricMatrix $ Matrix.fromLists [[1,2],[13,12],[41,30],[85,0::Double]] let u = choleskyDecompose a -} choleskyDecompose :: (C a) => SymmetricMatrix a -> UpperMatrix a choleskyDecompose (SymmetricMatrix a) = unsafePerformIO $ do u <- cloneVector $ Matrix.flatten a flip runContT return $ runChecked "Banded.choleskyDecompose" $ pure pbtrf <*> string "L" <*> cint (Matrix.rows a) <*> cint (Matrix.cols a - 1) <*> vector u <*> cint (Matrix.cols a) return $ UpperMatrix $ Dev.matrixFromVector Dev.RowMajor (Matrix.cols a) u class CholeskySolve c where choleskySolve :: (C a) => UpperMatrix a -> c a -> c a instance CholeskySolve Vector where choleskySolve = choleskySolveSingle instance CholeskySolve Matrix where choleskySolve = choleskySolveMany choleskySolveSingle :: (C a) => UpperMatrix a -> Vector a -> Vector a choleskySolveSingle u = Matrix.flatten . choleskySolveMany u . Matrix.asColumn {- let x = Matrix.fromLists [[2,3],[5,7],[11,13],[17,19::Double]] choleskySolveMany u (mm <> x) -} choleskySolveMany :: (C a) => UpperMatrix a -> Matrix a -> Matrix a choleskySolveMany (UpperMatrix u) rhs = unsafePerformIO $ do when (Matrix.rows u /= Matrix.rows rhs) $ error $ printf "Banded.choleskySolve: number of rows mismatch (%i /= %i)" (Matrix.rows u) (Matrix.rows rhs) x <- cloneVector $ Matrix.flatten $ Matrix.trans rhs flip runContT return $ runChecked "Banded.choleskySolveMany" $ pure pbtrs <*> string "L" <*> cint (Matrix.rows u) <*> cint (Matrix.cols u - 1) <*> cint (Matrix.cols rhs) <*> vector (Matrix.flatten u) <*> cint (Matrix.cols u) <*> vector x <*> cint (Matrix.rows rhs) return $ Dev.matrixFromVector Dev.ColumnMajor (Matrix.cols rhs) x cloneVector :: (Storable a) => Vector a -> IO (Vector a) cloneVector v = flip runContT return $ do (vPtr,len) <- vectorLen v w <- liftIO $ Dev.createVector len wPtr <- vector w liftIO $ Array.copyArray wPtr vPtr len return w type FortranIO r = ContT r IO run :: FortranIO r (IO a) -> FortranIO r a run act = act >>= liftIO runChecked :: String -> FortranIO r (Ptr C.CInt -> IO a) -> FortranIO r a runChecked name act = do info <- alloca a <- run $ fmap ($info) act liftIO $ Dev.check name (peek info) return a cint :: Int -> FortranIO r (Ptr C.CInt) cint = ContT . Marshal.with . fromIntegral alloca :: (Storable a) => FortranIO r (Ptr a) alloca = ContT Alloc.alloca string :: String -> FortranIO r (Ptr C.CChar) string = ContT . CStr.withCString vectorLen :: (Storable a) => Vector a -> FortranIO r (Ptr a, Int) vectorLen v = let (fptr, offset, len) = Dev.unsafeToForeignPtr v in ContT $ \f -> withForeignPtr fptr $ \ptr -> f (Array.advancePtr ptr offset, len) vector :: (Storable a) => Vector a -> FortranIO r (Ptr a) vector = fmap fst . vectorLen