-- Copyright 2013 Kevin Backhouse.

{-# OPTIONS_GHC -XPolyKinds -XKindSignatures -XScopedTypeVariables #-}

{-|

This module implements the core functions, datatypes, and classes of
the MultiPass library. Its export list is divided into two halves. The
first half contains the declarations which are relevant to anyone who
wants to use the MultiPass library. The second contains which are only
relevant to people who want to implement new instruments.

-}

module Control.Monad.MultiPass
  ( -- * Users
    MultiPass
  , MultiPassPrologue
  , MultiPassEpilogue
  , MultiPassMain, mkMultiPassMain
  , PassS(..), PassZ(..)
  , MultiPassAlgorithm(..)
  , run
  , NumThreads(..)
  , parallelMP, parallelMP_
  , readOnlyST2ToMP

    -- * Instrument Authors
  , On(..), Off(..)
  , MultiPassBase
  , mkMultiPass, mkMultiPassPrologue, mkMultiPassEpilogue
  , WrapInstrument, wrapInstrument
  , PassNumber
  , StepDirection(..)
  , ST2ToMP
  , UpdateThreadContext
  , Instrument(..)
  , ThreadContext(..)
  , NextThreadContext(..)
  , NextGlobalContext(..)
  , BackTrack(..)
  )
where

import Control.Exception ( assert )
import Control.Monad.State.Strict
import Control.Monad.ST2
import Data.Ix

-- | This datatype is used in conjunction with 'PassZ' to package the
-- main function of the multi-pass algorithm. For an example of how
-- they are used, see the implementation of
-- 'Control.Monad.MultiPass.Example.Repmin.repminMP' or any of the
-- other examples in the Example directory.
newtype PassS cont m
  = PassS (forall p. Monad p => cont (m p))

-- | Used in conjunction with 'PassS' to build a Peano number
-- corresponding to the number of passes.
newtype PassZ f
  = PassZ (forall (tc :: *). f tc)

-- | The main function of a multi-pass algorithm needs to be wrapped
-- in a newtype so that it can be packaged with 'PassS' and
-- 'PassZ'. The newtype needs to be made an instance of
-- 'MultiPassAlgorithm' so that it can unwrapped by the
-- implementation.
class MultiPassAlgorithm a b | a -> b where
  unwrapMultiPassAlgorithm :: a -> b

-- | Trivial monad, equivalent to 'Data.Functor.Identity.Identity'.
-- Used to switch on a pass of a multi-pass algorithm.
newtype On a = On a deriving Functor

instance Monad On where
  return x = On x
  On x >>= f = f x

-- | Trivial monad which computes absolutely nothing. It is used to
-- switch off a pass of a multi-pass algorithm.
data Off (a :: *) = Off deriving Functor

instance Monad Off where
  return _ = Off
  Off >>= _ = Off

-- ArgCons and ArgNil are used to uncurry the main function of the
-- multi-pass algorithm. For example, a function of the following
-- type:
--
--     Instrument1 -> Instrument2 -> MultiPass r w tc a
--
-- gets converted to a function of type:
--
--     ArgCons Instrument1 (ArgCons Instrument2 ArgNil) ->
--     MultiPass r w tc a
--
-- The uncurrying is implemented in the ApplyArg and ApplyArgs
-- classes.
--
-- ArgCons and ArgNil are not exported from this module.
data ArgCons a b
  = ArgCons !a !b

data ArgNil
  = ArgNil

mapArgCons :: (a -> a') -> (b -> b') -> (ArgCons a b) -> (ArgCons a' b')
mapArgCons f g (ArgCons x y) =
  ArgCons (f x) (g y)

-- The Param type is the old solution to the problem of passing
-- initial parameters to instruments. The MultiPassPrologue seems to
-- be a better solution to this problem, so the Param type has been
-- removed from the external interface. However, all the internal
-- plumbing is still there (in ApplyArg and ApplyArgs), so it would be
-- easy to resurrect if necessary. The comments below are the old
-- comments explaining how to use Param.
--
-- This type is used by instruments that are parameterised by an
-- initial value. It is used in the main function of the algorithm as
-- follows:
--
--    mainFcn =
--      Param initVal1 $ \instr1 ->
--      Param initVal2 $ \instr2 ->
--      do ...
--
-- The initial values are passed to the createInstrument method of the
-- Instrument class so that they can be used during the construction
-- of the instrument. This is implemented in the ApplyArg and
-- ApplyArgs classes.
data Param i f
  = Param !i !f

-- | This datatype is used by the 'NextThreadContext' and
-- 'NextGlobalContext' classes to specify whether the algorithm is
-- progressing to the next pass or back-tracking to a previous
-- pass. When back-tracking occurs, the current thread and global
-- contexts are first passed the 'StepReset' command. Then they are
-- passed the 'StepBackward' command @N@ times, where @N@ is the
-- number of passes that need to be revisited. Note that @N@ can be
-- zero if only the current pass needs to be revisited, so the
-- 'StepBackward' command may not be used. This is the reason why the
-- 'StepReset' command is always issued first.
data StepDirection
  = StepForward
  | StepReset
  | StepBackward
    deriving Eq

-- | This datatype is used by the back-tracking mechanism. Instruments
-- can request that the evaluator back-tracks to a specific pass
-- number. Instruments which use back-tracking store the relevant
-- PassNumbers in their global context. The current 'PassNumber' is
-- the first argument of 'nextGlobalContext' for this
-- purpose. 'PassNumber' is an abstract datatype. Instruments should
-- never need to create a new 'PassNumber' or modify an existing one,
-- so no functions that operate on 'PassNumber' are exported from this
-- module.
newtype PassNumber = PassNumber { unwrapPassNumber :: Int }

-- Increment a PassNumber. This function is not exported.
incrPassNumber :: PassNumber -> PassNumber
incrPassNumber (PassNumber k) =
  PassNumber (k+1)

-- Compute the minimum of two PassNumbers. This function is not
-- exported.
minPassNumber :: PassNumber -> PassNumber -> PassNumber
minPassNumber (PassNumber x) (PassNumber y) =
  PassNumber (min x y)

-- | 'MultiPass', 'MultiPassPrologue', and 'MultiPassEpilogue' are
-- trivial newtype wrappers around this monad. Instruments can
-- construct computations in the 'MultiPassBase' monad, but then use
-- 'mkMultiPass', 'mkMultiPassPrologue', and 'mkMultiPassEpilogue' to
-- restrict which of the three stages it is allowed to be used in.
newtype MultiPassBase r w tc a
  = MultiPassBase
      { unwrapMultiPassBase
          :: ThreadContext r w tc => StateT tc (ST2 r w) a
      }
    deriving Functor

instance Monad (MultiPassBase r w tc) where
  return x = MultiPassBase $ return x

  MultiPassBase m >>= f =
    MultiPassBase $
    do x <- m
       unwrapMultiPassBase (f x)

-- | This monad is used to implement the body of a multi-pass
-- algorithm.
newtype MultiPass r w tc a
  = MultiPass
      { unwrapMultiPass :: MultiPassBase r w tc a
      }
    deriving Functor

instance Monad (MultiPass r w tc) where
  return x = MultiPass $ return x

  MultiPass m >>= f =
    MultiPass $
    do x <- m
       unwrapMultiPass (f x)

-- | Restrict a computation so that it can only be executed during the
-- body of the algorithm (not the prologue or epilogue).
mkMultiPass :: MultiPassBase r w tc a -> MultiPass r w tc a
mkMultiPass =
  MultiPass

-- | This monad is used to implement the prologue of a multi-pass
-- algorithm.
newtype MultiPassPrologue r w tc a
  = MultiPassPrologue
      { unwrapMultiPassPrologue :: MultiPassBase r w tc a
      }
    deriving Functor

instance Monad (MultiPassPrologue r w tc) where
  return x = MultiPassPrologue $ return x

  MultiPassPrologue m >>= f =
    MultiPassPrologue $
    do x <- m
       unwrapMultiPassPrologue (f x)

-- | Restrict a computation so that it can only be executed during the
-- prologue.
mkMultiPassPrologue
  :: MultiPassBase r w tc a -> MultiPassPrologue r w tc a
mkMultiPassPrologue =
  MultiPassPrologue

-- | This monad is used to implement the epilogue of a multi-pass
-- algorithm.
newtype MultiPassEpilogue r w tc a
  = MultiPassEpilogue
      { unwrapMultiPassEpilogue :: MultiPassBase r w tc a
      }
    deriving Functor

instance Monad (MultiPassEpilogue r w tc) where
  return x = MultiPassEpilogue $ return x

  MultiPassEpilogue m >>= f =
    MultiPassEpilogue $
    do x <- m
       unwrapMultiPassEpilogue (f x)

-- | Restrict a computation so that it can only be executed during the
-- epilogue.
mkMultiPassEpilogue
  :: MultiPassBase r w tc a -> MultiPassEpilogue r w tc a
mkMultiPassEpilogue =
  MultiPassEpilogue

-- | 'MultiPassMain' is an abstract datatype containing the prologue,
-- body, and epilogue of a multi-pass algorithm. Use
-- 'mkMultiPassMain' to construct an object of type 'MultiPassMain'.
data MultiPassMain r w tc c =
  forall a b.
  MultiPassMain
    !(MultiPassPrologue r w tc a)
    !(a -> MultiPass r w tc b)
    !(b -> MultiPassEpilogue r w tc c)

-- | Combine the prologue, body, and epilogue of a multi-pass
-- algorithm to create the 'MultiPassMain' object which is required by
-- the 'run' function.
mkMultiPassMain
  :: MultiPassPrologue r w tc a           -- ^ Prologue
  -> (a -> MultiPass r w tc b)            -- ^ Algorithm body
  -> (b -> MultiPassEpilogue r w tc c)    -- ^ Epilogue
  -> MultiPassMain r w tc c
mkMultiPassMain prologue body epilogue =
  MultiPassMain prologue body epilogue

-- Run the prologue, body, and epilogue of a multi-pass algorithm.
runMultiPassMain
  :: ThreadContext r w tc
  => MultiPassMain r w tc a
  -> tc
  -> ST2 r w (a, tc)
runMultiPassMain (MultiPassMain prologue body epilogue) =
  runStateT $
  do x <- unwrapMultiPassBase $ unwrapMultiPassPrologue $ prologue
     y <- unwrapMultiPassBase $ unwrapMultiPass $ body x
     unwrapMultiPassBase $ unwrapMultiPassEpilogue $ epilogue y

-- | This class is used when multiple threads are
-- spawned. 'splitThreadContext' is used to create a new thread
-- context for each of the new threads and 'mergeThreadContext' is
-- used to merge them back together when the parallel region ends.
class ThreadContext r w tc where
  splitThreadContext
    :: Int                 -- Number of threads being created
    -> Int                 -- Index of current thread
    -> tc                  -- Current thread context
    -> ST2 r w tc          -- New sub-context

  mergeThreadContext
    :: Int                 -- Number of threads being merged
    -> (Int -> ST2 r w tc) -- Function to get the i'th sub-context
    -> tc                  -- Previous merged context
    -> ST2 r w tc          -- New merged context

instance ThreadContext r w () where
  splitThreadContext _ _ () = return ()
  mergeThreadContext _ _ () = return ()

instance ThreadContext r w ArgNil where
  splitThreadContext _ _ ArgNil = return ArgNil
  mergeThreadContext _ _ ArgNil = return ArgNil

instance (ThreadContext r w x, ThreadContext r w y) =>
         ThreadContext r w (ArgCons x y) where
  splitThreadContext m t (ArgCons x y) =
    do x' <- splitThreadContext m t x
       y' <- splitThreadContext m t y
       return (ArgCons x' y')

  mergeThreadContext m getSubContext (ArgCons x y) =
    let getSubContextL tc =
          do ArgCons tc' _ <- getSubContext tc
             return tc'
    in
    let getSubContextR tc =
          do ArgCons _ tc' <- getSubContext tc
             return tc'
    in
    do x' <- mergeThreadContext m getSubContextL x
       y' <- mergeThreadContext m getSubContextR y
       return (ArgCons x' y')

instance (ThreadContext r w x, ThreadContext r w y) =>
         ThreadContext r w (x,y) where
  splitThreadContext m t (x,y) =
    do x' <- splitThreadContext m t x
       y' <- splitThreadContext m t y
       return (x', y')

  mergeThreadContext m getSubContext (x,y) =
    let getSubContextL tc =
          do (tc',_) <- getSubContext tc
             return tc'
    in
    let getSubContextR tc =
          do (_,tc') <- getSubContext tc
             return tc'
    in
    do x' <- mergeThreadContext m getSubContextL x
       y' <- mergeThreadContext m getSubContextR y
       return (x',y')

instance ( ThreadContext r w x
         , ThreadContext r w y
         , ThreadContext r w z
         ) =>
         ThreadContext r w (x,y,z) where
  splitThreadContext m t (x,y,z) =
    do x' <- splitThreadContext m t x
       y' <- splitThreadContext m t y
       z' <- splitThreadContext m t z
       return (x', y', z')

  mergeThreadContext m getSubContext (x,y,z) =
    let getSubContext1 tc =
          do (tc',_,_) <- getSubContext tc
             return tc'
    in
    let getSubContext2 tc =
          do (_,tc',_) <- getSubContext tc
             return tc'
    in
    let getSubContext3 tc =
          do (_,_,tc') <- getSubContext tc
             return tc'
    in
    do x' <- mergeThreadContext m getSubContext1 x
       y' <- mergeThreadContext m getSubContext2 y
       z' <- mergeThreadContext m getSubContext3 z
       return (x',y',z')

-- If the initial thread context is Left then splitThreadContext
-- creates only Left thread contexts. Similarly, mergeThreadContext
-- expects all the sub-contexts to match each other.
instance (ThreadContext r w x, ThreadContext r w y) =>
         ThreadContext r w (Either x y) where
  splitThreadContext m t e =
    case e of
      Left x
        -> do x' <- splitThreadContext m t x
              return (Left x')

      Right y
        -> do y' <- splitThreadContext m t y
              return (Right y')

  mergeThreadContext m getSubContext e =
    let getSubContextL tc =
          do Left tc' <- getSubContext tc
             return tc'
    in
    let getSubContextR tc =
          do Right tc' <- getSubContext tc
             return tc'
    in
    case e of
      Left tc
        -> do tc' <- mergeThreadContext m getSubContextL tc
              return (Left tc')

      Right tc
        -> do tc' <- mergeThreadContext m getSubContextR tc
              return (Right tc')

{-|

Every instrument must define an instance of this class for each of its
passes. For example, the
'Control.Monad.MultiPass.Instrument.Counter.Counter' instrument
defines the following instances:

> instance Instrument tc () () () (Counter i r w Off Off tc)
>
> instance Num i =>
>          Instrument tc (CounterTC1 i r) () (Counter i r w On Off tc)
>
> instance Num i =>
>          Instrument tc (CounterTC2 i r) () (Counter i r w On On tc)

The functional dependency from @instr@ to @tc@ and @gc@ enables the
'run' function to automatically deduce the type of the thread context
and global context for each pass.
-}
class Instrument rootTC tc gc instr | instr -> tc gc where
  createInstrument
    :: ST2ToMP rootTC
    -> UpdateThreadContext rootTC tc
    -> gc                          -- ^ Global context
    -> WrapInstrument instr        -- ^ Instrument

-- | This abstract datatype is used as the result type of
-- createInstrument. Instrument authors can create it using the
-- 'wrapInstrument' function, but cannot unwrap it. This ensures that
-- instruments can only be constructed by the "Control.Monad.MultiPass"
-- library.
newtype WrapInstrument instr
  = WrapInstrument instr
    deriving Functor

instance Monad WrapInstrument where
  return x = WrapInstrument x
  WrapInstrument x >>= f = f x

-- | Create an object of type 'WrapInstrument'. It is needed when
-- defining a new instance of the 'Instrument' class.
wrapInstrument :: instr -> WrapInstrument instr
wrapInstrument = WrapInstrument

-- | The type of the first argument of 'createInstrument'. It enables
-- instruments to run 'ST2' in the 'MultiPassBase' monad. (Clearly the
-- @st2ToMP@ argument needs to be used with care.)
type ST2ToMP tc
  = forall r w a. ST2 r w a -> MultiPassBase r w tc a

-- | The type of the first argument of 'createInstrument'. It used to
-- read and write the thread context.
type UpdateThreadContext tc tc'
  = forall r w. (tc' -> tc') -> MultiPassBase r w tc tc'

updateCtxArgL
  :: UpdateThreadContext rootTC (ArgCons tc tcs)
  -> UpdateThreadContext rootTC tc
updateCtxArgL updateCtx h =
  do ArgCons x _ <- updateCtx (mapArgCons h id)
     return x

updateCtxArgR
  :: UpdateThreadContext rootTC (ArgCons tc tcs)
  -> UpdateThreadContext rootTC tcs
updateCtxArgR updateCtx h =
  do ArgCons _ y <- updateCtx (mapArgCons id h)
     return y

class ApplyArg r w param instr f oldTC oldGC tc gc rootTC f'
             | f -> f' tc gc where
  applyArg
    :: PassNumber
    -> StepDirection
    -> param
    -> (instr -> f)
    -> UpdateThreadContext rootTC tc
    -> oldTC
    -> oldGC
    -> ST2 r w (f', tc, gc)

instance ( ApplyArgs r w f oldTCs oldGCs tcs gcs rootTC f'
         , NextThreadContext r w oldTC oldGC tc
         , NextGlobalContext r w oldTC oldGC gc
         , Instrument rootTC tc gc instr
         ) =>
         ApplyArg r w param instr f
                  (ArgCons oldTC oldTCs) (ArgCons oldGC oldGCs)
                  (ArgCons tc tcs) (ArgCons gc gcs)
                  rootTC f' where
  applyArg n d _ f updateCtx
           (ArgCons oldTC oldTCs) (ArgCons oldGC oldGCs) =
    do gc <- nextGlobalContext n d oldTC oldGC
       tc <- nextThreadContext n d oldTC oldGC
       let st2ToMP m = MultiPassBase $ lift m
       let WrapInstrument instr =
             createInstrument st2ToMP (updateCtxArgL updateCtx) gc
       (f', tcs, gcs) <-
         applyArgs n d (f instr) (updateCtxArgR updateCtx) oldTCs oldGCs
       return (f', ArgCons tc tcs, ArgCons gc gcs)

class ApplyArgs r w f oldTC oldGC tc gc rootTC f' | f -> f' tc gc where
  applyArgs
    :: PassNumber
    -> StepDirection
    -> f
    -> UpdateThreadContext rootTC tc
    -> oldTC
    -> oldGC
    -> ST2 r w (f', tc, gc)

instance ApplyArg r w () instr f oldTC oldGC tc gc rootTC f' =>
         ApplyArgs r w (instr -> f) oldTC oldGC tc gc rootTC f' where
  applyArgs n d f updateCtx oldTC oldGC =
    applyArg n d () f updateCtx oldTC oldGC

instance ApplyArg r w param instr f oldTC oldGC tc gc rootTC f' =>
         ApplyArgs r w (Param param (instr -> f)) oldTC oldGC
                   tc gc rootTC f' where
  applyArgs n d (Param param f) updateCtx oldTC oldGC =
    applyArg n d param f updateCtx oldTC oldGC

instance ApplyArgs r w (MultiPassMain r w rootTC a)
                   ArgNil ArgNil ArgNil ArgNil
                   rootTC (MultiPassMain r w rootTC a) where
  applyArgs _ _ f _ ArgNil ArgNil =
    return (f, ArgNil, ArgNil)

class InitCtx ctx where
  initCtx :: ctx

instance InitCtx () where
  initCtx = ()

instance InitCtx ArgNil where
  initCtx = ArgNil

instance (InitCtx a , InitCtx b) =>
         InitCtx (ArgCons a b) where
  initCtx = ArgCons initCtx initCtx

-- | This class is used to create the next thread context when the
-- multi-pass algorithm proceeds to the next pass or back-tracks to
-- the previous pass.
class NextThreadContext r w tc gc tc' where
  nextThreadContext
    :: PassNumber
    -> StepDirection  -- Stepping forwards or backwards?
    -> tc             -- Old thread context
    -> gc             -- Old global context
    -> ST2 r w tc'    -- New thread context

instance NextThreadContext r w tc gc () where
  nextThreadContext _ _ _ _ = return ()

instance ( NextThreadContext r w x gc x'
         , NextThreadContext r w y gc y'
         ) =>
         NextThreadContext r w (x,y) gc (x',y') where
  nextThreadContext n d (x,y) gc =
    do x' <- nextThreadContext n d x gc
       y' <- nextThreadContext n d y gc
       return (x',y')

instance ( NextThreadContext r w () gc x
         , NextThreadContext r w () gc y
         ) =>
         NextThreadContext r w () gc (x,y) where
  nextThreadContext n d () gc =
    do x <- nextThreadContext n d () gc
       y <- nextThreadContext n d () gc
       return (x,y)

instance ( NextThreadContext r w x gc x'
         , NextThreadContext r w y gc y'
         , NextThreadContext r w z gc z'
         ) =>
         NextThreadContext r w (x,y,z) gc (x',y',z') where
  nextThreadContext n d (x,y,z) gc =
    do x' <- nextThreadContext n d x gc
       y' <- nextThreadContext n d y gc
       z' <- nextThreadContext n d z gc
       return (x',y',z')

instance ( NextThreadContext r w () gc x
         , NextThreadContext r w () gc y
         , NextThreadContext r w () gc z
         ) =>
         NextThreadContext r w () gc (x,y,z) where
  nextThreadContext n d () gc =
    do x <- nextThreadContext n d () gc
       y <- nextThreadContext n d () gc
       z <- nextThreadContext n d () gc
       return (x,y,z)

instance ( NextThreadContext r w x gc x'
         , NextThreadContext r w y gc y'
         ) =>
         NextThreadContext r w (Either x y) gc (Either x' y') where
  nextThreadContext n d e gc =
    case e of
      Left x
        -> do x' <- nextThreadContext n d x gc
              return (Left x')

      Right y
        -> do y' <- nextThreadContext n d y gc
              return (Right y')


-- | This class is used to create the next global context when the
-- multi-pass algorithm proceeds to the next pass or back-tracks to
-- the previous pass.
class NextGlobalContext r w tc gc gc' where
  nextGlobalContext
    :: PassNumber
    -> StepDirection  -- Stepping forwards or backwards?
    -> tc             -- Old thread context
    -> gc             -- Old global context
    -> ST2 r w gc'    -- New global context

instance NextGlobalContext r w tc gc () where
  nextGlobalContext _ _ _ _ = return ()

instance ( NextGlobalContext r w tc x x'
         , NextGlobalContext r w tc y y'
         ) =>
         NextGlobalContext r w tc (x,y) (x',y') where
  nextGlobalContext n d tc (x,y) =
    do x' <- nextGlobalContext n d tc x
       y' <- nextGlobalContext n d tc y
       return (x',y')

instance ( NextGlobalContext r w tc x x'
         , NextGlobalContext r w tc y y'
         , NextGlobalContext r w tc z z'
         ) =>
         NextGlobalContext r w tc (x,y,z) (x',y',z') where
  nextGlobalContext n d tc (x,y,z) =
    do x' <- nextGlobalContext n d tc x
       y' <- nextGlobalContext n d tc y
       z' <- nextGlobalContext n d tc z
       return (x',y',z')

instance ( NextGlobalContext r w tc x x'
         , NextGlobalContext r w tc y y'
         ) =>
         NextGlobalContext r w tc (Either x y) (Either x' y') where
  nextGlobalContext n d tc e =
    case e of
      Left x
        -> do x' <- nextGlobalContext n d tc x
              return (Left x')

      Right y
        -> do y' <- nextGlobalContext n d tc y
              return (Right y')

class InstantiatePasses a b | a -> b where
  instantiatePasses :: a -> PassZ b

instance InstantiatePasses (PassZ a) a where
  instantiatePasses (PassZ x) = PassZ x

instance InstantiatePasses (cont (m Off)) b =>
         InstantiatePasses (PassS cont m) b where
  instantiatePasses (PassS f) =
    instantiatePasses (f :: cont (m Off))

-- | Every instrument must define an instance of this class for each
-- of its passes. It is used to tell the evaluator whether it needs to
-- back-track. Instruments which do not back-track should use the
-- default implementation of backtrack which returns 'Nothing' (which
-- means that no back-tracking is necessary.) If more than one
-- instrument requests that the evaluator back-tracks then the
-- evaluator will back-track to the earliest of the requested passes.
class BackTrack r w tc gc where
  backtrack :: tc -> gc -> ST2 r w (Maybe PassNumber)
  backtrack _ _ = return Nothing

-- If the global context is the unit type then the instrument does not
-- back-track.
instance BackTrack r w tc ()

instance BackTrack r w ArgNil ArgNil

instance (BackTrack r w tc gc, BackTrack r w tcs gcs) =>
         BackTrack r w (ArgCons tc tcs) (ArgCons gc gcs) where
  backtrack (ArgCons tc tcs) (ArgCons gc gcs) =
    do mx <- backtrack tc gc
       my <- backtrack tcs gcs
       case (mx,my) of
         (Nothing, Nothing) -> return Nothing
         (Nothing, Just y)  -> return (Just y)
         (Just x, Nothing)  -> return (Just x)
         (Just x, Just y)   -> return (Just (minPassNumber x y))

class RunPasses r w f tc gc p out where
  runPasses
    :: PassNumber -> f -> p out -> tc -> gc
    -> ST2 r w
        (Either
           ( PassNumber
           , MultiPassMain r w tc (p out)
           , tc
           , gc
           )
           out)

instance RunPasses r w (PassZ f) tc gc On out where
  runPasses _ _ (On out) _ _ =
    return (Right out)

instance ( InstantiatePasses (cont (f Off)) fPrev
         , MultiPassAlgorithm (fPrev tc0) gPrev
         , InstantiatePasses (cont (f On)) fCurr
         , MultiPassAlgorithm (fCurr tc1) gCurr
         , ApplyArgs r w gCurr tc0 gc0 tc1 gc1 tc1
                     (MultiPassMain r w tc1 (p out))
         , ApplyArgs r w gCurr tc1 gc1 tc1 gc1 tc1
                     (MultiPassMain r w tc1 (p out))
         , ApplyArgs r w gPrev tc1 gc1 tc0 gc0 tc0
                     (MultiPassMain r w tc0 (q out))
         , ThreadContext r w tc1
         , BackTrack r w tc1 gc1
         , RunPasses r w (cont (f On)) tc1 gc1 p out
         ) =>
         RunPasses r w (PassS cont f) tc0 gc0 q out where
  runPasses n fBox _ =
    let PassS (fPrev :: cont (f Off)) = fBox in
    let PassS (fCurr :: cont (f On)) = fBox in
    let -- Loop header. Run the current pass and check whether
        -- back-tracking is necessary.
        loop g tc gc =
          do (result, tc') <- runMultiPassMain g tc
             mb <- backtrack tc' gc
             case mb of
               Nothing
                 -> -- Current pass is successful, so continue to
                    -- the next pass.
                    let n' = incrPassNumber n in
                    do e <- runPasses n' fCurr result tc' gc
                       case e of
                         Left info -> rewind info
                         Right out -> return (Right out)

               Just m
                 -> stepReset m tc' gc

        -- Call either loop or stepBackward, depending on the
        -- PassNumber.
        rewind (m,g,tc,gc) =
          assert (unwrapPassNumber m <= unwrapPassNumber n) $
          if unwrapPassNumber m == unwrapPassNumber n
             then loop g tc gc
             else stepBackward m tc gc

        -- Reset the contexts and rewind to the requested pass number.
        stepReset m tc gc =
          let PassZ f' = instantiatePasses fCurr in
          let g = unwrapMultiPassAlgorithm (f' :: fCurr tc1) in
          do (g', tc', gc') <-
               applyArgs n StepReset g updateThreadContextTop tc gc
             rewind (m,g',tc',gc')

        -- Return to the previous pass.
        stepBackward m tc gc =
          let PassZ f' = instantiatePasses fPrev in
          let g = unwrapMultiPassAlgorithm (f' :: fPrev tc0) in
          do (g', tc', gc') <-
               applyArgs n StepBackward g updateThreadContextTop tc gc
             return (Left (m,g',tc',gc'))
    in
    let loopStart tc gc =
          let PassZ f' = instantiatePasses fCurr in
          let g = unwrapMultiPassAlgorithm (f' :: fCurr tc1) in
          do (g', tc', gc') <-
               applyArgs n StepForward g updateThreadContextTop tc gc
             loop g' tc' gc'
    in
    loopStart

updateThreadContextTop :: UpdateThreadContext tc tc
updateThreadContextTop f =
  MultiPassBase $
  do tc <- get
     put (f tc)
     return tc

-- | This function is used to run a multi-pass algorithm. Its
-- complicated type is mostly an artifact of the internal
-- implementation, which uses type classes to generate the code for
-- each pass of the algorithm. Therefore, the recommended way to learn
-- how to use 'run' is to look at some of the examples in the
-- @Example@ sub-directory.
run
  :: forall r w f f' g tc gc out.
     ( InstantiatePasses f f'
     , MultiPassAlgorithm (f' tc) g
     , ApplyArgs r w g tc gc tc gc tc
                 (MultiPassMain r w tc (Off out))
     , InitCtx tc
     , InitCtx gc
     , RunPasses r w f tc gc Off out
     )
  => f
  -> ST2 r w out
run f =
  let tc = initCtx :: tc in
  let gc = initCtx :: gc in
  do e <- runPasses (PassNumber 0) f Off tc gc
     case e of
       Left _
         -> -- This is impossible, because it would imply that the
            -- back-tracking mechanism is attempting to back-track to
            -- a negative PassNumber.
            assert False $ error "run"

       Right result
         -> return result

-- | 'NumThreads' is used to specify the number of threads in
-- 'parallelMP' and 'parallelMP_'.
newtype NumThreads
  = NumThreads Int

-- | Use @m@ threads to run @n@ instances of the function @f@. The
-- results are returned in an array of length @n@.
parallelMP
  :: (Ix i, Num i)
  => NumThreads                 -- ^ Number of threads to spawn
  -> (i,i)                      -- ^ Element range
  -> (i -> MultiPass r w tc a)
  -> MultiPass r w tc (ST2Array r w i a)
parallelMP (NumThreads m) bnds f =
  let n = rangeSize bnds in
  assert (m > 0) $
  if m == 1 || n <= 1
     then -- Do not use parallelism.
          do xs <- MultiPass $ MultiPassBase $ lift $ newST2Array_ bnds
             sequence_
               [ do x <- f i
                    MultiPass $ MultiPassBase $ lift $
                      writeST2Array xs i x
               | i <- range bnds
               ]
             return xs
     else assert (m > 1) $
          assert (n > 1) $
          parallelHelper (min m n) n bnds f

parallelHelper
  :: (Ix i, Num i)
  => Int                        -- Number of threads
  -> Int                        -- Number of elements
  -> (i,i)                      -- Element range
  -> (i -> MultiPass r w tc a)
  -> MultiPass r w tc (ST2Array r w i a)
parallelHelper m n bnds f =
  MultiPass $ MultiPassBase $
  do tc <- get
     -- Split the thread state into m sub-states.
     let tBnds = (0,m-1)
     tcs <- lift $ newST2Array_ tBnds
     lift $ sequence_
       [ do tci <- splitThreadContext m t tc
            writeST2Array tcs t tci
       | t <- range tBnds
       ]
     -- Create an array for the results.
     xs <- lift $ newST2Array_ bnds
     let base = fst bnds
     let blockSize = (n+m-1) `div` m
     lift $ parallelST2 tBnds $ \i ->
       do tci <- readST2Array tcs i
          let start = i * blockSize
          let end = min n (start + blockSize)
          tci' <-
            flip execStateT tci $
            sequence_
              [ let j' = base + fromIntegral j in
                do x <- unwrapMultiPassBase $ unwrapMultiPass $ f j'
                   lift $ writeST2Array xs j' x
              | j <- [start .. end-1]
              ]
          writeST2Array tcs i tci'
     -- Create the new merged state.
     tc' <- lift $ mergeThreadContext m (readST2Array tcs) tc
     put tc'
     return xs

-- | Modified version of 'parallelMP' which discards the result of the
-- function, rather than writing it to an array.
parallelMP_
  :: (Ix i, Num i)
  => NumThreads                 -- ^ Number of threads to spawn
  -> (i,i)                      -- ^ Element range
  -> (i -> MultiPass r w tc a)
  -> MultiPass r w tc ()
parallelMP_ (NumThreads m) bnds f =
  let n = rangeSize bnds in
  assert (m > 0) $
  if m == 1 || n <= 1
     then -- Do not use parallelism.
          sequence_ [ f i | i <- range bnds ]
     else assert (m > 1) $
          assert (n > 1) $
          parallelHelper_ (min m n) n bnds f

parallelHelper_
  :: (Ix i, Num i)
  => Int                        -- Number of threads
  -> Int                        -- Number of elements
  -> (i,i)                      -- Element range
  -> (i -> MultiPass r w tc a)
  -> MultiPass r w tc ()
parallelHelper_ m n bnds f =
  MultiPass $ MultiPassBase $
  do tc <- get
     -- Split the thread state into m sub-states.
     let tBnds = (0,m-1)
     tcs <- lift $ newST2Array_ tBnds
     lift $ sequence_
       [ do tci <- splitThreadContext m t tc
            writeST2Array tcs t tci
       | t <- range tBnds
       ]
     let base = fst bnds
     let blockSize = (n+m-1) `div` m
     lift $ parallelST2 tBnds $ \i ->
       do tci <- readST2Array tcs i
          let start = i * blockSize
          let end = min n (start + blockSize)
          tci' <-
            flip execStateT tci $
            sequence_
              [ let j' = base + fromIntegral j in
                unwrapMultiPassBase $ unwrapMultiPass $ f j'
              | j <- [start .. end-1]
              ]
          writeST2Array tcs i tci'
     -- Create the new merged state.
     tc' <- lift $ mergeThreadContext m (readST2Array tcs) tc
     put tc'

-- | Read-only ST2 computations are allowed to be executed in the
-- MultiPass monad.
readOnlyST2ToMP :: (forall w. ST2 r w a) -> MultiPass r w' tc a
readOnlyST2ToMP m =
  MultiPass $ MultiPassBase $
  lift m