{-|
Module      : Idris.Core.Constraints
Description : Check universe constraints.
Copyright   :
License     : BSD3
Maintainer  : The Idris Community.
-}
module Idris.Core.Constraints ( ucheck ) where

import Idris.Core.TT (ConstraintFC(..), Err'(..), FC(..), TC(..),
                      UConstraint(..), UExp(..))

import Control.Applicative
import Control.Monad.State.Strict
import Data.List (partition)
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Debug.Trace

-- | Check that a list of universe constraints can be satisfied.
ucheck :: S.Set ConstraintFC -> TC ()
ucheck = void . solve 10 . S.filter (not . ignore) . dropUnused
    where
        -- TODO: remove the first ignore clause once Idris.Core.Binary:598 is dealt with
        ignore (ConstraintFC c _) | any (== Var [] (-1)) (varsIn c) = True
        ignore (ConstraintFC (ULE a b) _) = a == b
        ignore _ = False

dropUnused :: S.Set ConstraintFC -> S.Set ConstraintFC
dropUnused xs = let cs = S.toList xs
                    onlhs = countLHS M.empty cs in
                    addIfUsed S.empty onlhs cs
  where
    -- Count the number of times a variable occurs on the LHS of a constraint
    countLHS ms [] = ms
    countLHS ms (c : cs) = let lhvar = getLHS (uconstraint c)
                               num = case M.lookup lhvar ms of
                                          Nothing -> 1
                                          Just v -> v + 1 in
                               countLHS (M.insert lhvar num ms) cs

    -- Only keep a constraint if the variable on the RHS is used elsewhere
    -- on the LHS of a constraint
    addIfUsed cs' lhs [] = cs'
    addIfUsed cs' lhs (c : cs)
         = let rhvar = getRHS (uconstraint c) in
               case M.lookup rhvar lhs of
                    Nothing -> addIfUsed cs' lhs cs
                    Just v -> addIfUsed (S.insert c cs') lhs cs

    getLHS (ULT x _) = x
    getLHS (ULE x _) = x

    getRHS (ULT _ x) = x
    getRHS (ULE _ x) = x


data Var = Var String Int
    deriving (Eq, Ord, Show)

data Domain = Domain Int Int
    deriving (Eq, Ord, Show)

data SolverState =
    SolverState
        { queue       :: Queue
        , domainStore :: M.Map Var ( Domain
                                   , S.Set ConstraintFC        -- constraints that effected this variable
                                   )
        , cons_lhs    :: M.Map Var (S.Set ConstraintFC)
        , cons_rhs    :: M.Map Var (S.Set ConstraintFC)
        }

data Queue = Queue [ConstraintFC] (S.Set UConstraint)


solve :: Int -> S.Set ConstraintFC -> TC (M.Map Var Int)
solve maxUniverseLevel ucs =
    evalStateT (propagate >> extractSolution) initSolverState

    where
        inpConstraints = S.toAscList ucs

        -- | initial solver state.
        --   the queue contains all constraints, the domain store contains the initial domains.
        initSolverState :: SolverState
        initSolverState =
            let
                (initUnaryQueue, initQueue) = partition (\ c -> length (varsIn (uconstraint c)) == 1) inpConstraints
            in
                SolverState
                    { queue = Queue (initUnaryQueue ++ initQueue) (S.fromList (map uconstraint (initUnaryQueue ++ initQueue)))
                    , domainStore = M.fromList
                        [ (v, (Domain 0 maxUniverseLevel, S.empty))
                        | v <- ordNub [ v
                                      | ConstraintFC c _ <- inpConstraints
                                      , v <- varsIn c
                                      ]
                        ]
                    , cons_lhs = constraintsLHS
                    , cons_rhs = constraintsRHS
                    }

        lhs (ULT (UVar ns x) _) = Just (Var ns x)
        lhs (ULE (UVar ns x) _) = Just (Var ns x)
        lhs _ = Nothing

        rhs (ULT _ (UVar ns x)) = Just (Var ns x)
        rhs (ULE _ (UVar ns x)) = Just (Var ns x)
        rhs _ = Nothing

        -- | a map from variables to the list of constraints the variable occurs in. (in the LHS of a constraint)
        constraintsLHS :: M.Map Var (S.Set ConstraintFC)
        constraintsLHS = M.fromListWith S.union
            [ (v, S.singleton (ConstraintFC c fc))
            | (ConstraintFC c fc) <- inpConstraints
            , let vars = varsIn c
            , length vars > 1               -- do not register unary constraints
            , v <- vars
            , lhs c == Just v
            ]

        -- | a map from variables to the list of constraints the variable occurs in. (in the RHS of a constraint)
        constraintsRHS :: M.Map Var (S.Set ConstraintFC)
        constraintsRHS = M.fromListWith S.union
            [ (v, S.singleton (ConstraintFC c fc))
            | (ConstraintFC c fc) <- inpConstraints
            , let vars = varsIn c
            , length vars > 1               -- do not register unary constraints
            , v <- vars
            , rhs c == Just v
            ]

        -- | this is where the actual work is done.
        --   dequeue the first constraint,
        --   filter domains,
        --   update domains (possibly resulting in a domain wipe out),
        --   until the queue is empty.
        propagate :: StateT SolverState TC ()
        propagate = do
            mcons <- nextConstraint
            case mcons of
                Nothing -> return ()
                Just (ConstraintFC cons fc) -> do
                    case cons of
                        ULE a b -> do
                            Domain lowerA upperA <- domainOf a
                            Domain lowerB upperB <- domainOf b
                            when (upperB < upperA) $ updateUpperBoundOf (ConstraintFC cons fc) a upperB
                            when (lowerA > lowerB) $ updateLowerBoundOf (ConstraintFC cons fc) b lowerA
                        ULT a b -> do
                            Domain lowerA upperA <- domainOf a
                            Domain lowerB upperB <- domainOf b
                            let upperB_pred = pred upperB
                            let lowerA_succ = succ lowerA
                            when (upperB_pred < upperA) $ updateUpperBoundOf (ConstraintFC cons fc) a upperB_pred
                            when (lowerA_succ > lowerB) $ updateLowerBoundOf (ConstraintFC cons fc) b lowerA_succ
                    propagate

        -- | extract a solution from the state.
        extractSolution :: (MonadState SolverState m, Functor m) => m (M.Map Var Int)
        extractSolution = M.map (extractValue . fst) <$> gets domainStore

        extractValue :: Domain -> Int
        extractValue (Domain x _) = x

        -- | dequeue the first constraint.
        nextConstraint :: MonadState SolverState m => m (Maybe ConstraintFC)
        nextConstraint = do
            Queue list set <- gets queue
            case list of
                [] -> return Nothing
                (q:qs) -> do
                    modify $ \ st -> st { queue = Queue qs (S.delete (uconstraint q) set) }
                    return (Just q)

        -- | look up the domain of a variable from the state.
        --   for convenience, this function also accepts UVal's and returns a singleton domain for them.
        domainOf :: MonadState SolverState m => UExp -> m Domain
        domainOf (UVar ns var) = gets (fst . (M.! Var ns var) . domainStore)
        domainOf (UVal val) = return (Domain val val)

        asPair :: Domain -> (Int, Int)
        asPair (Domain x y) = (x, y)

        updateUpperBoundOf :: ConstraintFC -> UExp -> Int -> StateT SolverState TC ()
        updateUpperBoundOf suspect (UVar ns var) upper = do
            doms <- gets domainStore
            let (oldDom@(Domain lower _), suspects) = doms M.! Var ns var
            let newDom = Domain lower upper
            when (wipeOut newDom) $
              lift $ Error $
                UniverseError (ufc suspect) (UVar ns var)
                              (asPair oldDom) (asPair newDom)
                              (suspect : S.toList suspects)
            modify $ \ st -> st { domainStore = M.insert (Var ns var) (newDom, S.insert suspect suspects) doms }
            addToQueueRHS (uconstraint suspect) (Var ns var)
        updateUpperBoundOf _ UVal{} _ = return ()

        updateLowerBoundOf :: ConstraintFC -> UExp -> Int -> StateT SolverState TC ()
        updateLowerBoundOf suspect (UVar ns var) lower = do
            doms <- gets domainStore
            let (oldDom@(Domain _ upper), suspects) = doms M.! Var ns var
            let newDom = Domain lower upper
            when (wipeOut newDom) $
              lift $ Error $
                UniverseError (ufc suspect) (UVar ns var)
                              (asPair oldDom) (asPair newDom)
                              (suspect : S.toList suspects)
            modify $ \ st -> st { domainStore = M.insert (Var ns var) (newDom, S.insert suspect suspects) doms }
            addToQueueLHS (uconstraint suspect) (Var ns var)
        updateLowerBoundOf _ UVal{} _ = return ()

        -- | add all constraints (with the given var on the lhs) to the queue
        addToQueueLHS :: MonadState SolverState m => UConstraint -> Var -> m ()
        addToQueueLHS thisCons var = do
            clhs <- gets cons_lhs
            case M.lookup var clhs of
                Nothing -> return ()
                Just cs -> do
                    Queue list set <- gets queue
                    let set' = S.insert thisCons set
                    let newCons = [ c | c <- S.toList cs, uconstraint c `S.notMember` set' ]
                    if null newCons
                        then return ()
                        else modify $ \ st -> st { queue = Queue (list ++ newCons)
                                                                 (S.union set (S.fromList (map uconstraint newCons))) }

        -- | add all constraints (with the given var on the rhs) to the queue
        addToQueueRHS :: MonadState SolverState m => UConstraint -> Var -> m ()
        addToQueueRHS thisCons var = do
            crhs <- gets cons_rhs
            case M.lookup var crhs of
                Nothing -> return ()
                Just cs -> do
                    Queue list set <- gets queue
                    let set' = S.insert thisCons set
                    let newCons = [ c | c <- S.toList cs, uconstraint c `S.notMember` set' ]
                    if null newCons
                        then return ()
                        else modify $ \ st -> st { queue = Queue (list ++ newCons)
                                                                 (insertAll (map uconstraint newCons) set) }

        insertAll [] s = s
        insertAll (x : xs) s = insertAll xs (S.insert x s)

        -- | check if a domain is wiped out.
        wipeOut :: Domain -> Bool
        wipeOut (Domain l u) = l > u

ordNub :: Ord a => [a] -> [a]
ordNub = S.toList . S.fromList

-- | variables in a constraint
varsIn :: UConstraint -> [Var]
varsIn (ULT a b) = [ Var ns v | UVar ns v <- [a,b] ]
varsIn (ULE a b) = [ Var ns v | UVar ns v <- [a,b] ]