module Language.Haskell.FreeTheorems.Variations.PolySeq.ConstraintSolver where

import Language.Haskell.FreeTheorems.Variations.PolySeq.PrettyPrint(prettyConstraint) --only for debugging
import Language.Haskell.FreeTheorems.Variations.PolySeq.Syntax
import Language.Haskell.FreeTheorems.Variations.PolySeq.AlgCommon
    ( collectOne
    , substLabel
    , getUsedLabels
    , getUsedExtraLabels
    , removeTrue
    )
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Generics(everywhere, mkQ, mkT)
import Language.Haskell.FreeTheorems.Variations.PolySeq.Debug
import Control.Monad(mplus)
import Data.Function

--debuging stuff
trace2 = trace_ignore
trace = trace_ignore
trace1 = trace_ignore

-- | simplifies the initial constraint and replaces variables in term and type
simplifyConstraint :: (Term,Typ,Constraint) -> (Term,Typ,Constraint)
simplifyConstraint (t',tau',c') =
    let (t,tau,c) = simpConstraint (t',tau',c') in
    (t,tau,checkContradiction (removeTrue c))

-- | subfunction of simplifyConstraint
simpConstraint :: (Term,Typ,Constraint) -> (Term,Typ,Constraint)
simpConstraint (t,tau,c') =
    let c = removeUseless c' in
    case findEq c of
      Just (Eq l1 l2)      -> trace "found Eq by findEq" (if l1 /= (LVal Epsilon) 
                              then simpConstraint (substLabel l2 l1 t, substLabel l2 l1 tau, substLabel l2 l1 c)
			      else simpConstraint (substLabel l1 l2 t, substLabel l1 l2 tau, substLabel l1 l2 c))
      Nothing ->
        case findGtEpsilon c of
        Just (Leq _ l)     -> trace "found GtEpsilon" (simpConstraint (substLabel (LVal Epsilon) l t, substLabel (LVal Epsilon) l tau, substLabel (LVal Epsilon) l c))
        Nothing ->
          case findLtNbr c of
            Just (Leq l _) -> trace "found LtNbr" (simpConstraint (substLabel (LVal Nbr) l t, substLabel (LVal Nbr) l tau, substLabel (LVal Nbr) l c))
	    Nothing        -> (t,tau,c)

-- | removes Equations and InEquations, that are always fulfilled
removeUseless :: Constraint -> Constraint
removeUseless = everywhere (mkT rmUseLess)

-- | generic function for removing trivial parts of the constraint
rmUseLess c =
    case c of
      Conj c1 c2 -> if useLess c1 then c2 else if useLess c2 then c1 else Conj c1 c2
      Impl (Eq (LVal (Epsilon)) (LVal (Epsilon))) c2 -> c2
      Impl (Eq (LVal (Nbr))     (LVal (Epsilon))) _  -> Tru
      _           -> if useLess c then Tru else c

-- | checks if a constraint is trivial
useLess c =
  case c of
    Eq  c1 c2 -> if c1 == c2 then True else False
    Leq _ (LVal (Epsilon))           -> True
    Leq (LVal (Nbr)) _               -> True
    Leq (LVar (LabVar i)) (LVar (LabVar j)) -> if i == j then True else False
    _                                       -> False

-- | returns the first Equation part of a constraint. Does not consider Equations in the body of an implication
findEq :: Constraint -> Maybe Constraint
findEq c =
    case c of
      Conj c1 c2 -> findEq c1 `mplus` findEq c2
      Eq _ _     -> Just c
      _          -> Nothing


-- | returns the first "epsilon <= x" inequation part of a constraint. Does consider Equations in the body of an implication
findGtEpsilon = collectOne (mkQ Nothing fndGtEpsilon)


-- | generic function for findGtEpsilon
fndGtEpsilon :: Constraint -> Maybe Constraint
fndGtEpsilon c =
  case c of
    Leq (LVal (Epsilon)) _ -> Just c
    _                      -> Nothing


-- | returns the first "x <= nbr" inequation part of a constraint. Does consider Equations in the body of an implication
findLtNbr = collectOne (mkQ Nothing fndLtNbr)


-- | generic function for findLtNbr
fndLtNbr :: Constraint -> Maybe Constraint
fndLtNbr c =
  case c of
    Leq _ (LVal (Nbr)) -> Just c
    _                  -> Nothing


-- | returns Fls if the constraint is has no solution, otherwise the whole constraint
checkContradiction :: Constraint -> Constraint
checkContradiction c = 
    case fndContr c of
      Just _  -> Fls
      Nothing -> c


-- | auxiliar function, used in checkContradiction
fndContr :: Constraint -> Maybe Constraint
fndContr c =
    case c of
      Eq  (LVal Nbr) (LVal Epsilon) -> Just c
      Eq  (LVal Epsilon) (LVal Nbr) -> Just c
      Leq (LVal Epsilon) (LVal Nbr) -> Just c
      Conj c1 c2                    -> fndContr c1 `mplus` fndContr c2
      _                             -> Nothing
				       

-- | the first component of the returned tuple are the label variable numbers of the instantiated labels and the second
--   component is an array of all possible instantiation. For each instantiation the order matches the order of the 
--   variables in the first component of the tuple.
solveConstraint :: Constraint -> ([Int],[[LabVal]])
solveConstraint c =
    let labConstr = getUsedLabels$c
	vars = List.sort labConstr
        vals = (map (snd.unzip.(Map.toAscList)) (slvConstr labConstr Map.empty c))
    in
    trace1 ("solveConstraint produces\n"++ show vars ++ "with values\n" ++ show vals ++ "n") (vars,vals)


-- | Used by solveConstraint. Kind of breadth-first search for valid variable instantiations, aborting whenever a
--   partial instantiation leads to a contradiction in the constraint.
--   The map keeps track of all already instantiated variables and the list as first argument are the still free
--   variables.
--   The final result is a list of all possible complete (first argument is []) instantiations in form of a map
--   with variable number as key and concrete LabVal (Epsilon or Nbr) as value.
slvConstr :: [Int] -> Map.Map Int LabVal -> Constraint -> [Map.Map Int LabVal]
slvConstr vars map c =
    case vars of
      []   -> [map]
      i:is -> let c1  = substLabel (LVal Nbr)     (LVar (LabVar i)) c
		  c2  = substLabel (LVal Epsilon) (LVar (LabVar i)) c in
	      solve Nbr c1 ++ solve Epsilon c2
              where 
              solve lab c' = trace1 ("starting solve with " ++ show (prettyConstraint c')) (
                case reduceConstraint [(i,lab)] c' of
                  Nothing -> []
                  Just (as,c'') ->
                    let map'   = Map.union map (Map.fromList as)
		        asvars = fst.unzip$as
                        vars'  = filter (flip notElem asvars) is
                    in
                    slvConstr vars' map' c'')


-- | takes a list of variable instantiations, allready done in the constraint, and the constraint. It adds further
--   variable instantiations, if possible, after reducing the constraint.
--   It is called recursively until no further instantiation is possible.
reduceConstraint :: [(Int,LabVal)] -> Constraint -> Maybe ([(Int,LabVal)],Constraint)
reduceConstraint as c' =
    let c = removeTrue.removeUseless$c' in
    if checkContradiction c == Fls then Nothing
    else
    case findEq c of
      Just (Eq l1 l2)      -> trace "found Eq by findEq"(
			      case l1 of
                              LVar (LabVar i) -> let LVal w = l2 in reduceConstraint ((i,w):as) (substLabel l2 l1 c)
			      LVal w -> let LVar (LabVar i) = l2 in reduceConstraint ((i,w):as) (substLabel l1 l2 c))
      Nothing ->
        case findGtEpsilon c of
        Just (Leq _ (LVar (LabVar i))) -> trace "found GtEpsilon" (		   
					  reduceConstraint ((i,Epsilon):as) (substLabel (LVal Epsilon) (LVar (LabVar i))  c))
	Just c'                        -> error ("strange Constraint: " ++ show c' ++ "\n")
        Nothing ->
          case findLtNbr c of
            Just (Leq (LVar (LabVar i)) _) -> trace "found LtNbr" (reduceConstraint ((i,Nbr):as) (substLabel (LVal Nbr) (LVar (LabVar i))  c))
	    Just c'                        -> error ("strange Constraint: " ++ show c' ++ "\n")
	    Nothing        -> case checkContradiction c of
			        Fls -> Nothing
                                _   -> Just (as,c)


-- | reduces the possible type instantiations to the variables that appear in the given term and type.
--   Also removing duplicates.
filterTermAndTyp :: Term -> Typ -> ([Int],[[LabVal]]) -> ([Int],[[LabVal]])
filterTermAndTyp t tau  res = 
    let labTerm = getUsedLabels t
        labTermTyp = getUsedExtraLabels labTerm tau
        (varList,resList) = filterLabVars labTermTyp res
    in 
    (varList,List.nub resList)


-- | reduces the possible type instantiations to the variables that appear in the given type.
--   Also removing duplicates.
filterTyp :: Typ -> ([Int],[[LabVal]]) -> ([Int],[[LabVal]])
filterTyp tau res = 
    let labTyp = getUsedLabels tau
        (varList,resList) = filterLabVars labTyp res
    in 
    (varList,List.nub resList)


-- | auxiliar function used by filterTermAndTyp and filterTyp
filterLabVars varList (vars,vals) =
    case vars of
      []   -> (vars,vals)
      x:xs -> let (vrs, vls) = filterLabVars varList (xs,map tail vals) in
              if x `elem` varList then
              (x:vrs,zipWith (:) (map head vals) vls)
	      else (vrs,vls)


-- | takes a type and an array of instantiations (second component of the input tuple) for a subset of label
--    variables present in that type (first component of the input variable) and returns a list of types,
--    all partly instantiated by the one of the given instantiations.
makeTypes :: Typ -> ([Int],[[LabVal]]) -> [Typ]
makeTypes tau (vars, res) =
    map (\x->substLabFromList tau (zip vars x)) res
    where substLabFromList tau' l =
            case l of
              []         -> tau'
              (i,val):xs -> substLabFromList (substLabel (LVal val) (LVar (LabVar i)) tau') xs


-- | takes a type and an array of instantiations (second component of the input tuple) for a subset of label
--    variables present in that type (first component of the input variable) and returns a list of types,
--    all completely instantiated and including one of the partial instantiation given, such that no other
--    complete instantiation, including one of the given partial instantiations as a subset of the logical relation.
--    as logical relation.
makeMinimalTypes :: Typ -> ([Int],[[LabVal]]) -> [Typ]
makeMinimalTypes tau (vars,valss) =
    let (setVarOpt,unsetVarOpt) = List.partition (\x->(fst x) `elem` vars) (getOptimal tau) --should still be sorted
        types = makeTypes tau (vars, 
			       getMinimal (snd.unzip$
					   (trace2 ("the setOptVars are: " ++ show setVarOpt ++ "\n\n") 
					    setVarOpt)) 
			                  valss)
    in
    concat (map (getOptTypes unsetVarOpt) types)
  where
    --lessEqual. returns True if the logical relation of vals1 is a subset of the logical relation of vals2
    leq opts vals1 vals2 = --smaller is better -> opts has the smallest logical relation
      case opts of
       [] -> True
       o:os -> let v1:vs1 = vals1
                   v2:vs2 = vals2
               in
               if o == Non then leq os vs1 vs2
               else if LVal v1 == o && LVal v2 /= o then False 
                    else leq os vs1 vs2
    --notGreaterEqual, returns false if the logical relation of vals2 is a superset of the logical relation of vals1
    --Note: Since we have only a partial order leq /= notgeq.
    notgeq opts vals1 vals2 = --smaller is better -> opts has the smallest logical relation
      case opts of
       [] -> False
       o:os -> let v1:vs1 = vals1
                   v2:vs2 = vals2
               in
               if o == Non then notgeq os vs1 vs2
               else if LVal v2 == o && LVal v1 /= o then True
                    else notgeq os vs1 vs2
    -- sorts out the optimal (minimal logical relation) from all instantiations produced by the constraint
    -- notice that this is for normally the instantiation of just a part of all label variables in the type.
    getMinimal opts valss =
      case valss of
        [] -> []
        x:xs -> case List.find (leq opts x) xs of
                Nothing -> x : getMinimal opts (filter (notgeq opts x) xs)
                Just _  -> getMinimal opts xs
    -- instantiates label variables with no previous restriction on it and returns a set of optimal instantiations.
    getOptTypes opts typ =
      case opts of
        [] -> [typ]
        (i,opt):xs -> case opt of
                        Non          -> (getOptTypes xs (substLabel (LVal Epsilon) (LVar (LabVar i)) typ))
					++ (getOptTypes xs (substLabel (LVal Nbr) (LVar (LabVar i)) typ))
                        LVal Nbr     -> getOptTypes xs (substLabel (LVal Nbr)     (LVar (LabVar i)) typ)
                        LVal Epsilon -> getOptTypes xs (substLabel (LVal Epsilon) (LVar (LabVar i)) typ)
                     

-- |returns the optimal (minimal logical relations) instantiation of a type in the form
--  LVal Epsilon, LVal Nbr or Non. Where Non stands for both, Epsilon and Nbr, would lead
--  to a minimal type.
--  Note that first there might be several optimal instantiations for the same label variable, which are
--  afterwards removed by resolveConflicts. This actually is the first time Non can appear.
getOptimal :: Typ -> [(Int,Label)]
getOptimal tau = resolveConflicts (List.sortBy  (compare  `on` fst) (getOpt True tau))
  where resolveConflicts l =
          case l of
            []   -> []
            [x]  -> [x]
            (i,l1):(j,l2):xs -> if i == j then
                                  if l1 /= l2 then resolveConflicts ((i,Non):xs)
                                  else resolveConflicts ((i,l1):xs)
                                else (i,l1):(resolveConflicts ((j,l2):xs))


-- | returns the optimal instantiation for a label variable. It is always LVal Nbr or LVal Epsilon.
--   the first parameter marks if the logical relation has to be minimised (True) or maximised (False)
getOpt :: Bool -> Typ -> [(Int,Label)]
getOpt min tau =
    case tau of
      TVar _                             -> []
      TArrow (LVar (LabVar i)) tau1 tau2 -> ((i,opt):(getOpt (not min) tau1)) ++ (getOpt min tau2)
      TArrow _ tau1 tau2                 -> (getOpt (not min) tau1) ++ (getOpt min tau2)
      TAll (LVar (LabVar i)) _ tau1      -> (i,notOpt):(getOpt min tau1)
      TAll _ _ tau1                      -> getOpt min tau1
      TList tau1                         -> getOpt min tau1
      TInt                               -> []
      TBool                              -> []
    where 
      opt    = if min == False then LVal Nbr else LVal Epsilon
      notOpt = if min == True  then LVal Nbr else LVal Epsilon