{-# LANGUAGE FlexibleContexts #-}

module Control.Search.Combinator.And (andN,(<&>)) where

import Data.Maybe (fromMaybe, catMaybes, fromJust)

import Control.Search.Language
import Control.Search.GeneratorInfo
import Control.Search.Memo
import Control.Search.MemoReader
import Control.Search.Generator

import Control.Search.Combinator.Success

import Control.Monatron.Monatron hiding (Abort, L, state, cont)
import Control.Monatron.Zipper hiding (i,r)
import Control.Monatron.IdT

seqNLoop :: (ReaderM Int m, Evalable m) => Int -> [Eval m] -> Eval m
seqNLoop uid lst = commentEval $
  Eval { structs     = (foldr1 (@++@) $ map (structs) lst) @++@ mystructs 
       , toString = "seqN" ++ show uid ++ "(" ++ (foldr1 (\x y -> x ++ "," ++ y) $ map (toString) lst) ++ ")"
       , treeState_  = [entry ("seqn_pos",Int,assign 0)                      -- is the first or the second search active?
                       , ("seqn_union",Union [(SType (s3 i),"seq" ++ show i) | i <- [0..nbranches-1]], -- union of both tree states
				\i -> 						 -- init nested state of first search
                                   let j = xpath i 0
                                   in initSubEvalState j (s1 0) (fs1 0)
                         )]
       , initH       = \i -> (local (const 0) $ inits (xsuper 0) (xpath i 0))
       , evalState_  = [("complete",Bool,const $ return true)] -- some global data
       , pushLeftH    = push pushLeft
       , pushRightH   = push pushRight
       , nextSameH    = \i -> let j = i `withBase` "popped_estate"
                             in do nd <- inSeq nextDiff i
                                   ns <- inSeq nextSame i
                                   return $ IfThenElse ((seq_pos i) @== (seq_pos j)) ns nd
       , nextDiffH    = inSeq $ nextDiff
       , bodyH       = \i -> 
                                let seqBody super j pos = 
                                      do
                                        dr <- dec_ref "bodyE-stmt" j i pos
                                        bodyE super (j `onAbort` (comment "seqLoopN.bodyE" >>> dr))
                                    in do cb <- mapM (\x -> canBranch x >>= \b -> return (if b then 1 else 0)) {- (const $ return 1) -} lst
                                          let cu n | n==nbranches = 0
                                              cu n                = (cb!!n) + cu (n+1)
                                          ss <- mapM (\pos -> local (const $ fromIntegral pos) $ inSeq_ seqBody i) [0..nbranches-1]
                                          let cc n | n==nbranches = Skip
                                              cc n | cu n <= 1   = if ((cb !! n) == 1) then (ss !! n) else cc (n+1)
                                              cc n | otherwise      = IfThenElse (seq_pos i @== fromIntegral n) (ss !! n) (cc (n+1))
                                          return $ cc 0
       , addH        = inSeq $ addE
       , failH       = \i -> inSeq_ (\super j pos -> failE super j @>>>@ (dec_ref "failE" j i pos)) i
       , returnH     = \i -> numSwitch (\n -> if (n<nbranches-1)
                                                    then do let j1 = xpath i n
                                                                j2o = xpath i (n+1)
                                                            dr <- dec_ref "returnE-j2A" j2o i (n+1)
                                                            let j2 = j2o `onCommit` dr
                                                                j2b = resetCommit j2
				 	                    action <- local (const $ n+1) $ do stmt1 <- inits (xsuper (n+1)) j2b
                                                                                               stmt2 <- startTryE (xsuper (n+1)) j2b
                                                                                               init <- initSubEvalState j2b (s1 $ n+1) (fs1 $ n+1)
                                                                                               dr2 <- dec_ref "returnE-j1" j1 i n
					                                                       return (    comment ("Switching from branch" ++ show n ++ " to branch" ++ show (n+1))
                                                                                                           >>> dr2
                                                                                                           >>> (seq_pos i <== fromIntegral (n+1))
                                                                                                           >>> init >>> stmt1 >>> stmt2)
                                                            returnE (xsuper n) $ j1 `withCommit` const action
                                                    else do let j2o  = xpath i n
                                                            dr3 <- dec_ref "returnE-j2B" j2o i n
                                                            let j2 = j2o `onCommit` dr3
                                                            returnE (xsuper n) j2
                                          )
--       , continue    = \_ -> return true
       , tryH        = \i -> inSeq_ (\super j pos -> do { dr <- dec_ref "tryE" j i pos; return (comment "seqLoop.tryE(a)") @>>>@ tryE  super (j `onAbort` (comment "seqLoop.tryE(b)" >>> dr))}) i
       , startTryH   = \i -> local (const 0) $ inSeq_ (\super j pos -> do { dr <- dec_ref "startTryE" j i pos; return (comment "seqLoop.startTryE(a)") @>>>@ startTryE super (j `onAbort` (comment "seqLoop.startTryE(b)" >>> dr))}) i
       , tryLH       = \i -> inSeq_ (\super j pos -> tryE_ super j @>>>@ (dec_ref "tryE_" j i pos)) i
       , intArraysE  = foldr1 (++) $ map (intArraysE) lst
       , boolArraysE  = foldr1 (++) $ map (boolArraysE) lst
       , intVarsE    = foldr1 (++) $ map (intVarsE) lst
       , deleteH     = deleteMe
       , canBranch   = do res <- mapM (canBranch) lst
                          return $ or res
       , complete = \i -> return $ estate i @=> "complete"
--       , complete = const $ return false
       }
  where nbranches = length lst
        xsuper i = lst !! i
        mystructs = (catMaybes (map s1 [0..nbranches-1]),map s3 [0..nbranches-1])
	evalStruct side super = Just $ -- if (length (evalState_ super) == 0) then Nothing else Just $
			Struct (side ++ "EvalState"  ++ show uid) $ 
--				(Bool, "cont") :				-- continue or not with this search 
				(Int, "ref_count") : 				-- how many active nodes of this search
				[(ty, field) | (field,ty,_) <- evalState_ super] -- fields of this search
--        needSide = \pos stm -> if (length (evalState_ (xsuper pos)) == 0) then Skip else stm
        needSide pos stm = stm
        s1 i      = evalStruct ("Seq" ++ show i) (xsuper i)
        et i      = maybe (THook "void") (Pointer . SType) $ s1 i
        s3 i      = Struct ("Seq" ++ show i ++ "TreeState" ++ show uid) $ (case s1 i of { Nothing -> id; Just s -> ((Pointer $ SType s, "evalState"):) }) [(ty, field) | (field,ty,_) <- treeState_ (xsuper i)]
        st i      = Pointer . SType $ s3 i
        xpath i n = flip withClone (\i -> inc (ref_count i)) $ withPath i (inN n) (et n) (st n)
        fs1 n     = \i -> [(field,init) | (field,_ty,init) <- evalState_ (xsuper n) ]
        fs3 n     = \i -> [(field,init) | (field,_ty,init) <- treeState_ (xsuper n) ]
        withSeq f = numSwitch (\n -> f (xsuper n) (inN n))
        inSeq f   = \i -> numSwitch (\n -> f (xsuper n) (xpath i n))
        inSeq_ f  = \i -> numSwitch (\n -> f (xsuper n) (xpath i n) n)
        push dir  = \i -> inSeq_ ( \super j pos -> dir super (j `onCommit` (mkCopy i "seqn_pos"
                                                                            >>> needSide pos (mkCopy j "evalState")
                                                                            >>> needSide pos (inc (ref_count j))
                                                                           )
                                                             )
                                 ) i
        initSubEvalState = \j s fs -> (case s of { Nothing -> return Skip; Just ss -> return (    (estate j <== New ss)
				              >>> (ref_count j <== 1)
--			                      >>> (cont j <== true)
                                             )})
                                        @>>>@ inite (fs j) j
	deleteMe = \i -> inSeq_ (\super j pos -> do delrest <- deleteE super j
                                                    dr <- dec_ref "deleteMe" j i pos
                                                    return (delrest >>> dr)) i
--        dec_ref :: String -> Info -> Info -> Int -> Statement
        dec_ref s j i pos = complete (xsuper pos) j >>= \compl -> decrefx j pos (estate_type i,estate i) (estate_type j,estate j) (ref_count_type, ref_count j) (THook "bool", compl)
        decrefx j pos = memo "dec_ref_and" j (\(_,esti) (_,estj) (_,rcj) (_,xcl) -> return $ ((assign ((esti @=> "complete") &&& (xcl))) (esti @=> "complete") >>> 
                            needSide pos (dec (rcj) >>> ifthen (rcj @== 0) (Delete (estj)))) {- >>> DebugValue ("completeness and" ++ show uid) (esti @=> "complete") -})
	inN n     = \state -> state @-> "seqn_union" @-> ("seq" ++ show n)
	seq_pos   = \i -> tstate i @-> "seqn_pos"


andN [] = dummy
andN [s] = s
andN s =
  let sc = buildCombiner s
      in case sc of 
        SearchCombiner { runner = runner, elems = elems } ->
          Search { mkeval = \super -> do { ss <- extractCombiners elems $ mapE (L . mmap runL . runL) super
                                         ; uid <- get
                                         ; put $ uid+1
                                         ; return $ mapE (L . mmap L . runL) $ memoLoop $ seqNLoop uid ss
                                         }
                 , runsearch = runner . rReaderT 0 . runL
                 }

a <&> b = andN [a,b]