module Numeric.LAPACK.Matrix.Triangular.Private where import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor,ColumnMajor), flipOrder, uploFromOrder, triangleSize) import Numeric.LAPACK.Private (pointerSeq, copyToTemp, lacgv, fill, zero) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Foreign.Marshal.Alloc (alloca) import Foreign.Marshal.Array (advancePtr) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Data.Foldable (forM_) diagonalPointers :: (Storable a, Storable ar) => Order -> Int -> Ptr ar -> Ptr a -> [(Ptr ar, Ptr a)] diagonalPointers order n xPtr aPtr = take n $ zip (pointerSeq 1 xPtr) $ scanl advancePtr aPtr $ case order of RowMajor -> iterate pred n ColumnMajor -> iterate succ 2 columnMajorPointers :: (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, ((Ptr a, Ptr a), Ptr a))] columnMajorPointers n fullPtr packedPtr = let ds = iterate succ 1 in take n $ zip ds $ zip (zip (pointerSeq 1 fullPtr) (pointerSeq n fullPtr)) (scanl advancePtr packedPtr ds) rowMajorPointers :: (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, (Ptr a, Ptr a))] rowMajorPointers n fullPtr packedPtr = let ds = iterate pred n in take n $ zip ds $ zip (pointerSeq (n+1) fullPtr) (scanl advancePtr packedPtr ds) forPointers :: [(Int, a)] -> (Ptr CInt -> a -> IO ()) -> IO () forPointers xs act = alloca $ \nPtr -> forM_ xs $ \(d,ptrs) -> do poke nPtr $ fromIntegral d act nPtr ptrs copyTriangleToTemp :: Class.Floating a => Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a) copyTriangleToTemp order n a = do let aSize = triangleSize n apPtr <- copyToTemp aSize a liftIO $ evalContT $ do aSizePtr <- Call.cint aSize incPtr <- Call.cint 1 case order of RowMajor -> liftIO $ lacgv aSizePtr apPtr incPtr ColumnMajor -> return () return apPtr unpackToTemp :: Storable a => (Int -> Ptr a -> Ptr a -> IO ()) -> Int -> ForeignPtr a -> ContT r IO (Ptr a) unpackToTemp f n a = do apPtr <- ContT $ withForeignPtr a aPtr <- Call.allocaArray (n*n) liftIO $ f n apPtr aPtr return aPtr unpack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () unpack order n packedPtr fullPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n ldaPtr <- Call.cint n liftIO $ withInfo $ LapackGen.tpttr uploPtr nPtr packedPtr fullPtr ldaPtr pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () pack order n fullPtr packedPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n ldaPtr <- Call.cint n liftIO $ withInfo $ LapackGen.trttp uploPtr nPtr fullPtr ldaPtr packedPtr unpackZero, _unpackZero :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () _unpackZero order n packedPtr fullPtr = do fill zero (n*n) fullPtr unpack order n packedPtr fullPtr unpackZero order n packedPtr fullPtr = do fillTriangle zero (flipOrder order) n fullPtr unpack order n packedPtr fullPtr fillTriangle :: Class.Floating a => a -> Order -> Int -> Ptr a -> IO () fillTriangle z order n aPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n zPtr <- Call.number z liftIO $ LapackGen.laset uploPtr nPtr nPtr zPtr zPtr aPtr nPtr withInfo :: (Ptr CInt -> IO ()) -> IO () withInfo = alloca