module Control.Monad.MultiPass.Instrument.EmitST2ArrayFxp
( EmitST2ArrayFxp
, setBaseIndex, emit, emitList, getIndex, getResult
)
where
import Control.Exception ( assert )
import Control.Monad.ST2
import Control.Monad.Writer.Strict
import Control.Monad.MultiPass
import Control.Monad.MultiPass.Utils.UpdateCtx
import Control.Monad.MultiPass.ThreadContext.CounterTC
import Control.Monad.MultiPass.ThreadContext.MonoidTC
import Data.Ix
data EmitST2ArrayFxp i a r w p1 p2 p3 tc
= EmitST2ArrayFxp
{ setBaseInternal :: !(p2 i -> MultiPassPrologue r w tc ())
, emitInternal :: !(p3 a -> MultiPass r w tc ())
, emitListInternal :: !(p1 Int -> p3 [a] -> MultiPass r w tc ())
, getIndexInternal :: !(forall w'. MultiPass r w' tc (p2 i))
, getResultInternal
:: !(MultiPassEpilogue r w tc (p3 (ST2Array r w i a)))
}
setBaseIndex
:: (Ix i, Num i, Monad p1, Monad p2, Monad p3)
=> EmitST2ArrayFxp i a r w p1 p2 p3 tc
-> p2 i
-> MultiPassPrologue r w tc ()
setBaseIndex =
setBaseInternal
emit
:: (Ix i, Num i, Monad p1, Monad p2, Monad p3)
=> EmitST2ArrayFxp i a r w p1 p2 p3 tc
-> p3 a
-> MultiPass r w tc ()
emit =
emitInternal
emitList
:: (Ix i, Num i, Monad p1, Monad p2, Monad p3)
=> EmitST2ArrayFxp i a r w p1 p2 p3 tc
-> p1 Int
-> p3 [a]
-> MultiPass r w tc ()
emitList =
emitListInternal
getIndex
:: (Ix i, Num i, Monad p1, Monad p2, Monad p3)
=> EmitST2ArrayFxp i a r w p1 p2 p3 tc
-> MultiPass r w' tc (p2 i)
getIndex =
getIndexInternal
getResult
:: (Ix i, Num i, Monad p1, Monad p2, Monad p3)
=> EmitST2ArrayFxp i a r w p1 p2 p3 tc
-> MultiPassEpilogue r w tc (p3 (ST2Array r w i a))
getResult =
getResultInternal
instance Instrument tc () ()
(EmitST2ArrayFxp i a r w Off Off Off tc) where
createInstrument _ _ () =
wrapInstrument $
EmitST2ArrayFxp
{ setBaseInternal = \Off -> return ()
, emitInternal = \Off -> return ()
, emitListInternal = \Off Off -> return ()
, getIndexInternal = return Off
, getResultInternal = return Off
}
type TC1 i r = (CounterTC1 i r, CounterTC1 ListIndex r)
newtype ListIndex
= ListIndex Int
deriving (Eq,Ord,Ix)
instance Num ListIndex where
(ListIndex x) + (ListIndex y) = ListIndex (x + y)
(ListIndex x) (ListIndex y) = ListIndex (x y)
(ListIndex x) * (ListIndex y) = ListIndex (x * y)
negate (ListIndex x) = ListIndex (negate x)
abs (ListIndex x) = ListIndex (abs x)
signum (ListIndex x) = ListIndex (signum x)
fromInteger x = ListIndex (fromInteger x)
instance Show ListIndex where
show (ListIndex i) = show i
instance Num i =>
Instrument tc (TC1 i r) ()
(EmitST2ArrayFxp i a r w On Off Off tc) where
createInstrument _ updateCtx () =
wrapInstrument $
EmitST2ArrayFxp
{ setBaseInternal = \Off ->
return ()
, emitInternal = \Off ->
mkMultiPass $
do _ <- updateCtxFst updateCtx incrCounterTC1
return ()
, emitListInternal = \(On lowerBound) Off ->
mkMultiPass $
do _ <- updateCtxFst updateCtx
(addkCounterTC1 (fromIntegral lowerBound))
_ <- updateCtxSnd updateCtx incrCounterTC1
return ()
, getIndexInternal =
return Off
, getResultInternal =
return Off
}
type TC2 i r = (CounterTC2 i r, CounterTC2 ListIndex r)
data GC2 r w i
= GC2 {
gc2_base :: !(ST2Ref r w i)
, gc2_initialised :: !Bool
, gc2_length_array :: !(ST2Array r w ListIndex Int)
, gc2_passnumber :: !PassNumber
}
instance (Ix i, Num i) =>
Instrument tc (TC2 i r) (GC2 r w i)
(EmitST2ArrayFxp i a r w On On Off tc) where
createInstrument st2ToMP updateCtx gc =
wrapInstrument $
EmitST2ArrayFxp
{ setBaseInternal = \(On base) ->
mkMultiPassPrologue $
st2ToMP $ writeST2Ref (gc2_base gc) base
, emitInternal = \Off ->
void $ mkMultiPass $ updateCtxFst updateCtx incrCounterTC2
, emitListInternal =
let lenArray = gc2_length_array gc in
if gc2_initialised gc then (
\(On lowerBound) Off ->
mkMultiPass $
do listCount <- updateCtxSnd updateCtx incrCounterTC2
let i = counterVal2 listCount
len <- st2ToMP $ readST2Array lenArray i
assert (lowerBound <= len) $ return ()
void $ updateCtxFst updateCtx $
addkCounterTC2 (fromIntegral len)
) else (
\(On lowerBound) Off ->
mkMultiPass $
do listCount <- updateCtxSnd updateCtx incrCounterTC2
let i = counterVal2 listCount
st2ToMP $ writeST2Array lenArray i lowerBound
void $ updateCtxFst updateCtx $
addkCounterTC2 (fromIntegral lowerBound)
)
, getIndexInternal =
mkMultiPass $
do base <- st2ToMP $ readST2Ref (gc2_base gc)
counter <- updateCtxFst updateCtx id
return (On (base + counterVal2 counter))
, getResultInternal =
return Off
}
data TC3 i r
= TC3 { indexCounter :: CounterTC2 i r
, listCounter :: CounterTC2 ListIndex r
, newIndexCounter :: CounterTC1 i r
, indexChanged :: MonoidTC Any
}
updateIndexCounter
:: UpdateThreadContext rootTC (TC3 i r)
-> UpdateThreadContext rootTC (CounterTC2 i r)
updateIndexCounter updateCtx f =
do tc <- updateCtx $ \tc ->
tc { indexCounter = f (indexCounter tc) }
return (indexCounter tc)
updateListCounter
:: UpdateThreadContext rootTC (TC3 i r)
-> UpdateThreadContext rootTC (CounterTC2 ListIndex r)
updateListCounter updateCtx f =
do tc <- updateCtx $ \tc ->
tc { listCounter = f (listCounter tc) }
return (listCounter tc)
updateNewIndexCounter
:: UpdateThreadContext rootTC (TC3 i r)
-> UpdateThreadContext rootTC (CounterTC1 i r)
updateNewIndexCounter updateCtx f =
do tc <- updateCtx $ \tc ->
tc { newIndexCounter = f (newIndexCounter tc) }
return (newIndexCounter tc)
updateIndexChanged
:: UpdateThreadContext rootTC (TC3 i r)
-> UpdateThreadContext rootTC (MonoidTC Any)
updateIndexChanged updateCtx f =
do tc <- updateCtx $ \tc ->
tc { indexChanged = f (indexChanged tc) }
return (indexChanged tc)
instance Num i => ThreadContext r w (TC3 i r) where
splitThreadContext m t (TC3 a b c d) =
do a' <- splitThreadContext m t a
b' <- splitThreadContext m t b
c' <- splitThreadContext m t c
d' <- splitThreadContext m t d
return (TC3 a' b' c' d')
mergeThreadContext m getSubContext (TC3 a b c d) =
let getField f tc =
do tc' <- getSubContext tc
return (f tc')
in
do a' <- mergeThreadContext m (getField indexCounter) a
b' <- mergeThreadContext m (getField listCounter) b
c' <- mergeThreadContext m (getField newIndexCounter) c
d' <- mergeThreadContext m (getField indexChanged) d
return $ TC3
{ indexCounter = a'
, listCounter = b'
, newIndexCounter = c'
, indexChanged = d'
}
data GC3 r w i a
= GC3 {
gc3_base :: !(ST2Ref r w i)
, gc3_length_array :: !(ST2Array r w ListIndex Int)
, gc3_output_array :: !(ST2Array r w i a)
, gc3_ready :: !Bool
, gc3_passnumber2 :: !PassNumber
, gc3_passnumber3 :: !PassNumber
}
instance (Ix i, Num i) =>
Instrument tc (TC3 i r) (GC3 r w i a)
(EmitST2ArrayFxp i a r w On On On tc) where
createInstrument st2ToMP updateCtx gc =
let writeHelper =
do void $ updateNewIndexCounter updateCtx incrCounterTC1
base <- st2ToMP $ readST2Ref (gc3_base gc)
counter <-
updateIndexCounter updateCtx incrCounterTC2
return $ base + counterVal2 counter
in
let writeListHelper lowerBound ys =
let newLen = length ys in
do
listCount <-
updateListCounter updateCtx incrCounterTC2
let i = counterVal2 listCount
oldLen <- st2ToMP $ readST2Array (gc3_length_array gc) i
st2ToMP $ writeST2Array (gc3_length_array gc) i newLen
assert (newLen >= lowerBound) $ return ()
assert (newLen >= oldLen) $ return ()
let changed = MonoidTC $ Any $ newLen /= oldLen
void $ updateIndexChanged updateCtx $ mappend changed
void $ updateNewIndexCounter updateCtx
(addkCounterTC1 (fromIntegral newLen))
base <- st2ToMP $ readST2Ref (gc3_base gc)
indexCount <- updateIndexCounter updateCtx
(addkCounterTC2 (fromIntegral oldLen))
return (base + counterVal2 indexCount)
in
let setBaseHelper (On base) =
mkMultiPassPrologue $
st2ToMP $ writeST2Ref (gc3_base gc) base
in
let getIndexHelper =
mkMultiPass $
do base <- st2ToMP $ readST2Ref (gc3_base gc)
indexCount <- updateIndexCounter updateCtx id
return (On (base + counterVal2 indexCount))
in
let xs = gc3_output_array gc in
let getResultHelper = return $ On $ xs in
if gc3_ready gc then (
wrapInstrument $
EmitST2ArrayFxp
{ setBaseInternal = setBaseHelper
, emitInternal = \(On x) ->
mkMultiPass $
do k <- writeHelper
st2ToMP $ writeST2Array xs k x
, emitListInternal = \(On lowerBound) (On ys) ->
mkMultiPass $
do j <- writeListHelper lowerBound ys
let n = length ys
sequence_
[ let j' = j + fromIntegral k in
st2ToMP $ writeST2Array xs j' y
| (k,y) <- zip [0 .. n1] ys
]
, getIndexInternal = getIndexHelper
, getResultInternal = getResultHelper
}
) else (
wrapInstrument $
EmitST2ArrayFxp
{ setBaseInternal = setBaseHelper
, emitInternal = \(On _) ->
void $ mkMultiPass $ writeHelper
, emitListInternal = \(On lowerBound) (On ys) ->
void $ mkMultiPass $ writeListHelper lowerBound ys
, getIndexInternal = getIndexHelper
, getResultInternal = getResultHelper
}
)
instance BackTrack r w (TC2 i r) (GC2 r w i)
instance BackTrack r w (TC3 i r) (GC3 r w i a) where
backtrack tc gc =
let MonoidTC (Any changed) = indexChanged tc in
case (changed, gc3_ready gc) of
(False, False)
->
return $ Just $ gc3_passnumber3 gc
(False, True)
->
return Nothing
(True, False)
->
return $ Just $ gc3_passnumber2 gc
(True, True)
->
assert False $ return Nothing
instance Num i =>
NextThreadContext r w (TC3 i r) gc (TC3 i r) where
nextThreadContext _ _ tc _ =
do
indexCount <- newCounterTC2 (newIndexCounter tc)
return $ TC3
{ indexCounter = indexCount
, listCounter = resetCounterTC2 (listCounter tc)
, newIndexCounter = newCounterTC1
, indexChanged = mempty
}
instance Num i =>
NextGlobalContext r w (TC1 i r) () (GC2 r w i) where
nextGlobalContext n _ (_,listCount) () =
do base <- newST2Ref 0
xs <- newST2Array_ (0, counterVal1 listCount 1)
return $ GC2
{ gc2_base = base
, gc2_initialised = False
, gc2_length_array = xs
, gc2_passnumber = n
}
instance NextGlobalContext r w (TC2 i r) (GC2 r w i) (GC2 r w i) where
nextGlobalContext _ _ _ gc =
return $ gc { gc2_initialised = True }
instance (Ix i, Num i) =>
NextGlobalContext r w (TC2 i r) (GC2 r w i) (GC3 r w i a) where
nextGlobalContext n _ _ gc =
do
xs <- newST2Array_ (0,0)
return $ GC3
{ gc3_base = gc2_base gc
, gc3_length_array = gc2_length_array gc
, gc3_output_array = xs
, gc3_ready = False
, gc3_passnumber2 = gc2_passnumber gc
, gc3_passnumber3 = n
}
instance (Ix i, Num i) =>
NextGlobalContext r w (TC3 i r)
(GC3 r w i a) (GC3 r w i a) where
nextGlobalContext _ StepForward _ gc = return gc
nextGlobalContext _ StepBackward _ gc = return gc
nextGlobalContext _ StepReset tc gc =
let MonoidTC (Any changed) = indexChanged tc in
case (changed, gc3_ready gc) of
(False, False)
->
do base <- readST2Ref (gc3_base gc)
let n = base + counterVal2 (indexCounter tc)
xs <- newST2Array_ (base, n1)
return $ gc
{ gc3_output_array = xs
, gc3_ready = True
}
(False, True)
->
return gc
(True, False)
->
return gc
(True, True)
->
assert False $ return gc
instance NextGlobalContext r w (TC3 i r) (GC3 r w i a)
(GC2 r w i) where
nextGlobalContext _ _ _ gc =
return $ GC2
{ gc2_base = gc3_base gc
, gc2_initialised = True
, gc2_length_array = gc3_length_array gc
, gc2_passnumber = gc3_passnumber2 gc
}
instance Num i =>
NextThreadContext r w (TC2 i r) gc (TC3 i r) where
nextThreadContext _ _ (indexCount, listCount) _ =
return $ TC3
{ indexCounter = resetCounterTC2 indexCount
, listCounter = resetCounterTC2 listCount
, newIndexCounter = newCounterTC1
, indexChanged = mempty
}
instance NextThreadContext r w (TC3 i r) gc (TC2 i r) where
nextThreadContext _ _ tc _ =
return (indexCounter tc, listCounter tc)