{-# LANGUAGE FlexibleContexts #-}

module Control.Search.Combinator.Or ((<|>)) where

import Control.Search.Language
import Control.Search.GeneratorInfo
import Control.Search.Generator
import Control.Search.MemoReader
import Control.Search.Memo

import Control.Monatron.Monatron hiding (Abort, L, state, cont)
import Control.Monatron.Zipper hiding (i,r)

xs1 uid lsuper rsuper       = Struct ("LeftEvalState"  ++ show uid)  $ (THook "TreeState", "parent") : {- (Bool, "cont") : -} (Int, "ref_count") : [(ty, field) | (field,ty,_) <- evalState_ lsuper]
xfs1 uid lsuper rsuper       = [(field,init) | (field,ty,init) <- evalState_ lsuper ]
xs2 uid lsuper rsuper        = Struct ("RightEvalState" ++ show uid) $ xneedSide uid lsuper rsuper SecondS $ {- (Bool, "cont") : -} (Int, "ref_count") : [(ty, field) | (field,ty,_) <- evalState_ rsuper]
xfs2 uid lsuper rsuper       = [(field,init) | (field,ty,init) <- evalState_ rsuper ]
xet uid lsuper rsuper FirstS = SType $ xs1 uid lsuper rsuper
xet uid lsuper rsuper SecondS = SType $ xs2 uid lsuper rsuper
xs3 uid lsuper rsuper        = Struct ("LeftTreeState"  ++ show uid) $ (Pointer $ SType $ xs1 uid lsuper rsuper, "evalState") : [(ty, field) | (field,ty,_) <- treeState_ lsuper]
xfs3 uid lsuper rsuper       = [(field,init) | (field,ty,init) <- treeState_ lsuper]
xs4 uid lsuper rsuper        = Struct ("RightTreeState" ++ show uid) $ xneedSide uid lsuper rsuper SecondS [(Pointer $ SType $ xs2 uid lsuper rsuper, "evalState")] ++ [(ty, field) | (field,ty,_) <- treeState_ rsuper]
xst uid lsuper rsuper FirstS = SType $ xs3 uid lsuper rsuper
xst uid lsuper rsuper SecondS = SType $ xs4 uid lsuper rsuper
xneedSide :: Monoid m => Int -> Eval n -> Eval n -> SeqPos -> m -> m
xneedSide uid lsuper rsuper = \pos stm -> case pos of { FirstS -> stm;
                                                       SecondS -> if (length (evalState_ rsuper) == 0) then mempty else stm;
                                                     }

orLoop :: (ReaderM SeqPos m, Evalable m) => Int -> (Eval m) -> (Eval m) -> Eval m
orLoop uid (lsuper) (rsuper) = commentEval $
  Eval { structs     = structs lsuper @++@ structs rsuper @++@ mystructs 
       , toString    = "or" ++ show uid ++ "(" ++ toString lsuper ++ "," ++ toString rsuper ++ ")"
       , treeState_   = [entry ("is_fst",Bool,assign true)
                       , ("or_union",Union [(SType s3,"fst"),(SType s4,"snd")], 
				\i -> 
                                   let j = withPath i in1 (et FirstS) (st FirstS)
                                   in        do cc <- cachedClone i (cloneBase j)
                                                return (    (estate j <== New s1)
				                        >>> (ref_count j <== 1)
--				                        >>> (cont j <== true)
                                                        >>> (parent j <== baseTstate j)
                                                        >>> cc
                                                       )
                                       @>>>@ mseqs [init (j `withClone` (\k -> inc $ ref_count k)) | (f,init) <- fs3]
                                       @>>>@ inite fs1 j
                         )]
       , initH       = \i -> initE lsuper (withPath i in1 (et FirstS) (st FirstS))
       , evalState_  = []
       , pushLeftH    = push pushLeft
       , pushRightH   = push pushRight
       , nextSameH    = \i -> let j = i `withBase` "popped_estate"
                             in do nS1 <- local (const FirstS)  $ inSeq nextSame i
                                   nS2 <- local (const SecondS) $ inSeq nextSame i
                                   nD1 <- local (const FirstS)  $ inSeq nextDiff i
                                   nD2 <- local (const SecondS) $ inSeq nextDiff i
                                   return $ IfThenElse (is_fst i) 
                                                       (IfThenElse (is_fst j) nS1 nD1)
                                                       (IfThenElse (is_fst j) nD2 nS2) 
       , nextDiffH    = \i -> inSeq nextDiff i
       , bodyH       = \i ->
                         let f y z p = 
                               let j = withPath i y (et p) (st p)
                                 in dec_ref i >>= \deref -> bodyE z (j `onAbort` deref)
			 in IfThenElse (is_fst i) @$ local (const FirstS)  (f in1 lsuper FirstS)
                                                  @. local (const SecondS) (f in2 rsuper SecondS)
       , addH        = inSeq $ addE
       , failH       = \i -> inSeq failE i @>>>@ dec_ref i
       , returnH     = \i -> 
			     let j1 deref = (withPath i in1 (et FirstS) (st FirstS)) `onCommit` (comment "returnE-deref-j1" >>> deref >>> comment "end returnE-deref-j1")
                                 j2 deref = (withPath i in2 (et SecondS) (st SecondS)) `onCommit` (comment "returnE-deref-j2" >>> deref >>> comment "end returnE-deref-j2")
                             in seqSwitch (dec_ref1 i >>= returnE lsuper . j1)
                                          (dec_ref2 (j2 Skip) >>= returnE rsuper . j2) 
       , tryH        = \i -> 
			  do  dr <- dec_ref i
                              inSeq (\super j -> tryE super (j `onAbort` (comment "Combinator/Or tryH onAbort" >>> dr ))) i
       , startTryH   = \i -> local (const FirstS) $ inSeq startTryE i
       , tryLH       = \i -> inSeq tryE_ i @>>>@ dec_ref i
       , boolArraysE  = boolArraysE lsuper ++ boolArraysE rsuper
       , intArraysE  = intArraysE lsuper ++ intArraysE rsuper
       , intVarsE    = intVarsE lsuper ++ intVarsE rsuper
       , deleteH     = deleteMe
       , canBranch   = return True
       , complete    = \i -> do sid1 <- complete lsuper (withPath i in1 (et FirstS) (st FirstS))
                                sid2 <- complete rsuper (withPath i in2 (et SecondS) (st SecondS))
                                return $ (Cond (tstate i @-> "is_fst") sid1 sid2)

--       , complete = const $ return false
       }
  where mystructs = ([s1,s2],[s3,s4])
        s1 = xs1 uid lsuper rsuper
        s2 = xs2 uid lsuper rsuper
        s3 = xs3 uid lsuper rsuper
        s4 = xs4 uid lsuper rsuper
        fs1 = xfs1 uid lsuper rsuper
        fs2 = xfs2 uid lsuper rsuper
        fs3 = xfs3 uid lsuper rsuper
        et = xet uid lsuper rsuper
        st = xst uid lsuper rsuper
        needSide = xneedSide uid lsuper rsuper
        parent    = \i -> estate i @=> "parent"
        withSeq f = seqSwitch (f lsuper in1 FirstS) (f rsuper in2 SecondS)
        withSeq_ f = seqSwitch (f lsuper in1 FirstS) (f rsuper in2 SecondS)
        inSeq f   = \i     -> withSeq_ $ \super ins pos -> f super (withPath i ins (et pos) (st pos))
        dec_ref    = \i -> seqSwitch (dec_ref1 i) (dec_ref2 $ withPath i in2 (et SecondS) (st SecondS))
        dec_ref1   = \i ->      let j1     = withPath i in1 (et FirstS) (st FirstS)
                                    i'     = resetClone $ resetAbort $ resetCommit $ i `withBase` ("or_tstate" ++ show uid)
                                    j2     = withPath i' in2 (et SecondS) (st SecondS)
                                in (local (const SecondS) $
                                    do stmt1 <- inits rsuper j2
                                       stmt2 <- startTryE rsuper j2
                                       ini <- inite fs2 j2
                                       compl <- complete lsuper j1
				       return (    dec (ref_count j1) 
                                               >>> (ifthen (ref_count j1 @== 0) $
                                                      (
                                                      {- DebugValue ("or" ++ show uid ++ ": left finished with complete") (compl)
                                                      >>> -} (ifthen (Not compl) $
				                            (   SHook ("TreeState or_tstate" ++ show uid ++ ";")
							    >>> (baseTstate j2 <== parent j1)
                                                            >>> (is_fst i' <== false)
                                                            >>> comment "orLoop-dec_ref1-Delete" >>> Delete (estate j1)
                                                            >>> needSide SecondS (estate j2 <== New s2)  
				                            >>> needSide SecondS (ref_count j2 <== 1)
--				                            >>> (cont j2 <== true)
  				                            >>> ini
                                                            >>> stmt1 >>> stmt2
                                                            )
                                                          )
                                                      )
                                                   )
                                              )
                                   )
        dec_ref2  = \j -> {- return (DebugValue ("or" ++ show uid ++ ": right dec_ref from") (ref_count j)) @>>>@ -} (complete rsuper (withPath (resetClone $ resetAbort $ resetCommit $ j `withBase` ("or_tstate" ++ show uid)) in2 (et SecondS) (st SecondS)) >>= \compl -> (return $ needSide SecondS $ dec (ref_count j) >>> ifthen (ref_count j @== 0) ({- DebugValue ("or" ++ show uid ++ ": right finished with complete") compl >>> -} comment "orLoop-dec_ref2-Delete" >>> Delete (estate j))))
        push dir  = \i -> seqSwitch (push1 dir i) (push2 dir i)
        push1 dir = \i -> 
                           let j = withPath i in1 (et FirstS) (st FirstS)
                           in  dir lsuper (j `onCommit` (   mkCopy i "is_fst"
                                                        >>> mkCopy j "evalState"
                                                        >>> inc (ref_count j)
                                                        ))
        push2 dir = \i -> 
                           let j = withPath i in2 (et SecondS) (st SecondS)
                           in  dir rsuper (j `onCommit` (   mkCopy i "is_fst"
                                                        >>> needSide SecondS (mkCopy j "evalState")
                                                        >>> needSide SecondS (inc (ref_count j))
                                                       ))
	in1       = \state -> state @-> "or_union" @-> "fst"
	in2       = \state -> state @-> "or_union" @-> "snd"
	is_fst    = \i -> tstate i @-> "is_fst"
	deleteMe  = \i -> seqSwitch (deleteE lsuper (withPath i in1 (et FirstS) (st FirstS))) (deleteE rsuper (withPath i in2 (et SecondS) (st SecondS))) @>>>@ dec_ref i

(<|>)
  :: Search
  -> Search
  -> Search
s1 <|> s2 = 
  case s1 of
    Search { mkeval = evals1, runsearch = runs1 } ->
      case s2 of
        Search { mkeval = evals2, runsearch = runs2 } ->
	  Search {mkeval =
	          \super -> do { s2' <- evals2 $ mapE (L . L . L . mmap (mmap runL . runL) . runL)  super
	                       ; s1' <- evals1 $ mapE (L . L . mmap (mmap runL . runL) . runL) super
			       ; uid <- get
			       ; put (uid + 1)
	                       ; return $ mapE (L . mmap L . runL) $ 
			           	orLoop uid (mapE (L . mmap (mmap L) . runL . runL) s1')
	                                               (mapE (L . mmap (mmap L) . runL . runL . runL) s2')
	                       }
	         , runsearch  = runs2 . runs1 . runL . rReaderT FirstS . runL
	         }
 where 	in1       = \state -> state @-> "or_union" @-> "fst"
	in2       = \state -> state @-> "or_union" @-> "snd"