{-# LANGUAGE FlexibleContexts #-} module Control.Search.Combinator.If (if') where import Control.Search.Language import Control.Search.GeneratorInfo import Control.Search.MemoReader import Control.Search.Generator import Control.Search.Stat import Control.Monatron.Monatron hiding (Abort, L, state, cont) import Control.Monatron.Zipper hiding (i,r) xs1 uid lsuper rsuper = Struct ("LeftEvalState" ++ show uid) $ {- (Bool, "cont") : -} (Int, "ref_count") : [(ty, field) | (field,ty,_) <- evalState_ lsuper] xfs1 uid lsuper rsuper = [(field,init) | (field,ty,init) <- evalState_ rsuper ] xs2 uid lsuper rsuper = Struct ("RightEvalState" ++ show uid) $ {- (Bool, "cont") : -} (Int, "ref_count") : [(ty, field) | (field,ty,_) <- evalState_ rsuper] xfs2 uid lsuper rsuper = [(field,init) | (field,ty,init) <- evalState_ rsuper ] xs3 uid lsuper rsuper = Struct ("LeftTreeState" ++ show uid) $ (Pointer $ SType $ xs1 uid lsuper rsuper, "evalState") : [(ty, field) | (field,ty,_) <- treeState_ lsuper] xfs3 uid lsuper rsuper = [(field,init) | (field,ty,init) <- treeState_ lsuper] xs4 uid lsuper rsuper = Struct ("RightTreeState" ++ show uid) $ (Pointer $ SType $ xs2 uid lsuper rsuper, "evalState") : [(ty, field) | (field,ty,_) <- treeState_ rsuper] xfs4 uid lsuper rsuper = [(field,init) | (field,ty,init) <- treeState_ rsuper] in1 = \state -> state @-> "if_union" @-> "if_then" in2 = \state -> state @-> "if_union" @-> "if_else" xpath uid lsuper rsuper i FirstS = withPath i in1 (SType $ xs1 uid lsuper rsuper) (SType $ xs3 uid lsuper rsuper) xpath uid lsuper rsuper i SecondS = withPath i in2 (SType $ xs2 uid lsuper rsuper) (SType $ xs4 uid lsuper rsuper) ifLoop :: (Evalable m, ReaderM SeqPos m) => Stat -> Int -> Eval m -> Eval m -> Eval m ifLoop cond uid lsuper rsuper = commentEval $ Eval { structs = structs lsuper @++@ structs rsuper @++@ mystructs , toString = "if" ++ show uid ++ "(" ++ show cond ++ "," ++ toString lsuper ++ "," ++ toString rsuper ++ ")" , treeState_ = [("if_true", Bool,const $ return Skip), ("if_union",Union [(SType s3,"if_true"),(SType s4,"if_false")],const $ return Skip) ] , initH = \i -> (readStat cond >>= \r -> return (assign (r i) (tstate i @-> "if_true"))) @>>>@ initstate i , evalState_ = [] , pushLeftH = push pushLeft , pushRightH = push pushRight , nextSameH = \i -> let j = i `withBase` "popped_estate" in do nS1 <- local (const FirstS) $ inSeq nextSame i nS2 <- local (const SecondS) $ inSeq nextSame i nD1 <- local (const FirstS) $ inSeq nextDiff i nD2 <- local (const SecondS) $ inSeq nextDiff i return $ IfThenElse (is_fst i) (IfThenElse (is_fst j) nS1 nD1) (IfThenElse (is_fst j) nD2 nS2) , nextDiffH = \i -> inSeq nextDiff i , bodyH = \i -> let f y z p = let j = mpath i p {- in do cond <- continue z (estate j) deref <- dec_ref i stmt <- bodyE z (j `onAbort` deref) return $ IfThenElse (cont j) (IfThenElse cond stmt ( (cont j <== false) >>> deref >>> abort j)) (deref >>> abort j) -} in dec_ref i >>= \deref -> bodyE z (j `onAbort` deref) in IfThenElse (is_fst i) @$ local (const FirstS) (f in1 lsuper FirstS) @. local (const SecondS) (f in2 rsuper SecondS) , addH = inSeq $ addE , failH = \i -> inSeq failE i @>>>@ dec_ref i , returnH = \i -> let j1 deref = mpath i FirstS `onCommit` deref j2 deref = mpath i SecondS `onCommit` deref in IfThenElse (is_fst i) @$ (dec_refx (j1 Skip) >>= returnE lsuper . j1) @. (dec_refx (j2 Skip) >>= returnE rsuper . j2) -- , continue = \_ -> return true , tryH = \i -> IfThenElse (is_fst i) @$ tryE lsuper (mpath i FirstS) @. tryE rsuper (mpath i SecondS) , startTryH = \i -> IfThenElse (is_fst i) @$ startTryE lsuper (mpath i FirstS) @. startTryE rsuper (mpath i SecondS) , tryLH = \i -> IfThenElse (is_fst i) @$ tryE_ lsuper (mpath i FirstS) @. tryE_ rsuper (mpath i SecondS) , boolArraysE = boolArraysE lsuper ++ boolArraysE rsuper , intArraysE = intArraysE lsuper ++ intArraysE rsuper , intVarsE = intVarsE lsuper ++ intVarsE rsuper , deleteH = deleteMe , canBranch = canBranch lsuper >>= \l -> canBranch rsuper >>= \r -> return (l || r) , complete = \i -> do sid1 <- complete lsuper (mpath i FirstS) sid2 <- complete rsuper (mpath i SecondS) return $ Cond (tstate i @-> "is_fst") sid1 sid2 } where mystructs = ([s1,s2],[s3,s4]) s1 = xs1 uid lsuper rsuper s2 = xs2 uid lsuper rsuper s3 = xs3 uid lsuper rsuper s4 = xs4 uid lsuper rsuper fs1 = xfs1 uid lsuper rsuper fs2 = xfs2 uid lsuper rsuper fs3 = xfs3 uid lsuper rsuper fs4 = xfs4 uid lsuper rsuper mpath = xpath uid lsuper rsuper withSeq f = seqSwitch (f lsuper in1) (f rsuper in2) withSeq_ f = seqSwitch (f lsuper in1 FirstS) (f rsuper in2 SecondS) inSeq f = \i -> withSeq_ $ \super ins pos -> f super (mpath i pos) dec_ref = \i -> seqSwitch (dec_refx $ mpath i FirstS) (dec_refx $ mpath i SecondS) dec_refx = \j -> return $ dec (ref_count j) >>> ifthen (ref_count j @== 0) (comment "ifLoop-dec_refx" >>> Delete (estate j)) push dir = \i -> seqSwitch (push1 dir i) (push2 dir i) push1 dir = \i -> let j = mpath i FirstS in dir lsuper (j `onCommit` ( mkCopy i "if_true" >>> mkCopy j "evalState" >>> inc (ref_count j) )) push2 dir = \i -> let j = mpath i SecondS in dir rsuper (j `onCommit` ( mkCopy i "if_true" >>> mkCopy j "evalState" >>> inc (ref_count j) )) initstate = \i -> let f d = let j = mpath i (if d then FirstS else SecondS) in return ( (estate j <== New (if d then s1 else s2)) >>> (ref_count j <== 1) ) @>>>@ inite (if d then fs1 else fs2) j @>>>@ inits (if d then lsuper else rsuper) j in do thenP <- f True elseP <- f False return $ IfThenElse (tstate i @-> "if_true") thenP elseP in1 = \state -> state @-> "if_union" @-> "if_then" in2 = \state -> state @-> "if_union" @-> "if_else" is_fst = \i -> tstate i @-> "if_true" deleteMe = \i -> seqSwitch (deleteE lsuper (mpath i FirstS)) (deleteE rsuper (mpath i SecondS)) @>>>@ dec_ref i if' :: Stat -> Search -> Search -> Search if' cond s1 s2 = case s1 of Search { mkeval = evals1, runsearch = runs1 } -> case s2 of Search { mkeval = evals2, runsearch = runs2 } -> Search { mkeval = \super -> do { s2' <- evals2 $ mapE (L . L . L . mmap (mmap runL . runL) . runL) super ; s1' <- evals1 $ mapE (L . L . mmap (mmap runL . runL) . runL) super ; uid <- get ; put (uid + 1) ; return $ mapE (L . mmap L . runL) $ ifLoop cond uid (mapE (L . mmap (mmap L) . runL . runL) s1') (mapE (L . mmap (mmap L) . runL . runL . runL) s2') } , runsearch = runs2 . runs1 . runL . rReaderT FirstS . runL } where in1 = \state -> state @-> "if_union" @-> "if_then" in2 = \state -> state @-> "if_union" @-> "if_else"