module Numeric.Netlib.CArray.Utility where import qualified Data.Array.CArray as CArray import Data.Array.IOCArray (IOCArray, withIOCArray) import Data.Array.CArray (CArray, withCArray, Ix) import qualified Foreign.Marshal.Utils as Marshal import qualified Foreign.Marshal.Array as Array import qualified Foreign.Marshal.Alloc as Alloc import qualified Foreign.C.String as CStr import qualified Foreign.C.Types as C import Foreign.Storable.Complex () import Foreign.Storable (Storable, peek) import Foreign.Ptr (Ptr) import Control.Monad.Trans.Cont (ContT(ContT)) import Control.Monad.IO.Class (liftIO) import Control.Monad (when) import Control.Applicative ((<$>)) import Data.Complex (Complex) type FortranIO r = ContT r IO run :: FortranIO r (IO a) -> FortranIO r a run act = act >>= liftIO runChecked :: String -> FortranIO r (Ptr C.CInt -> IO a) -> FortranIO r a runChecked name act = do info <- alloca a <- run $ fmap ($info) act liftIO $ check name (peek info) return a check :: String -> IO C.CInt -> IO () check msg f = do err <- f when (err/=0) $ error $ msg ++ ": " ++ show err assert :: String -> Bool -> IO () assert msg success = when (not success) $ error $ "assertion failed: " ++ msg ignore :: String -> Int -> IO () ignore _msg _dim = return () newArray :: (Ix i, Storable e) => (i, i) -> IO (CArray i e) newArray bnds = CArray.createCArray bnds (\_ -> return ()) newArray1 :: (Storable e) => Int -> IO (CArray Int e) newArray1 m = newArray (0, m-1) newArray2 :: (Storable e) => Int -> Int -> IO (CArray (Int,Int) e) newArray2 m n = newArray ((0,0), (m-1,n-1)) newArray3 :: (Storable e) => Int -> Int -> Int -> IO (CArray (Int,Int,Int) e) newArray3 m n k = newArray ((0,0,0), (m-1,n-1,k-1)) sizes1 :: (Ix i) => (i,i) -> Int sizes1 = CArray.rangeSize sizes2 :: (Ix i, Ix j) => ((i,j),(i,j)) -> (Int,Int) sizes2 ((i0,j0), (i1,j1)) = (CArray.rangeSize (i0,i1), CArray.rangeSize (j0,j1)) sizes3 :: (Ix i, Ix j, Ix k) => ((i,j,k),(i,j,k)) -> (Int,Int,Int) sizes3 ((i0,j0,k0), (i1,j1,k1)) = (CArray.rangeSize (i0,i1), CArray.rangeSize (j0,j1), CArray.rangeSize (k0,k1)) cint :: Int -> FortranIO r (Ptr C.CInt) cint = ContT . Marshal.with . fromIntegral range :: (Int,Int) -> FortranIO r (Ptr C.CInt) range = cint . CArray.rangeSize alloca :: (Storable a) => FortranIO r (Ptr a) alloca = ContT Alloc.alloca allocaArray :: (Storable a) => Int -> FortranIO r (Ptr a) allocaArray = ContT . Array.allocaArray bool :: Bool -> FortranIO r (Ptr Bool) bool = ContT . Marshal.with char :: Char -> FortranIO r (Ptr C.CChar) char = ContT . Marshal.with . CStr.castCharToCChar string :: String -> FortranIO r (Ptr C.CChar) string = ContT . CStr.withCString float :: Float -> FortranIO r (Ptr Float) float = ContT . Marshal.with double :: Double -> FortranIO r (Ptr Double) double = ContT . Marshal.with complexFloat :: Complex Float -> FortranIO r (Ptr (Complex Float)) complexFloat = ContT . Marshal.with complexDouble :: Complex Double -> FortranIO r (Ptr (Complex Double)) complexDouble = ContT . Marshal.with array :: (Storable a) => CArray i a -> FortranIO r (Ptr a) array = ContT . withCArray arrayBounds :: (Storable a, Ix i) => CArray i a -> FortranIO r (Ptr a, (i,i)) arrayBounds v = flip (,) (CArray.bounds v) <$> array v ioarray :: (Storable a) => IOCArray i a -> FortranIO r (Ptr a) ioarray = ContT . withIOCArray unzipBounds :: ((i,j),(i,j)) -> ((i,i), (j,j)) unzipBounds ((i0,j0), (i1,j1)) = ((i0,i1), (j0,j1)) (^!) :: (Num a) => a -> Int -> a x^!n = x^n