module Numeric.LinearAlgebra.Banded (
UpperMatrix(..),
SymmetricMatrix(..),
choleskyDecompose,
CholeskySolve, choleskySolve,
) where
import qualified Data.Packed.Development as Dev
import qualified Data.Packed.Matrix as Matrix
import Data.Packed.Matrix (Matrix)
import Data.Packed.Vector (Vector)
import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.Marshal.Array as Array
import qualified Foreign.Marshal.Alloc as Alloc
import qualified Foreign.C.String as CStr
import qualified Foreign.C.Types as C
import Foreign.Storable (Storable, peek)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (pure, (<*>))
import Text.Printf (printf)
import System.IO.Unsafe (unsafePerformIO)
foreign import ccall "spbtrf_"
spbtrf ::
Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt ->
Ptr Float -> Ptr C.CInt ->
Ptr C.CInt -> IO ()
foreign import ccall "dpbtrf_"
dpbtrf ::
Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt ->
Ptr Double -> Ptr C.CInt ->
Ptr C.CInt -> IO ()
foreign import ccall "spbtrs_"
spbtrs ::
Ptr C.CChar -> Ptr C.CInt ->
Ptr C.CInt -> Ptr C.CInt ->
Ptr Float -> Ptr C.CInt ->
Ptr Float -> Ptr C.CInt ->
Ptr C.CInt -> IO ()
foreign import ccall "dpbtrs_"
dpbtrs ::
Ptr C.CChar -> Ptr C.CInt ->
Ptr C.CInt -> Ptr C.CInt ->
Ptr Double -> Ptr C.CInt ->
Ptr Double -> Ptr C.CInt ->
Ptr C.CInt -> IO ()
class (Matrix.Element a) => C a where
pbtrf ::
Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt ->
Ptr a -> Ptr C.CInt ->
Ptr C.CInt -> IO ()
pbtrs ::
Ptr C.CChar -> Ptr C.CInt ->
Ptr C.CInt -> Ptr C.CInt ->
Ptr a -> Ptr C.CInt ->
Ptr a -> Ptr C.CInt ->
Ptr C.CInt -> IO ()
instance C Float where
pbtrf = spbtrf
pbtrs = spbtrs
instance C Double where
pbtrf = dpbtrf
pbtrs = dpbtrs
newtype UpperMatrix a = UpperMatrix (Matrix a)
deriving (Show)
newtype SymmetricMatrix a = SymmetricMatrix (Matrix a)
deriving (Show)
choleskyDecompose :: (C a) => SymmetricMatrix a -> UpperMatrix a
choleskyDecompose (SymmetricMatrix a) = unsafePerformIO $ do
u <- cloneVector $ Matrix.flatten a
evalContT $
runChecked "Banded.choleskyDecompose" $
pure pbtrf
<*> string "L"
<*> cint (Matrix.rows a)
<*> cint (Matrix.cols a 1)
<*> vector u
<*> cint (Matrix.cols a)
return $ UpperMatrix $
Dev.matrixFromVector Dev.RowMajor (Matrix.cols a) u
class CholeskySolve c where
choleskySolve :: (C a) => UpperMatrix a -> c a -> c a
instance CholeskySolve Vector where choleskySolve = choleskySolveSingle
instance CholeskySolve Matrix where choleskySolve = choleskySolveMany
choleskySolveSingle :: (C a) => UpperMatrix a -> Vector a -> Vector a
choleskySolveSingle u =
Matrix.flatten . choleskySolveMany u . Matrix.asColumn
choleskySolveMany :: (C a) => UpperMatrix a -> Matrix a -> Matrix a
choleskySolveMany (UpperMatrix u) rhs = unsafePerformIO $ do
when (Matrix.rows u /= Matrix.rows rhs) $ error $
printf "Banded.choleskySolve: number of rows mismatch (%i /= %i)"
(Matrix.rows u) (Matrix.rows rhs)
x <- cloneVector $ Matrix.flatten $ Matrix.trans rhs
evalContT $
runChecked "Banded.choleskySolveMany" $
pure pbtrs
<*> string "L"
<*> cint (Matrix.rows u)
<*> cint (Matrix.cols u 1)
<*> cint (Matrix.cols rhs)
<*> vector (Matrix.flatten u)
<*> cint (Matrix.cols u)
<*> vector x
<*> cint (Matrix.rows rhs)
return $ Dev.matrixFromVector Dev.ColumnMajor (Matrix.cols rhs) x
cloneVector :: (Storable a) => Vector a -> IO (Vector a)
cloneVector v =
evalContT $ do
(vPtr,len) <- vectorLen v
w <- liftIO $ Dev.createVector len
wPtr <- vector w
liftIO $ Array.copyArray wPtr vPtr len
return w
type FortranIO r = ContT r IO
run :: FortranIO r (IO a) -> FortranIO r a
run act = act >>= liftIO
runChecked :: String -> FortranIO r (Ptr C.CInt -> IO a) -> FortranIO r a
runChecked name act = do
info <- alloca
a <- run $ fmap ($info) act
liftIO $ Dev.check name (peek info)
return a
cint :: Int -> FortranIO r (Ptr C.CInt)
cint = ContT . Marshal.with . fromIntegral
alloca :: (Storable a) => FortranIO r (Ptr a)
alloca = ContT Alloc.alloca
string :: String -> FortranIO r (Ptr C.CChar)
string = ContT . CStr.withCString
vectorLen :: (Storable a) => Vector a -> FortranIO r (Ptr a, Int)
vectorLen v =
let (fptr, offset, len) = Dev.unsafeToForeignPtr v
in ContT $ \f ->
withForeignPtr fptr $ \ptr ->
f (Array.advancePtr ptr offset, len)
vector :: (Storable a) => Vector a -> FortranIO r (Ptr a)
vector = fmap fst . vectorLen