module Control.Monad.MultiPass
(
MultiPass
, MultiPassPrologue
, MultiPassEpilogue
, MultiPassMain, mkMultiPassMain
, PassS(..), PassZ(..)
, MultiPassAlgorithm(..)
, run
, NumThreads(..)
, parallelMP, parallelMP_
, readOnlyST2ToMP
, On(..), Off(..)
, MultiPassBase
, mkMultiPass, mkMultiPassPrologue, mkMultiPassEpilogue
, WrapInstrument, wrapInstrument
, PassNumber
, StepDirection(..)
, ST2ToMP
, UpdateThreadContext
, Instrument(..)
, ThreadContext(..)
, NextThreadContext(..)
, NextGlobalContext(..)
, BackTrack(..)
)
where
import Control.Exception ( assert )
import Control.Monad.State.Strict
import Control.Monad.ST2
import Data.Ix
newtype PassS cont m
= PassS (forall p. Monad p => cont (m p))
newtype PassZ f
= PassZ (forall (tc :: *). f tc)
class MultiPassAlgorithm a b | a -> b where
unwrapMultiPassAlgorithm :: a -> b
newtype On a = On a deriving Functor
instance Monad On where
return x = On x
On x >>= f = f x
data Off (a :: *) = Off deriving Functor
instance Monad Off where
return _ = Off
Off >>= _ = Off
data ArgCons a b
= ArgCons !a !b
data ArgNil
= ArgNil
mapArgCons :: (a -> a') -> (b -> b') -> (ArgCons a b) -> (ArgCons a' b')
mapArgCons f g (ArgCons x y) =
ArgCons (f x) (g y)
data Param i f
= Param !i !f
data StepDirection
= StepForward
| StepReset
| StepBackward
deriving Eq
newtype PassNumber = PassNumber { unwrapPassNumber :: Int }
incrPassNumber :: PassNumber -> PassNumber
incrPassNumber (PassNumber k) =
PassNumber (k+1)
minPassNumber :: PassNumber -> PassNumber -> PassNumber
minPassNumber (PassNumber x) (PassNumber y) =
PassNumber (min x y)
newtype MultiPassBase r w tc a
= MultiPassBase
{ unwrapMultiPassBase
:: ThreadContext r w tc => StateT tc (ST2 r w) a
}
deriving Functor
instance Monad (MultiPassBase r w tc) where
return x = MultiPassBase $ return x
MultiPassBase m >>= f =
MultiPassBase $
do x <- m
unwrapMultiPassBase (f x)
newtype MultiPass r w tc a
= MultiPass
{ unwrapMultiPass :: MultiPassBase r w tc a
}
deriving Functor
instance Monad (MultiPass r w tc) where
return x = MultiPass $ return x
MultiPass m >>= f =
MultiPass $
do x <- m
unwrapMultiPass (f x)
mkMultiPass :: MultiPassBase r w tc a -> MultiPass r w tc a
mkMultiPass =
MultiPass
newtype MultiPassPrologue r w tc a
= MultiPassPrologue
{ unwrapMultiPassPrologue :: MultiPassBase r w tc a
}
deriving Functor
instance Monad (MultiPassPrologue r w tc) where
return x = MultiPassPrologue $ return x
MultiPassPrologue m >>= f =
MultiPassPrologue $
do x <- m
unwrapMultiPassPrologue (f x)
mkMultiPassPrologue
:: MultiPassBase r w tc a -> MultiPassPrologue r w tc a
mkMultiPassPrologue =
MultiPassPrologue
newtype MultiPassEpilogue r w tc a
= MultiPassEpilogue
{ unwrapMultiPassEpilogue :: MultiPassBase r w tc a
}
deriving Functor
instance Monad (MultiPassEpilogue r w tc) where
return x = MultiPassEpilogue $ return x
MultiPassEpilogue m >>= f =
MultiPassEpilogue $
do x <- m
unwrapMultiPassEpilogue (f x)
mkMultiPassEpilogue
:: MultiPassBase r w tc a -> MultiPassEpilogue r w tc a
mkMultiPassEpilogue =
MultiPassEpilogue
data MultiPassMain r w tc c =
forall a b.
MultiPassMain
!(MultiPassPrologue r w tc a)
!(a -> MultiPass r w tc b)
!(b -> MultiPassEpilogue r w tc c)
mkMultiPassMain
:: MultiPassPrologue r w tc a
-> (a -> MultiPass r w tc b)
-> (b -> MultiPassEpilogue r w tc c)
-> MultiPassMain r w tc c
mkMultiPassMain prologue body epilogue =
MultiPassMain prologue body epilogue
runMultiPassMain
:: ThreadContext r w tc
=> MultiPassMain r w tc a
-> tc
-> ST2 r w (a, tc)
runMultiPassMain (MultiPassMain prologue body epilogue) =
runStateT $
do x <- unwrapMultiPassBase $ unwrapMultiPassPrologue $ prologue
y <- unwrapMultiPassBase $ unwrapMultiPass $ body x
unwrapMultiPassBase $ unwrapMultiPassEpilogue $ epilogue y
class ThreadContext r w tc where
splitThreadContext
:: Int
-> Int
-> tc
-> ST2 r w tc
mergeThreadContext
:: Int
-> (Int -> ST2 r w tc)
-> tc
-> ST2 r w tc
instance ThreadContext r w () where
splitThreadContext _ _ () = return ()
mergeThreadContext _ _ () = return ()
instance ThreadContext r w ArgNil where
splitThreadContext _ _ ArgNil = return ArgNil
mergeThreadContext _ _ ArgNil = return ArgNil
instance (ThreadContext r w x, ThreadContext r w y) =>
ThreadContext r w (ArgCons x y) where
splitThreadContext m t (ArgCons x y) =
do x' <- splitThreadContext m t x
y' <- splitThreadContext m t y
return (ArgCons x' y')
mergeThreadContext m getSubContext (ArgCons x y) =
let getSubContextL tc =
do ArgCons tc' _ <- getSubContext tc
return tc'
in
let getSubContextR tc =
do ArgCons _ tc' <- getSubContext tc
return tc'
in
do x' <- mergeThreadContext m getSubContextL x
y' <- mergeThreadContext m getSubContextR y
return (ArgCons x' y')
instance (ThreadContext r w x, ThreadContext r w y) =>
ThreadContext r w (x,y) where
splitThreadContext m t (x,y) =
do x' <- splitThreadContext m t x
y' <- splitThreadContext m t y
return (x', y')
mergeThreadContext m getSubContext (x,y) =
let getSubContextL tc =
do (tc',_) <- getSubContext tc
return tc'
in
let getSubContextR tc =
do (_,tc') <- getSubContext tc
return tc'
in
do x' <- mergeThreadContext m getSubContextL x
y' <- mergeThreadContext m getSubContextR y
return (x',y')
instance ( ThreadContext r w x
, ThreadContext r w y
, ThreadContext r w z
) =>
ThreadContext r w (x,y,z) where
splitThreadContext m t (x,y,z) =
do x' <- splitThreadContext m t x
y' <- splitThreadContext m t y
z' <- splitThreadContext m t z
return (x', y', z')
mergeThreadContext m getSubContext (x,y,z) =
let getSubContext1 tc =
do (tc',_,_) <- getSubContext tc
return tc'
in
let getSubContext2 tc =
do (_,tc',_) <- getSubContext tc
return tc'
in
let getSubContext3 tc =
do (_,_,tc') <- getSubContext tc
return tc'
in
do x' <- mergeThreadContext m getSubContext1 x
y' <- mergeThreadContext m getSubContext2 y
z' <- mergeThreadContext m getSubContext3 z
return (x',y',z')
instance (ThreadContext r w x, ThreadContext r w y) =>
ThreadContext r w (Either x y) where
splitThreadContext m t e =
case e of
Left x
-> do x' <- splitThreadContext m t x
return (Left x')
Right y
-> do y' <- splitThreadContext m t y
return (Right y')
mergeThreadContext m getSubContext e =
let getSubContextL tc =
do Left tc' <- getSubContext tc
return tc'
in
let getSubContextR tc =
do Right tc' <- getSubContext tc
return tc'
in
case e of
Left tc
-> do tc' <- mergeThreadContext m getSubContextL tc
return (Left tc')
Right tc
-> do tc' <- mergeThreadContext m getSubContextR tc
return (Right tc')
class Instrument rootTC tc gc instr | instr -> tc gc where
createInstrument
:: ST2ToMP rootTC
-> UpdateThreadContext rootTC tc
-> gc
-> WrapInstrument instr
newtype WrapInstrument instr
= WrapInstrument instr
deriving Functor
instance Monad WrapInstrument where
return x = WrapInstrument x
WrapInstrument x >>= f = f x
wrapInstrument :: instr -> WrapInstrument instr
wrapInstrument = WrapInstrument
type ST2ToMP tc
= forall r w a. ST2 r w a -> MultiPassBase r w tc a
type UpdateThreadContext tc tc'
= forall r w. (tc' -> tc') -> MultiPassBase r w tc tc'
updateCtxArgL
:: UpdateThreadContext rootTC (ArgCons tc tcs)
-> UpdateThreadContext rootTC tc
updateCtxArgL updateCtx h =
do ArgCons x _ <- updateCtx (mapArgCons h id)
return x
updateCtxArgR
:: UpdateThreadContext rootTC (ArgCons tc tcs)
-> UpdateThreadContext rootTC tcs
updateCtxArgR updateCtx h =
do ArgCons _ y <- updateCtx (mapArgCons id h)
return y
class ApplyArg r w param instr f oldTC oldGC tc gc rootTC f'
| f -> f' tc gc where
applyArg
:: PassNumber
-> StepDirection
-> param
-> (instr -> f)
-> UpdateThreadContext rootTC tc
-> oldTC
-> oldGC
-> ST2 r w (f', tc, gc)
instance ( ApplyArgs r w f oldTCs oldGCs tcs gcs rootTC f'
, NextThreadContext r w oldTC oldGC tc
, NextGlobalContext r w oldTC oldGC gc
, Instrument rootTC tc gc instr
) =>
ApplyArg r w param instr f
(ArgCons oldTC oldTCs) (ArgCons oldGC oldGCs)
(ArgCons tc tcs) (ArgCons gc gcs)
rootTC f' where
applyArg n d _ f updateCtx
(ArgCons oldTC oldTCs) (ArgCons oldGC oldGCs) =
do gc <- nextGlobalContext n d oldTC oldGC
tc <- nextThreadContext n d oldTC oldGC
let st2ToMP m = MultiPassBase $ lift m
let WrapInstrument instr =
createInstrument st2ToMP (updateCtxArgL updateCtx) gc
(f', tcs, gcs) <-
applyArgs n d (f instr) (updateCtxArgR updateCtx) oldTCs oldGCs
return (f', ArgCons tc tcs, ArgCons gc gcs)
class ApplyArgs r w f oldTC oldGC tc gc rootTC f' | f -> f' tc gc where
applyArgs
:: PassNumber
-> StepDirection
-> f
-> UpdateThreadContext rootTC tc
-> oldTC
-> oldGC
-> ST2 r w (f', tc, gc)
instance ApplyArg r w () instr f oldTC oldGC tc gc rootTC f' =>
ApplyArgs r w (instr -> f) oldTC oldGC tc gc rootTC f' where
applyArgs n d f updateCtx oldTC oldGC =
applyArg n d () f updateCtx oldTC oldGC
instance ApplyArg r w param instr f oldTC oldGC tc gc rootTC f' =>
ApplyArgs r w (Param param (instr -> f)) oldTC oldGC
tc gc rootTC f' where
applyArgs n d (Param param f) updateCtx oldTC oldGC =
applyArg n d param f updateCtx oldTC oldGC
instance ApplyArgs r w (MultiPassMain r w rootTC a)
ArgNil ArgNil ArgNil ArgNil
rootTC (MultiPassMain r w rootTC a) where
applyArgs _ _ f _ ArgNil ArgNil =
return (f, ArgNil, ArgNil)
class InitCtx ctx where
initCtx :: ctx
instance InitCtx () where
initCtx = ()
instance InitCtx ArgNil where
initCtx = ArgNil
instance (InitCtx a , InitCtx b) =>
InitCtx (ArgCons a b) where
initCtx = ArgCons initCtx initCtx
class NextThreadContext r w tc gc tc' where
nextThreadContext
:: PassNumber
-> StepDirection
-> tc
-> gc
-> ST2 r w tc'
instance NextThreadContext r w tc gc () where
nextThreadContext _ _ _ _ = return ()
instance ( NextThreadContext r w x gc x'
, NextThreadContext r w y gc y'
) =>
NextThreadContext r w (x,y) gc (x',y') where
nextThreadContext n d (x,y) gc =
do x' <- nextThreadContext n d x gc
y' <- nextThreadContext n d y gc
return (x',y')
instance ( NextThreadContext r w () gc x
, NextThreadContext r w () gc y
) =>
NextThreadContext r w () gc (x,y) where
nextThreadContext n d () gc =
do x <- nextThreadContext n d () gc
y <- nextThreadContext n d () gc
return (x,y)
instance ( NextThreadContext r w x gc x'
, NextThreadContext r w y gc y'
, NextThreadContext r w z gc z'
) =>
NextThreadContext r w (x,y,z) gc (x',y',z') where
nextThreadContext n d (x,y,z) gc =
do x' <- nextThreadContext n d x gc
y' <- nextThreadContext n d y gc
z' <- nextThreadContext n d z gc
return (x',y',z')
instance ( NextThreadContext r w () gc x
, NextThreadContext r w () gc y
, NextThreadContext r w () gc z
) =>
NextThreadContext r w () gc (x,y,z) where
nextThreadContext n d () gc =
do x <- nextThreadContext n d () gc
y <- nextThreadContext n d () gc
z <- nextThreadContext n d () gc
return (x,y,z)
instance ( NextThreadContext r w x gc x'
, NextThreadContext r w y gc y'
) =>
NextThreadContext r w (Either x y) gc (Either x' y') where
nextThreadContext n d e gc =
case e of
Left x
-> do x' <- nextThreadContext n d x gc
return (Left x')
Right y
-> do y' <- nextThreadContext n d y gc
return (Right y')
class NextGlobalContext r w tc gc gc' where
nextGlobalContext
:: PassNumber
-> StepDirection
-> tc
-> gc
-> ST2 r w gc'
instance NextGlobalContext r w tc gc () where
nextGlobalContext _ _ _ _ = return ()
instance ( NextGlobalContext r w tc x x'
, NextGlobalContext r w tc y y'
) =>
NextGlobalContext r w tc (x,y) (x',y') where
nextGlobalContext n d tc (x,y) =
do x' <- nextGlobalContext n d tc x
y' <- nextGlobalContext n d tc y
return (x',y')
instance ( NextGlobalContext r w tc x x'
, NextGlobalContext r w tc y y'
, NextGlobalContext r w tc z z'
) =>
NextGlobalContext r w tc (x,y,z) (x',y',z') where
nextGlobalContext n d tc (x,y,z) =
do x' <- nextGlobalContext n d tc x
y' <- nextGlobalContext n d tc y
z' <- nextGlobalContext n d tc z
return (x',y',z')
instance ( NextGlobalContext r w tc x x'
, NextGlobalContext r w tc y y'
) =>
NextGlobalContext r w tc (Either x y) (Either x' y') where
nextGlobalContext n d tc e =
case e of
Left x
-> do x' <- nextGlobalContext n d tc x
return (Left x')
Right y
-> do y' <- nextGlobalContext n d tc y
return (Right y')
class InstantiatePasses a b | a -> b where
instantiatePasses :: a -> PassZ b
instance InstantiatePasses (PassZ a) a where
instantiatePasses (PassZ x) = PassZ x
instance InstantiatePasses (cont (m Off)) b =>
InstantiatePasses (PassS cont m) b where
instantiatePasses (PassS f) =
instantiatePasses (f :: cont (m Off))
class BackTrack r w tc gc where
backtrack :: tc -> gc -> ST2 r w (Maybe PassNumber)
backtrack _ _ = return Nothing
instance BackTrack r w tc ()
instance BackTrack r w ArgNil ArgNil
instance (BackTrack r w tc gc, BackTrack r w tcs gcs) =>
BackTrack r w (ArgCons tc tcs) (ArgCons gc gcs) where
backtrack (ArgCons tc tcs) (ArgCons gc gcs) =
do mx <- backtrack tc gc
my <- backtrack tcs gcs
case (mx,my) of
(Nothing, Nothing) -> return Nothing
(Nothing, Just y) -> return (Just y)
(Just x, Nothing) -> return (Just x)
(Just x, Just y) -> return (Just (minPassNumber x y))
class RunPasses r w f tc gc p out where
runPasses
:: PassNumber -> f -> p out -> tc -> gc
-> ST2 r w
(Either
( PassNumber
, MultiPassMain r w tc (p out)
, tc
, gc
)
out)
instance RunPasses r w (PassZ f) tc gc On out where
runPasses _ _ (On out) _ _ =
return (Right out)
instance ( InstantiatePasses (cont (f Off)) fPrev
, MultiPassAlgorithm (fPrev tc0) gPrev
, InstantiatePasses (cont (f On)) fCurr
, MultiPassAlgorithm (fCurr tc1) gCurr
, ApplyArgs r w gCurr tc0 gc0 tc1 gc1 tc1
(MultiPassMain r w tc1 (p out))
, ApplyArgs r w gCurr tc1 gc1 tc1 gc1 tc1
(MultiPassMain r w tc1 (p out))
, ApplyArgs r w gPrev tc1 gc1 tc0 gc0 tc0
(MultiPassMain r w tc0 (q out))
, ThreadContext r w tc1
, BackTrack r w tc1 gc1
, RunPasses r w (cont (f On)) tc1 gc1 p out
) =>
RunPasses r w (PassS cont f) tc0 gc0 q out where
runPasses n fBox _ =
let PassS (fPrev :: cont (f Off)) = fBox in
let PassS (fCurr :: cont (f On)) = fBox in
let
loop g tc gc =
do (result, tc') <- runMultiPassMain g tc
mb <- backtrack tc' gc
case mb of
Nothing
->
let n' = incrPassNumber n in
do e <- runPasses n' fCurr result tc' gc
case e of
Left info -> rewind info
Right out -> return (Right out)
Just m
-> stepReset m tc' gc
rewind (m,g,tc,gc) =
assert (unwrapPassNumber m <= unwrapPassNumber n) $
if unwrapPassNumber m == unwrapPassNumber n
then loop g tc gc
else stepBackward m tc gc
stepReset m tc gc =
let PassZ f' = instantiatePasses fCurr in
let g = unwrapMultiPassAlgorithm (f' :: fCurr tc1) in
do (g', tc', gc') <-
applyArgs n StepReset g updateThreadContextTop tc gc
rewind (m,g',tc',gc')
stepBackward m tc gc =
let PassZ f' = instantiatePasses fPrev in
let g = unwrapMultiPassAlgorithm (f' :: fPrev tc0) in
do (g', tc', gc') <-
applyArgs n StepBackward g updateThreadContextTop tc gc
return (Left (m,g',tc',gc'))
in
let loopStart tc gc =
let PassZ f' = instantiatePasses fCurr in
let g = unwrapMultiPassAlgorithm (f' :: fCurr tc1) in
do (g', tc', gc') <-
applyArgs n StepForward g updateThreadContextTop tc gc
loop g' tc' gc'
in
loopStart
updateThreadContextTop :: UpdateThreadContext tc tc
updateThreadContextTop f =
MultiPassBase $
do tc <- get
put (f tc)
return tc
run
:: forall r w f f' g tc gc out.
( InstantiatePasses f f'
, MultiPassAlgorithm (f' tc) g
, ApplyArgs r w g tc gc tc gc tc
(MultiPassMain r w tc (Off out))
, InitCtx tc
, InitCtx gc
, RunPasses r w f tc gc Off out
)
=> f
-> ST2 r w out
run f =
let tc = initCtx :: tc in
let gc = initCtx :: gc in
do e <- runPasses (PassNumber 0) f Off tc gc
case e of
Left _
->
assert False $ error "run"
Right result
-> return result
newtype NumThreads
= NumThreads Int
parallelMP
:: (Ix i, Num i)
=> NumThreads
-> (i,i)
-> (i -> MultiPass r w tc a)
-> MultiPass r w tc (ST2Array r w i a)
parallelMP (NumThreads m) bnds f =
let n = rangeSize bnds in
assert (m > 0) $
if m == 1 || n <= 1
then
do xs <- MultiPass $ MultiPassBase $ lift $ newST2Array_ bnds
sequence_
[ do x <- f i
MultiPass $ MultiPassBase $ lift $
writeST2Array xs i x
| i <- range bnds
]
return xs
else assert (m > 1) $
assert (n > 1) $
parallelHelper (min m n) n bnds f
parallelHelper
:: (Ix i, Num i)
=> Int
-> Int
-> (i,i)
-> (i -> MultiPass r w tc a)
-> MultiPass r w tc (ST2Array r w i a)
parallelHelper m n bnds f =
MultiPass $ MultiPassBase $
do tc <- get
let tBnds = (0,m1)
tcs <- lift $ newST2Array_ tBnds
lift $ sequence_
[ do tci <- splitThreadContext m t tc
writeST2Array tcs t tci
| t <- range tBnds
]
xs <- lift $ newST2Array_ bnds
let base = fst bnds
let blockSize = (n+m1) `div` m
lift $ parallelST2 tBnds $ \i ->
do tci <- readST2Array tcs i
let start = i * blockSize
let end = min n (start + blockSize)
tci' <-
flip execStateT tci $
sequence_
[ let j' = base + fromIntegral j in
do x <- unwrapMultiPassBase $ unwrapMultiPass $ f j'
lift $ writeST2Array xs j' x
| j <- [start .. end1]
]
writeST2Array tcs i tci'
tc' <- lift $ mergeThreadContext m (readST2Array tcs) tc
put tc'
return xs
parallelMP_
:: (Ix i, Num i)
=> NumThreads
-> (i,i)
-> (i -> MultiPass r w tc a)
-> MultiPass r w tc ()
parallelMP_ (NumThreads m) bnds f =
let n = rangeSize bnds in
assert (m > 0) $
if m == 1 || n <= 1
then
sequence_ [ f i | i <- range bnds ]
else assert (m > 1) $
assert (n > 1) $
parallelHelper_ (min m n) n bnds f
parallelHelper_
:: (Ix i, Num i)
=> Int
-> Int
-> (i,i)
-> (i -> MultiPass r w tc a)
-> MultiPass r w tc ()
parallelHelper_ m n bnds f =
MultiPass $ MultiPassBase $
do tc <- get
let tBnds = (0,m1)
tcs <- lift $ newST2Array_ tBnds
lift $ sequence_
[ do tci <- splitThreadContext m t tc
writeST2Array tcs t tci
| t <- range tBnds
]
let base = fst bnds
let blockSize = (n+m1) `div` m
lift $ parallelST2 tBnds $ \i ->
do tci <- readST2Array tcs i
let start = i * blockSize
let end = min n (start + blockSize)
tci' <-
flip execStateT tci $
sequence_
[ let j' = base + fromIntegral j in
unwrapMultiPassBase $ unwrapMultiPass $ f j'
| j <- [start .. end1]
]
writeST2Array tcs i tci'
tc' <- lift $ mergeThreadContext m (readST2Array tcs) tc
put tc'
readOnlyST2ToMP :: (forall w. ST2 r w a) -> MultiPass r w' tc a
readOnlyST2ToMP m =
MultiPass $ MultiPassBase $
lift m