{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeApplications #-}
module Control.Distributed.MPI.Storable
(
MPIException(..)
, Comm(..)
, commSelf
, commWorld
, Count(..)
, fromCount
, toCount
, Rank(..)
, anySource
, commRank
, commSize
, fromRank
, rootRank
, toRank
, Status(..)
, Tag(..)
, anyTag
, fromTag
, toTag
, unitTag
, Request
, abort
, mainMPI
, recv
, recv_
, send
, sendrecv
, sendrecv_
, irecv
, isend
, test
, test_
, wait
, wait_
, barrier
, bcastRecv
, bcastSend
, ibarrier
, ibcastRecv
, ibcastSend
) where
import Prelude hiding (init)
import Control.Concurrent
import Control.Exception
import Control.Monad
import Control.Monad.Loops
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Data.Typeable
import Foreign
import Foreign.C.Types
import qualified Foreign.Storable as Storable
import qualified Control.Distributed.MPI as MPI
import Control.Distributed.MPI
( Comm(..)
, commSelf
, commWorld
, Count(..)
, fromCount
, toCount
, Rank(..)
, anySource
, commRank
, commSize
, fromRank
, rootRank
, toRank
, Tag(..)
, anyTag
, fromTag
, toTag
, unitTag
, abort
, barrier
)
type CanSerialize a = Storable.Storable a
serialize :: CanSerialize a => a -> IO B.ByteString
serialize x = do let len = Storable.sizeOf x
ptr <- malloc
Storable.poke ptr x
B.unsafePackMallocCStringLen (castPtr ptr, len)
deserialize :: CanSerialize a => B.ByteString -> IO a
deserialize bs = B.unsafeUseAsCStringLen bs $ \(ptr, _) -> peek (castPtr ptr)
whileNothing :: Monad m => m (Maybe a) -> m () -> m a
whileNothing cond loop = go
where go = do mx <- cond
case mx of
Nothing -> do loop
go
Just x -> return x
newtype MPIException = MPIException String
deriving (Eq, Ord, Read, Show, Typeable)
instance Exception MPIException
mpiAssert :: Bool -> String -> IO ()
mpiAssert cond msg =
do when (not cond) $ throw (MPIException msg)
return ()
data DidInit = DidInit | DidNotInit
initMPI :: IO DidInit
initMPI =
do isInit <- MPI.initialized
if isInit
then return DidNotInit
else do ts <- MPI.initThread MPI.ThreadMultiple
mpiAssert (ts >= MPI.ThreadMultiple)
("MPI.init: Insufficient thread support: requiring " ++
show MPI.ThreadMultiple ++
", but MPI library provided only " ++ show ts)
return DidInit
finalizeMPI :: DidInit -> IO ()
finalizeMPI DidInit =
do isFinalized <- MPI.finalized
if isFinalized
then return ()
else do MPI.finalize
finalizeMPI DidNotInit = return ()
mainMPI :: IO ()
-> IO ()
mainMPI action = bracket initMPI finalizeMPI (\_ -> action)
newtype Request a = Request (MVar (Status, a))
data Status = Status { msgRank :: !Rank
, msgTag :: !Tag
}
deriving (Eq, Ord, Read, Show)
recv :: CanSerialize a
=> Rank
-> Tag
-> Comm
-> IO (Status, a)
recv recvrank recvtag comm =
do status <- whileNothing (MPI.iprobe recvrank recvtag comm) yield
source <- MPI.getSource status
tag <- MPI.getTag status
count <- MPI.getCount status MPI.datatypeByte
let len = MPI.fromCount count
ptr <- mallocBytes len
buffer <- B.unsafePackMallocCStringLen (ptr, len)
req <- MPI.irecv buffer source tag comm
whileM_ (not <$> MPI.test_ req) yield
recvobj <- deserialize buffer
return (Status source tag, recvobj)
recv_ :: CanSerialize a
=> Rank
-> Tag
-> Comm
-> IO a
recv_ recvrank recvtag comm =
snd <$> recv recvrank recvtag comm
send :: CanSerialize a
=> a
-> Rank
-> Tag
-> Comm
-> IO ()
send sendobj sendrank sendtag comm =
do sendbuf <- serialize sendobj
B.unsafeUseAsCStringLen sendbuf $ \_ ->
do req <- MPI.isend sendbuf sendrank sendtag comm
whileM_ (not <$> MPI.test_ req) yield
sendrecv :: (CanSerialize a, CanSerialize b)
=> a
-> Rank
-> Tag
-> Rank
-> Tag
-> Comm
-> IO (Status, b)
sendrecv sendobj sendrank sendtag recvrank recvtag comm =
do recvreq <- irecv recvrank recvtag comm
send sendobj sendrank sendtag comm
wait recvreq
sendrecv_ :: (CanSerialize a, CanSerialize b)
=> a
-> Rank
-> Tag
-> Rank
-> Tag
-> Comm
-> IO b
sendrecv_ sendobj sendrank sendtag recvrank recvtag comm =
snd <$> sendrecv sendobj sendrank sendtag recvrank recvtag comm
irecv :: CanSerialize a
=> Rank
-> Tag
-> Comm
-> IO (Request a)
irecv recvrank recvtag comm =
do result <- newEmptyMVar
_ <- forkIO $
do res <- recv recvrank recvtag comm
putMVar result res
return (Request result)
isend :: CanSerialize a
=> a
-> Rank
-> Tag
-> Comm
-> IO (Request ())
isend sendobj sendrank sendtag comm =
do result <- newEmptyMVar
_ <- forkIO $ do send sendobj sendrank sendtag comm
putMVar result (Status sendrank sendtag, ())
return (Request result)
test :: Request a
-> IO (Maybe (Status, a))
test (Request result) = tryTakeMVar result
test_ :: Request a
-> IO (Maybe a)
test_ req = fmap snd <$> test req
wait :: Request a
-> IO (Status, a)
wait (Request result) = takeMVar result
wait_ :: Request a
-> IO a
wait_ req = snd <$> wait req
bcastRecv :: CanSerialize a
=> Rank
-> Comm
-> IO a
bcastRecv root comm =
do rank <- MPI.commRank comm
mpiAssert (rank /= root) "bcastRecv: expected rank /= root"
lenbuf <- mallocForeignPtr @CLong
lenreq <- MPI.ibcast (lenbuf, 1::Int) root comm
whileM_ (not <$> MPI.test_ lenreq) yield
len <- withForeignPtr lenbuf peek
ptr <- mallocBytes (fromIntegral len)
recvbuf <- B.unsafePackMallocCStringLen (ptr, fromIntegral len)
req <- MPI.ibcast recvbuf root comm
whileM_ (not <$> MPI.test_ req) yield
recvobj <- deserialize recvbuf
return recvobj
bcastSend :: CanSerialize a
=> a
-> Rank
-> Comm
-> IO ()
bcastSend sendobj root comm =
do rank <- MPI.commRank comm
mpiAssert (rank == root) "bcastSend: expected rank == root"
sendbuf <- serialize sendobj
lenbuf <- mallocForeignPtr @CLong
withForeignPtr lenbuf $ \ptr -> poke ptr (fromIntegral (B.length sendbuf))
lenreq <- MPI.ibcast (lenbuf, 1::Int) root comm
whileM_ (not <$> MPI.test_ lenreq) yield
req <- MPI.ibcast sendbuf root comm
whileM_ (not <$> MPI.test_ req) yield
ibcastRecv :: CanSerialize a
=> Rank
-> Comm
-> IO (Request a)
ibcastRecv root comm =
do result <- newEmptyMVar
_ <- forkIO $
do recvobj <- bcastRecv root comm
putMVar result (Status root MPI.anyTag, recvobj)
return (Request result)
ibcastSend :: CanSerialize a
=> a
-> Rank
-> Comm
-> IO (Request ())
ibcastSend sendobj root comm =
do result <- newEmptyMVar
_ <- forkIO $
do bcastSend sendobj root comm
putMVar result (Status root MPI.anyTag, ())
return (Request result)
ibarrier :: Comm
-> IO (Request ())
ibarrier comm =
do result <- newEmptyMVar
_ <- forkIO $
do req <- MPI.ibarrier comm
whileM_ (not <$> MPI.test_ req) yield
putMVar result (Status MPI.anySource MPI.anyTag, ())
return (Request result)