module Theory.Tools.EquationStore (
  
    SplitId(..)
  , EqStore(..)
  , emptyEqStore
  , eqsSubst
  , eqsConj
  
  , falseEqConstrConj
  
  , eqsIsFalse
  
  , addEqs
  , addRuleVariants
  , addDisj
  
  , performSplit
  , dropNameHintsBound
  , splits
  , splitSize
  , splitExists
  
  , simp
  , simpDisjunction
  
  , prettyEqStore
) where
import           Logic.Connectives
import           Term.Unification
import           Theory.Text.Pretty
import           Control.Monad.Fresh
import           Control.Monad.Bind
import           Control.Monad.Reader
import           Extension.Prelude
import           Utils.Misc
import           Debug.Trace.Ignore
import           Control.Basics
import           Control.DeepSeq
import           Control.Monad.State   hiding (get, modify, put)
import qualified Control.Monad.State   as MS
import           Data.Binary
import           Data.DeriveTH
import qualified Data.Foldable         as F
import           Data.List
import           Data.Maybe
import qualified Data.Set              as S
import           Extension.Data.Label  hiding (for, get)
import qualified Extension.Data.Label  as L
import           Extension.Data.Monoid
newtype SplitId = SplitId { unSplitId :: Integer }
  deriving( Eq, Ord, Show, Enum, Binary, NFData, HasFrees )
data EqStore = EqStore {
      _eqsSubst       :: LNSubst
    , _eqsConj        :: Conj (SplitId, S.Set LNSubstVFresh)
    , _eqsNextSplitId :: SplitId
    }
  deriving( Eq, Ord )
$(mkLabels [''EqStore])
emptyEqStore :: EqStore
emptyEqStore = EqStore emptySubst (Conj []) (SplitId 0)
eqsIsFalse :: EqStore -> Bool
eqsIsFalse = any ((S.empty == ) . snd) . getConj . L.get eqsConj
falseEqConstrConj :: Conj (SplitId, S.Set LNSubstVFresh)
falseEqConstrConj = Conj [ (SplitId (1), S.empty) ]
dropNameHintsBound :: EqStore -> EqStore
dropNameHintsBound = modify eqsConj (Conj . map (second (S.map dropNameHintsLNSubstVFresh)) . getConj)
dropNameHintsLNSubstVFresh :: LNSubstVFresh -> LNSubstVFresh
dropNameHintsLNSubstVFresh subst =
    substFromListVFresh $ zip (map fst slist)
                              ((`evalFresh` nothingUsed) . (`evalBindT` noBindings) $ renameDropNamehint (map snd slist))
  where slist = substToListVFresh subst
instance Apply SplitId where
    apply _ = id
instance HasFrees EqStore where
    foldFrees f (EqStore subst substs nextSplitId) =
        foldFrees f subst <> foldFrees f substs <> foldFrees f nextSplitId
    foldFreesOcc  _ _ = const mempty
    mapFrees f (EqStore subst substs nextSplitId) =
        EqStore <$> mapFrees f subst
                <*> mapFrees f substs
                <*> mapFrees f nextSplitId
falseDisj :: S.Set LNSubstVFresh
falseDisj = S.empty
splits :: EqStore -> [SplitId]
splits eqs = map fst $ nub $ sortOn snd
    [ (idx, S.size conj) | (idx, conj) <- getConj $ L.get eqsConj eqs ]
splitExists :: EqStore -> SplitId -> Bool
splitExists eqs = isJust . splitSize eqs
splitSize :: EqStore -> SplitId -> Maybe Int
splitSize eqs sid =
    (S.size . snd) <$> (find ((sid ==) . fst) $ getConj $ L.get eqsConj $ eqs)
addDisj :: EqStore -> (S.Set LNSubstVFresh) -> (EqStore, SplitId)
addDisj eqStore disj =
    (   modify eqsConj ((Conj [(sid, disj)]) `mappend`)
      $ modify eqsNextSplitId succ
      $ eqStore
    , sid
    )
  where
    sid = L.get eqsNextSplitId eqStore
performSplit :: EqStore -> SplitId -> Maybe [EqStore]
performSplit eqStore idx =
    case break ((idx ==) . fst) (getConj $ L.get eqsConj eqStore) of
        (_, [])                   -> Nothing
        (before, (_, disj):after) -> Just $
            mkNewEqStore before after <$> S.toList disj
  where
    mkNewEqStore before after subst =
        fst $ addDisj (set eqsConj (Conj (before ++ after)) eqStore)
                      (S.singleton subst)
addEqs :: MonadFresh m
       => MaudeHandle -> [Equal LNTerm] -> EqStore -> m (EqStore, Maybe SplitId)
addEqs hnd eqs0 eqStore =
    case unifyLNTermFactored eqs `runReader` hnd of
        (_, []) ->
            return (set eqsConj falseEqConstrConj eqStore, Nothing)
        (subst, [substFresh]) | substFresh == emptySubstVFresh ->
            return (applyEqStore hnd subst eqStore, Nothing)
        (subst, substs) -> do
            let (eqStore', sid) = addDisj (applyEqStore hnd subst eqStore)
                                          (S.fromList substs)
            return (eqStore', Just sid)
            
  where
    eqs = apply (L.get eqsSubst eqStore) $ trace (unlines ["addEqs: ", show eqs0]) $ eqs0
    
applyEqStore :: MaudeHandle -> LNSubst -> EqStore -> EqStore
applyEqStore hnd asubst eqStore
    | dom asubst `intersect` varsRange asubst /= [] || trace (show ("applyEqStore", asubst, eqStore)) False
    = error $ "applyEqStore: dom and vrange not disjoint for `"++show asubst++"'"
    | otherwise
    = modify eqsConj (fmap (second (S.fromList . concatMap applyBound  . S.toList))) $
          set eqsSubst newsubst eqStore
  where
    newsubst = asubst `compose` L.get eqsSubst eqStore
    applyBound s = map (restrictVFresh (varsRange newsubst ++ domVFresh s)) $
        (`runReader` hnd) $ unifyLNTerm
          [ Equal (apply newsubst (varTerm lv)) t
          | let slist = substToListVFresh s,
            
            
            
            let ran = renameAvoiding (map snd slist)
                                     (domVFresh s ++ varsRange newsubst),
            (lv,t) <- zip (map fst slist) ran
          ]
addRuleVariants :: Disj LNSubstVFresh -> EqStore -> (EqStore, SplitId)
addRuleVariants (Disj substs) eqStore
    | dom freeSubst `intersect` concatMap domVFresh substs /= []
    = error $ "addRuleVariants: Nonempty intersection between domain of variants and free substitution. "
              ++"This case has not been implemented, add rule variants earlier."
    | otherwise = addDisj eqStore (S.fromList substs)
  where
    freeSubst = L.get eqsSubst eqStore
simpDisjunction :: MonadFresh m
                => MaudeHandle
                -> (LNSubst -> LNSubstVFresh -> Bool)
                -> Disj LNSubstVFresh
                -> m (LNSubst, Maybe [LNSubstVFresh])
simpDisjunction hnd isContr disj0 = do
    eqStore' <- simp hnd isContr eqStore
    return (L.get eqsSubst eqStore', wrap $ L.get eqsConj eqStore')
  where
    eqStore = fst $ addDisj emptyEqStore (S.fromList $ getDisj $ disj0)
    wrap (Conj [])          = Nothing
    wrap (Conj [(_, disj)]) = Just $ S.toList disj
    wrap conj               =
        error ("simplifyDisjunction: imposible, unexpected conjunction `"
               ++ show conj ++ "'")
simp :: MonadFresh m => MaudeHandle -> (LNSubst -> LNSubstVFresh -> Bool) -> EqStore -> m EqStore
simp hnd isContr eqStore =
    execStateT (whileTrue (simp1 hnd isContr))
               (trace (show ("eqStore", eqStore)) eqStore)
simp1 :: MonadFresh m => MaudeHandle -> (LNSubst -> LNSubstVFresh -> Bool) -> StateT EqStore m Bool
simp1 hnd isContr = do
    eqs <- MS.get
    if eqsIsFalse eqs
        then return False
        else do
          b1 <- simpMinimize (isContr (L.get eqsSubst eqs))
          b2 <- simpRemoveRenamings
          b3 <- simpEmptyDisj
          b4 <- foreachDisj hnd simpSingleton
          b5 <- foreachDisj hnd simpAbstractSortedVar
          b6 <- foreachDisj hnd simpIdentify
          b7 <- foreachDisj hnd simpAbstractFun
          b8 <- foreachDisj hnd simpAbstractName
          (trace (show ("simp:", [b1, b2, b3, b4, b5, b6, b7, b8]))) $
              return $ (or [b1, b2, b3, b4, b5, b6, b7, b8])
simpRemoveRenamings :: MonadFresh m => StateT EqStore m Bool
simpRemoveRenamings = do
    conj <- gets (L.get eqsConj)
    if F.any (S.foldl' (\b subst -> b || domVFresh subst /= domVFresh (removeRenamings subst)) False . snd) conj
      then modM eqsConj (fmap (second $ S.map removeRenamings)) >> return True
      else return False
simpEmptyDisj :: MonadFresh m => StateT EqStore m Bool
simpEmptyDisj = do
    conj <- getM eqsConj
    if (F.any ((== falseDisj) . snd) conj && conj /= falseEqConstrConj)
      then eqsConj =: falseEqConstrConj >> return True
      else return False
simpSingleton :: MonadFresh m
              => [LNSubstVFresh]
              -> m (Maybe (Maybe LNSubst, [S.Set LNSubstVFresh]))
simpSingleton [subst0] = do
        subst <- freshToFree subst0
        return (Just (Just subst, []))
simpSingleton _        = return Nothing
simpAbstractFun :: MonadFresh m
                => [LNSubstVFresh]
                -> m (Maybe (Maybe LNSubst, [S.Set LNSubstVFresh]))
simpAbstractFun []             = return Nothing
simpAbstractFun (subst:others) = case commonOperators of
    [] -> return Nothing
    
    (v, o, argss@(args:_)):_ | all ((==length args) . length) argss -> do
        fvars <- mapM (\_ -> freshLVar "x" LSortMsg) args
        let substs' = zipWith (abstractAll v fvars) (subst:others) argss
            fsubst  = substFromList [(v, fApp o (map varTerm fvars))]
        return $ Just (Just fsubst, [S.fromList substs'])
    
    (v, o@(AC _), argss):_ -> do
        fv1 <- freshLVar "x" LSortMsg
        fv2 <- freshLVar "x" LSortMsg
        let substs' = zipWith (abstractTwo o v fv1 fv2) (subst:others) argss
            fsubst  = substFromList [(v, fApp o (map varTerm [fv1,fv2]))]
        return $ Just (Just fsubst, [S.fromList substs'])
    (_, _ ,_):_ ->
        error "simpAbstract: impossible, invalid arities or List operator encountered."
  where
    commonOperators = do
        (v, viewTerm -> FApp o args) <- substToListVFresh subst
        let images = map (\s -> imageOfVFresh s v) others
            argss  = [ args' | Just (viewTerm -> FApp o' args') <- images, o' == o ]
        guard (length argss == length others)
        return (v, o, args:argss)
    abstractAll v freshVars s args = substFromListVFresh $
        filter ((/= v) . fst) (substToListVFresh s) ++ zip freshVars args
    abstractTwo o v fv1 fv2 s args = substFromListVFresh $
        filter ((/= v) . fst) (substToListVFresh s) ++ newMappings args
      where
        newMappings []      =
            error "simpAbstract: impossible, AC symbols must have arity >= 2."
        newMappings [a1,a2] = [(fv1, a1), (fv2, a2)]
        
        
        newMappings (a:as)  = [(fv1, a),  (fv2, fApp o as)]
simpAbstractName :: MonadFresh m
                 => [LNSubstVFresh]
                 -> m (Maybe (Maybe LNSubst, [S.Set LNSubstVFresh]))
simpAbstractName []             = return Nothing
simpAbstractName (subst:others) = case commonNames of
    []           -> return Nothing
    (v, c):_     ->
        return $ Just (Just $ substFromList [(v, c)]
                      , [S.fromList (map (\s -> restrictVFresh (delete v (domVFresh s)) s) (subst:others))])
  where
    commonNames = do
        (v, c@(viewTerm -> Lit (Con _))) <- substToListVFresh subst
        let images = map (\s -> imageOfVFresh s v) others
        guard (length images == length [ () | Just c' <- images, c' == c])
        return (v, c)
simpAbstractSortedVar :: MonadFresh m
                      => [LNSubstVFresh]
                      -> m (Maybe (Maybe LNSubst, [S.Set LNSubstVFresh]))
simpAbstractSortedVar []             = return Nothing
simpAbstractSortedVar (subst:others) = case commonSortedVar of
    []            -> return Nothing
    (v, s, lvs):_ -> do
        fv <- freshLVar (lvarName v) s
        return $ Just (Just $ substFromList [(v, varTerm fv)]
                      , [S.fromList (zipWith (replaceMapping v fv) lvs (subst:others))])
  where
    commonSortedVar = do
        (v, (viewTerm -> Lit (Var lx))) <- substToListVFresh subst
        guard (sortCompare (lvarSort v)  (lvarSort lx) == Just GT)
        let images = map (\s -> imageOfVFresh s v) others
            
            
            goodImages = [ ly | Just (viewTerm -> Lit (Var ly)) <- images, lvarSort lx == lvarSort ly]
        guard (length images == length goodImages)
        return (v, lvarSort lx, (lx:goodImages))
    replaceMapping v fv lv sigma =
        substFromListVFresh $ (filter ((/=v) . fst) $ substToListVFresh sigma) ++ [(fv, varTerm lv)]
simpIdentify :: MonadFresh m
             => [LNSubstVFresh]
             -> m (Maybe (Maybe LNSubst, [S.Set LNSubstVFresh]))
simpIdentify []             = return Nothing
simpIdentify (subst:others) = case equalImgPairs of
    []         -> return Nothing
    ((v,v'):_) -> do
        let (vkeep, vremove) = case sortCompare (lvarSort v) (lvarSort v') of
                                 Just GT -> (v', v)
                                 Just _  -> (v, v')
                                 Nothing -> error $ "EquationStore.simpIdentify: impossible, variables with incomparable sorts: "
                                                    ++ show v ++" and "++ show v'
        return $ Just (Just  (substFromList [(vremove, varTerm vkeep)]),
                       [S.fromList (map (removeMappings [vkeep]) (subst:others))])
  where
    equalImgPairs = do
        (v,t)    <- substToListVFresh subst
        (v', t') <- substToListVFresh subst
        guard (t == t' && v < v' && all (agrees_on v v') others)
        return (v,v')
    agrees_on v v' s =
        imageOfVFresh s v == imageOfVFresh s v' && isJust (imageOfVFresh s v)
    removeMappings vs s = restrictVFresh (domVFresh s \\ vs) s
simpMinimize :: MonadFresh m => (LNSubstVFresh -> Bool) -> StateT EqStore m Bool
simpMinimize isContr = do
    conj <- MS.gets (L.get eqsConj)
    if F.any (F.any check . snd) conj
      then MS.modify (set eqsConj (fmap (second minimize) conj)) >> return True
      else return False
  where
    minimize substs
      | emptySubstVFresh `S.member` substs = S.singleton emptySubstVFresh
      | otherwise                          = S.filter (not . isContr) substs
    check subst = subst == emptySubstVFresh || isContr subst
foreachDisj :: MonadFresh m
            => MaudeHandle
            -> ([LNSubstVFresh] -> m (Maybe (Maybe LNSubst, [S.Set LNSubstVFresh])))
            -> StateT EqStore m Bool
foreachDisj hnd f =
    go [] =<< gets (getConj . L.get eqsConj)
  where
    go _     []               = return False
    go lefts ((idx,d):rights) = do
        b <- lift $ f (S.toList d)
        case b of
          Nothing              -> go ((idx,d):lefts) rights
          Just (msubst, disjs) -> do
              eqsConj =: Conj (reverse lefts ++ ((,) idx <$> disjs) ++ rights)
              maybe (return ()) (\s -> MS.modify (applyEqStore hnd s)) msubst
              return True
prettyEqStore :: HighlightDocument d => EqStore -> d
prettyEqStore eqs@(EqStore substFree (Conj disjs) _nextSplitId) = vcat $
  [if eqsIsFalse eqs then text "CONTRADICTORY" else emptyDoc] ++
  map combine
    [ ("subst", vcat $ prettySubst (text . show) (text . show) substFree)
    , ("conj",  vcat $ map ppDisj disjs)
    ]
  where
    combine (header, d) = fsep [keyword_ header <> colon, nest 2 d]
    ppDisj (idx, substs) =
        text (show (unSplitId idx) ++ ".") <-> numbered' conjs
      where
        conjs  = map ppSubst $ S.toList substs
    ppEq (a,b) =
      prettyNTerm (lit (Var a)) $$ nest (6::Int) (opEqual <-> prettyNTerm b)
    ppSubst subst = sep
      [ hsep (opExists : map prettyLVar (varsRangeVFresh subst)) <> opDot
      , nest 2 $ fsep $ intersperse opLAnd $ map ppEq $ substToListVFresh subst
      ]
instance Show EqStore where
    show = render . prettyEqStore
$( derive makeBinary ''EqStore)
$( derive makeNFData ''EqStore)