module Data.BffMono.Base where 
import Data.Traversable hiding (mapM)
import Data.Foldable (Foldable)
import qualified Data.Foldable as Foldable 
import Data.BffMono.CheckHistory 
import Data.BffMono.EquivMap (EquivMap) 
import qualified Data.BffMono.EquivMap as EM 
import Data.BffMono.EquivWitness (EquivWitness)
import qualified Data.BffMono.EquivWitness as EW
import Control.Monad.Error 
import Control.Monad.State 
import Control.Monad.Identity 
class Pack conc abs | abs -> conc where 
    new :: conc -> abs 
class (Pack conc abs, Monad m, Functor m) => 
       PackM conc abs m where
    liftO :: Eq r => ([conc] -> r) -> ([abs] -> m r)
    
    
    eqSync :: Eq conc => abs -> abs -> m Bool 
    
    compareSync :: Ord conc => abs -> abs -> m Ordering 
    
    
liftO1 :: (PackM conc abs m, Eq r) => (conc -> r) -> abs -> m r 
liftO1 f x = liftO (\[a] -> f a) [x]
liftO2 :: (PackM conc abs m, Eq r) 
          => (conc -> conc -> r) -> abs -> abs -> m r 
liftO2 f x y = liftO (\[a,b] -> f a b) [x,y]
data Location = InSource Int | InTrans
                deriving (Show, Eq, Ord)
data Loc a = Loc { body :: a, location :: Location }
           deriving (Show, Eq, Ord)
type Update a = EquivMap Int a 
update :: MonadState (Update a) m => Loc a -> m (Loc a)
update (Loc a InTrans)      = return $ Loc a InTrans 
update (Loc a (InSource i)) = 
    do { r <- EM.lookupM i
       ; case r of 
           Nothing -> return $ Loc a (InSource i)
           Just b  -> return $ Loc b (InSource i)}
assignIDs :: Traversable f => f a -> f (Loc a) 
assignIDs t = 
    evalState (traverse f t) 0
        where
          f x = do { i <- get
                   ; put (i+1)
                   ; return $ Loc x (InSource i) }
errMsgInconsistent :: Error e => e 
errMsgInconsistent = strMsg "Inconsistent Update!"
errMsgConstant :: Error e => e 
errMsgConstant = strMsg "Update on Constant!"
 
matchViews :: (Eq a,Functor f,Foldable f, Eq (f ()), MonadError e m, Error e)
              => f (Loc a) -> f a -> EquivWitness Int -> m (Update a) 
matchViews xview view equiv =
    if isShapeEqual xview view then 
        makeUpd (EW.emptyMap equiv) $ filter hasUpdated
                    $ zip (Foldable.toList xview) (Foldable.toList view)
    else
        throwError $ strMsg "Shape Mismatch!"
    where
      hasUpdated (Loc x _, y) = x /= y 
      makeUpd upd [] = return upd 
      makeUpd upd ((Loc _ InTrans,y):ps) = throwError errMsgConstant 
      makeUpd upd ((Loc _ (InSource i), y):ps) =
          case EM.lookup i upd of 
            (Just z, upd') -> 
                if z == y then 
                    makeUpd upd' ps 
                else
                    throwError errMsgInconsistent 
            (Nothing, upd') -> 
                makeUpd (EM.insert i y upd) ps 
      isShapeEqual :: (Functor f, Eq (f ())) => f a -> f b -> Bool 
      isShapeEqual x y =  void x == void y 
          
          
newtype New a = New {runNew :: a} 
instance Functor New where 
    fmap f (New a) = New $ f a 
instance Pack a (New a) where 
    new = New 
instance PackM a (New a) Identity where 
    liftO obs xs = return $ obs (map runNew xs)
    eqSync x y  = return $ runNew x == runNew y 
    compareSync x y = return $ runNew x `compare` runNew y 
                 
instance Pack a (Loc a) where 
    new a = Loc a InTrans 
type B a = State ([CheckResult (Loc a)], EquivWitness Int)
unB :: B a b -> (b, [CheckResult (Loc a)], EquivWitness Int)
unB m =
    let (x,(h,t)) = runState m ([], EW.empty)
    in (x,h,t)
instance PackM a (Loc a) (B a) where 
    liftO obs xs = do { modify (\(h,t) -> (CheckResult obs' xs (obs' xs):h,t))
                      ; return $ obs' xs }
        where
          obs' xs = obs (map body xs)
    eqSync x y 
        | body x == body y, InSource i <- location x, InSource j <- location y = 
              do { modify (\(h,t) -> (h, EW.equate i j t))
                 ; return True }
        | otherwise = liftO2 (==) x y 
    compareSync x y 
        |  EQ <- compare x y, InSource i <- location x, InSource j <- location y =
              do { modify (\(h,t) -> (h, EW.equate i j t))
                 ; return EQ }
        | otherwise = liftO2 compare x y 
          
                          
bwd :: (Eq (vf ()), Traversable vf, Traversable sf, Eq c,
        MonadError e n, Error e) =>
       (forall a m. (PackM c a m) => sf a -> m (vf a)) ->
           sf c -> vf c -> n (sf c)
bwd pget src view =
        do { upd <- matchViews xview view equiv 
           ; let (b,upd') = runState (checkHistory update hist) upd 
           ; if b then 
                 let u x = evalState (update x) upd' 
                 in return $ fmap (body . u) xsrc 
             else
                 throwError $ strMsg "Violated Invariants"}
    where
      xsrc = assignIDs src 
      (xview, hist, equiv) = unB' (pget xsrc) 
      
      unB' = unB :: B c (sf (Loc c))
                    -> (sf (Loc c), 
                        [CheckResult (Loc c)], 
                        EquivWitness Int) 
fwd :: (Traversable vf, Traversable sf) =>
       (forall a m. (PackM c a m) => sf a -> m (vf a)) ->
           sf c -> vf c
fwd pget src =
    let Identity r = pget $ fmap New src 
    in fmap runNew r