{-# LANGUAGE ScopedTypeVariables #-}

module Idris.TypeSearch (
  searchByType, searchPred, defaultScoreFunction
) where

import Control.Applicative (Applicative (..), (<$>), (<*>), (<|>))
import Control.Arrow (first, second, (&&&), (***))
import Control.Monad (guard)

import Data.List (find, partition, (\\))
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (catMaybes, fromMaybe, isJust, maybeToList, mapMaybe)
import Data.Monoid (Monoid (mempty, mappend))
import Data.Ord (comparing)
import qualified Data.PriorityQueue.FingerTree as Q
import Data.Set (Set)
import qualified Data.Set as S
import qualified Data.Text as T (pack, isPrefixOf)
import Data.Traversable (traverse)

import Idris.AbsSyntax (addUsingConstraints, addImpl, getIState, putIState, implicit)
import Idris.AbsSyntaxTree (class_instances, ClassInfo, defaultSyntax, eqTy, Idris,
  IState (idris_classes, idris_docstrings, tt_ctxt, idris_outputmode),
  implicitAllowed, OutputMode(..), prettyDocumentedIst, PTerm, toplevel)
import Idris.Core.Evaluate (Context (definitions), Def (Function, TyDecl, CaseOp), normaliseC)
import Idris.Core.TT hiding (score)
import Idris.Core.Unify (match_unify)
import Idris.Delaborate (delabTy)
import Idris.Docstrings (noDocs, overview)
import Idris.Elab.Type (elabType)
import Idris.Output (iRenderOutput, iPrintResult, iRenderResult)

import Prelude hiding (pred)

import Util.Pretty (text, char, vsep, (<>), Doc)

searchByType :: PTerm -> Idris ()
searchByType pterm = do
  pterm' <- addUsingConstraints syn emptyFC pterm
  pterm'' <- implicit toplevel syn name pterm'
  i <- getIState
  let pterm'''  = addImpl i pterm''
  ty <- elabType toplevel syn (fst noDocs) (snd noDocs) emptyFC [] name pterm'
  putIState i -- don't actually make any changes
  let names = searchUsing searchPred i ty
  let names' = take numLimit names
  let docs =
       [ let docInfo = (n, delabTy i n, fmap (overview . fst) (lookupCtxtExact n (idris_docstrings i))) in
         displayScore theScore <> char ' ' <> prettyDocumentedIst i docInfo
                | (n, theScore) <- names']
  case idris_outputmode i of
    RawOutput _  -> do mapM_ iRenderOutput docs
                       iPrintResult ""
    IdeSlave _ _ -> iRenderResult (vsep docs)
  where
    numLimit = 50
    syn = defaultSyntax { implicitAllowed = True } -- syntax
    name = sMN 0 "searchType" -- name

-- | Conduct a type-directed search using a given match predicate
searchUsing :: (IState -> Type -> [(Name, Type)] -> [(Name, a)]) 
  -> IState -> Type -> [(Name, a)]
searchUsing pred istate ty = pred istate nty . concat . M.elems $ 
  M.mapWithKey (\key -> M.toAscList . M.mapMaybe (f key)) (definitions ctxt)
  where
  nty = normaliseC ctxt [] ty
  ctxt = tt_ctxt istate
  f k x = do
    guard $ not (special k)
    type2 <- typeFromDef x
    return $ normaliseC ctxt [] type2
  special :: Name -> Bool
  special (NS n _) = special n
  special (SN _) = True
  special (UN n) =    T.pack "default#" `T.isPrefixOf` n 
                   || n `elem` map T.pack ["believe_me", "really_believe_me"]
  special _ = False

-- Our default search predicate.
searchPred :: IState -> Type -> [(Name, Type)] -> [(Name, Score)]
searchPred istate ty1 = matcher where
  maxScore = 100
  matcher = matchTypesBulk istate maxScore ty1


typeFromDef :: (Def, b, c, d) -> Maybe Type
typeFromDef (def, _, _, _) = get def where
  get :: Def -> Maybe Type
  get (Function ty _) = Just ty
  get (TyDecl _ ty) = Just ty
 -- get (Operator ty _ _) = Just ty
  get (CaseOp _ ty _ _ _ _)  = Just ty
  get _ = Nothing


-- | reverse the edges for a directed acyclic graph
reverseDag :: Ord k => [((k, a), Set k)] -> [((k, a), Set k)]
reverseDag xs = map f xs where
  f ((k, v), _) = ((k, v), S.fromList . map (fst . fst) $ filter (S.member k . snd) xs)

-- run vToP first!
-- | Compute a directed acyclic graph corresponding to the
-- arguments of a function. 
-- returns [(the name and type of the bound variable
--          the names in the type of the bound variable)]
computeDagP :: Ord n 
  => (TT n -> Bool) -- ^ filter to remove some arguments
  -> TT n
  -> ([((n, TT n), Set n)], [(n, TT n)], TT n)
computeDagP removePred t = (reverse (map f arguments), reverse theRemovedArgs , retTy) where
  f (n, ty) = ((n, ty), M.keysSet (usedVars ty))

  (arguments, theRemovedArgs, retTy) = go [] [] t

  -- NOTE : args are in reverse order
  go args removedArgs (Bind n (Pi ty _) sc) = let arg = (n, ty) in
    if removePred ty
      then go args (arg : removedArgs) sc
      else go (arg : args) removedArgs sc
  go args removedArgs sc = (args, removedArgs, sc)

-- | Collect the names and types of all the free variables
-- The Boolean indicates those variables which are determined due to injectivity
-- I have not given nearly enough thought to know whether this is correct
usedVars :: Ord n => TT n -> Map n (TT n, Bool)
usedVars = f True where
  f b (P Bound n t) = M.singleton n (t, b) `M.union` f b t
  f b (Bind n binder t2) = (M.delete n (f b t2) `M.union`) $ case binder of
    Let t v ->   f b t `M.union` f b v
    Guess t v -> f b t `M.union` f b v
    bind -> f b (binderTy bind)
  f b (App t1 t2) = f b t1 `M.union` f (b && isInjective t1) t2
  f b (Proj t _) = f b t
  f _ (V _) = error "unexpected! run vToP first"
  f _ _ = M.empty

-- | Remove a node from a directed acyclic graph
deleteFromDag :: Ord n => n -> [((n, TT n), (a, Set n))] -> [((n, TT n), (a, Set n))]
deleteFromDag name [] = []
deleteFromDag name (((name2, ty), (ix, set)) : xs) = (if name == name2
  then id
  else (((name2, ty) , (ix, S.delete name set)) :) ) (deleteFromDag name xs)

deleteFromArgList :: Ord n => n -> [(n, TT n)] -> [(n, TT n)]
deleteFromArgList n = filter ((/= n) . fst)

-- | Asymmetric modifications to keep track of
data AsymMods = Mods
  { argApp         :: !Int
  , typeClassApp   :: !Int
  , typeClassIntro :: !Int
  } deriving (Eq, Show)


-- | Homogenous tuples
data Sided a = Sided
  { left  :: !a
  , right :: !a
  } deriving (Eq, Show)

sided :: (a -> a -> b) -> Sided a -> b
sided f (Sided l r) = f l r

-- | Could be a functor instance, but let's avoid name overloading
both :: (a -> b) -> Sided a -> Sided b
both f (Sided l r) = Sided (f l) (f r)

data Score = Score
  { transposition :: !Int
  , equalityFlips :: !Int
  , asymMods      :: !(Sided AsymMods)
  } deriving (Eq, Show)

displayScore :: Score -> Doc a
displayScore score = text $ case both noMods (asymMods score) of
  Sided True  True  -> "=" -- types are isomorphic
  Sided True  False -> "<" -- found type is more general than searched type
  Sided False True  -> ">" -- searched type is more general than found type
  Sided False False -> "_"
  where 
  noMods (Mods app tcApp tcIntro) = app + tcApp + tcIntro == 0

scoreCriterion :: Score -> Bool
scoreCriterion (Score _ _ amods) = not
  (  sided (&&) (both ((> 0) . argApp) amods)
  || sided (+) (both argApp amods) > 4
  || sided (||) (both (\(Mods _ tcApp tcIntro) -> tcApp > 3 || tcIntro > 3) amods)
  ) 

defaultScoreFunction :: Score -> Int
defaultScoreFunction (Score trans eqFlip amods) = 
  trans + eqFlip + linearPenalty + upAndDowncastPenalty
  where
  -- prefer finding a more general type to a less general type
  linearPenalty = (\(Sided l r) -> 3 * l + r) 
    (both (\(Mods app tcApp tcIntro) -> 3 * app + 4 * tcApp + 2 * tcIntro) amods)
  -- it's very bad to have *both* upcasting and downcasting
  upAndDowncastPenalty = 100 * 
    sided (*) (both (\(Mods app tcApp tcIntro) -> 2 * app + tcApp + tcIntro) amods)

instance Monoid a => Monoid (Sided a) where
  mempty = Sided mempty mempty
  (Sided l1 r1) `mappend` (Sided l2 r2) = Sided (l1 `mappend` l2) (r1 `mappend` r2)

instance Monoid AsymMods where
  mempty = Mods 0 0 0
  (Mods a b c) `mappend` (Mods a' b' c') = Mods (a + a') (b + b') (c + c')

instance Monoid Score where
  mempty = Score 0 0 mempty
  (Score t e mods) `mappend` (Score t' e' mods') = Score (t + t') (e + e') (mods `mappend` mods') 

-- | A directed acyclic graph representing the arguments to a function
-- The 'Int' represents the position of the argument (1st argument, 2nd, etc.)
type ArgsDAG = [((Name, Type), (Int, Set Name))]

-- | A list of typeclass constraints
type Classes = [(Name, Type)]

-- | The state corresponding to an attempted match of two types.
data State = State
  { holes          :: ![(Name, Type)] -- ^ names which have yet to be resolved
  , argsAndClasses :: !(Sided (ArgsDAG, Classes))
     -- ^ arguments and typeclass constraints for each type which have yet to be resolved
  , score     :: !Score -- ^ the score so far
  , usedNames :: ![Name] -- ^ all names that have been used
  } deriving Show

modifyTypes :: (Type -> Type) -> (ArgsDAG, Classes) -> (ArgsDAG, Classes)
modifyTypes f = modifyDag *** modifyList
  where
  modifyDag = map (first (second f))
  modifyList = map (second f)

findNameInArgsDAG :: Name -> ArgsDAG -> Maybe (Type, Maybe Int)
findNameInArgsDAG name = fmap ((snd . fst) &&& (Just . fst . snd)) . find ((name ==) . fst . fst)

findName :: Name -> (ArgsDAG, Classes) -> Maybe (Type, Maybe Int)
findName n (args, classes) = findNameInArgsDAG n args <|> ((,) <$> lookup n classes <*> Nothing)

deleteName :: Name -> (ArgsDAG, Classes) -> (ArgsDAG, Classes)
deleteName n (args, classes) = (deleteFromDag n args, filter ((/= n) . fst) classes)

tcToMaybe :: TC' e a -> Maybe a
tcToMaybe (OK x) = Just x
tcToMaybe (Error _) = Nothing

inArgTys :: (Type -> Type) -> ArgsDAG -> ArgsDAG
inArgTys = map . first . second


typeclassUnify :: Ctxt ClassInfo -> Context -> Type -> Type -> Maybe [(Name, Type)]
typeclassUnify classInfo ctxt ty tyTry = do
  res <- tcToMaybe $ match_unify ctxt [] ty retTy [] theHoles []
  guard $ null (theHoles \\ map fst res)
  let argTys' = map (second $ foldr (.) id [ subst n t | (n, t) <- res ]) tcArgs
  return argTys'
  where
  tyTry' = vToP tyTry
  theHoles = map fst nonTcArgs
  retTy = getRetTy tyTry'
  (tcArgs, nonTcArgs) = partition (isTypeClassArg classInfo . snd) $ getArgTys tyTry'

isTypeClassArg :: Ctxt ClassInfo -> Type -> Bool
isTypeClassArg classInfo ty = not (null (getClassName clss >>= flip lookupCtxt classInfo)) where
  (clss, args) = unApply ty
  getClassName (P (TCon _ _) className _) = [className]
  getClassName _ = []


instance Ord Score where
  compare = comparing defaultScoreFunction

-- | Compute the power set
subsets :: [a] -> [[a]]
subsets [] = [[]]
subsets (x : xs) = let ss = subsets xs in map (x :) ss ++ ss


-- not recursive (i.e., doesn't flip iterated identities) at the moment
-- recalls the number of flips that have been made
flipEqualities :: Type -> [(Int, Type)]
flipEqualities t = case t of
  eq1@(App (App (App (App eq@(P _ eqty _) tyL) tyR) valL) valR) | eqty == eqTy ->
    [(0, eq1), (1, App (App (App (App eq tyR) tyL) valR) valL)]
  Bind n binder sc -> (\bind' (j, sc') -> (fst (binderTy bind') + j, Bind n (fmap snd bind') sc'))
    <$> traverse flipEqualities binder <*> flipEqualities sc
  App f x -> (\(i, f') (j, x') -> (i + j, App f' x')) 
    <$> flipEqualities f <*> flipEqualities x
  t' -> [(0, t')]

--DONT run vToP first!
-- | Try to match two types together in a unification-like procedure.
-- Returns a list of types and their minimum scores, sorted in order
-- of increasing score.
matchTypesBulk :: forall info. IState -> Int -> Type -> [(info, Type)] -> [(info, Score)]
matchTypesBulk istate maxScore type1 types = getAllResults startQueueOfQueues where
  getStartQueues :: (info, Type) -> Maybe (Score, (info, Q.PQueue Score State))
  getStartQueues nty@(info, type2) = case mapMaybe startStates ty2s of
    [] -> Nothing
    xs -> Just (minimum (map fst xs), (info, Q.fromList xs))
    where
    ty2s = (\(i, dag) (j, retTy) -> (i + j, dag, retTy))
      <$> flipEqualitiesDag dag2 <*> flipEqualities retTy2
    flipEqualitiesDag dag = case dag of 
      [] -> [(0, [])]
      ((n, ty), (pos, deps)) : xs -> 
         (\(i, ty') (j, xs') -> (i + j , ((n, ty'), (pos, deps)) : xs'))
           <$> flipEqualities ty <*> flipEqualitiesDag xs
    startStates (numEqFlips, sndDag, sndRetTy) = do
      state <- unifyQueue (State startingHoles 
                (Sided (dag1, typeClassArgs1) (sndDag, typeClassArgs2))
                (mempty { equalityFlips = numEqFlips }) usedns) [(retTy1, sndRetTy)]
      return (score state, state)


    (dag2, typeClassArgs2, retTy2) = makeDag (uniqueBinders (map fst argNames1) type2)
    argNames2 = map fst dag2
    usedns = map fst startingHoles
    startingHoles = argNames1 ++ argNames2

    startingTypes = [(retTy1, retTy2)] 


  startQueueOfQueues :: Q.PQueue Score (info, Q.PQueue Score State)
  startQueueOfQueues = Q.fromList $ mapMaybe getStartQueues types

  getAllResults :: Q.PQueue Score (info, Q.PQueue Score State) -> [(info, Score)]
  getAllResults q = case Q.minViewWithKey q of
    Nothing -> []
    Just ((nextScore, (info, stateQ)), q') ->
      if defaultScoreFunction nextScore <= maxScore
        then case nextStepsQueue stateQ of
          Nothing -> getAllResults q'
          Just (Left stateQ') -> case Q.minViewWithKey stateQ' of
             Nothing -> getAllResults q'
             Just ((newQscore,_), _) -> getAllResults (Q.add newQscore (info, stateQ') q')
          Just (Right theScore) -> (info, theScore) : getAllResults q'
        else []


  ctxt = tt_ctxt istate
  classInfo = idris_classes istate

  (dag1, typeClassArgs1, retTy1) = makeDag type1
  argNames1 = map fst dag1
  makeDag :: Type -> (ArgsDAG, Classes, Type)
  makeDag = first3 (zipWith (\i (ty, deps) -> (ty, (i, deps))) [0..] . reverseDag) . 
    computeDagP (isTypeClassArg classInfo) . vToP
  first3 f (a,b,c) = (f a, b, c)
  
  -- update our state with the unification resolutions
  resolveUnis :: [(Name, Type)] -> State -> Maybe (State, [(Type, Type)])
  resolveUnis [] state = Just (state, [])
  resolveUnis ((name, term@(P Bound name2 _)) : xs) 
    state | isJust findArgs = do
    ((ty1, ix1), (ty2, ix2)) <- findArgs

    (state'', queue) <- resolveUnis xs state'
    let transScore = fromMaybe 0 (abs <$> ((-) <$> ix1 <*> ix2))
    return (inScore (\s -> s { transposition = transposition s + transScore }) state'', (ty1, ty2) : queue)
    where
    unresolved = argsAndClasses state
    inScore f stat = stat { score = f (score stat) }
    findArgs = ((,) <$> findName name  (left unresolved) <*> findName name2 (right unresolved)) <|>
               ((,) <$> findName name2 (left unresolved) <*> findName name  (right unresolved))
    matchnames = [name, name2]
    deleteArgs = deleteName name . deleteName name2
    state' = state { holes = filter (not . (`elem` matchnames) . fst) (holes state)
                   , argsAndClasses = both (modifyTypes (subst name term) . deleteArgs) unresolved}

  resolveUnis ((name, term) : xs)
    state@(State hs unresolved _ _) = case both (findName name) unresolved of
      Sided Nothing  Nothing  -> Nothing
      Sided (Just _) (Just _) -> error "Idris internal error: TypeSearch.resolveUnis"
      oneOfEach -> first (addScore (both scoreFor oneOfEach)) <$> nextStep
    where
    scoreFor (Just _) = mempty { argApp = 1 }
    scoreFor Nothing  = mempty { argApp = otherApplied }
    -- find variables which are determined uniquely by the type
    -- due to injectivity
    matchedVarMap = usedVars term
    bothT f (x, y) = (f x, f y)
    (injUsedVars, notInjUsedVars) = bothT M.keys . M.partition id . M.filterWithKey (\k _-> k `elem` map fst hs) $ M.map snd matchedVarMap 
    varsInTy = injUsedVars ++ notInjUsedVars
    toDelete = name : varsInTy
    deleteMany = foldr (.) id (map deleteName toDelete)

    otherApplied = length notInjUsedVars

    addScore additions theState = theState { score = let s = score theState in
      s { asymMods = asymMods s `mappend` additions } }
    state' = state { holes = filter (not . (`elem` toDelete) . fst) hs  
                   , argsAndClasses = both (modifyTypes (subst name term) . deleteMany) (argsAndClasses state) }
    nextStep = resolveUnis xs state'


  -- | resolve a queue of unification constraints
  unifyQueue :: State -> [(Type, Type)] -> Maybe State
  unifyQueue state [] = return state
  unifyQueue state ((ty1, ty2) : queue) = do
    --trace ("go: \n" ++ show state) True `seq` return ()
    res <- tcToMaybe $ match_unify ctxt [ (n, Pi ty (TType (UVar 0))) | (n, ty) <- holes state] ty1 ty2 [] (map fst $ holes state) []
    (state', queueAdditions) <- resolveUnis res state
    guard $ scoreCriterion (score state')
    unifyQueue state' (queue ++ queueAdditions)

  possClassInstances :: [Name] -> Type -> [Type]
  possClassInstances usedns ty = do
    className <- getClassName clss
    classDef <- lookupCtxt className classInfo
    n <- class_instances classDef
    def <- lookupCtxt n (definitions ctxt)
    nty <- normaliseC ctxt [] <$> (case typeFromDef def of Just x -> [x]; Nothing -> [])
    let ty' = vToP (uniqueBinders usedns nty)
    return ty'
    where
    (clss, _) = unApply ty
    getClassName (P (TCon _ _) className _) = [className]
    getClassName _ = []

  -- Just if the computation hasn't totally failed yet, Nothing if it has
  -- Left if we haven't found a terminal state, Right if we have
  nextStepsQueue :: Q.PQueue Score State -> Maybe (Either (Q.PQueue Score State) Score)
  nextStepsQueue queue = do
    ((nextScore, next), rest) <- Q.minViewWithKey queue
    if isFinal next 
      then Just $ Right nextScore
      else let additions = if scoreCriterion nextScore
                 then Q.fromList [ (score state, state) | state <- nextSteps next ]
                 else Q.empty in
           Just $ Left (Q.union rest additions)
    where
    isFinal (State [] (Sided ([], []) ([], [])) _ _) = True
    isFinal _ = False

  -- | Find all possible matches starting from a given state.
  -- We go in stages rather than concatenating all three results in hopes of narrowing
  -- the search tree. Once we advance in a phase, there should be no going back.
  nextSteps :: State -> [State]

  -- Stage 3 - match typeclasses
  nextSteps (State [] unresolved@(Sided ([], c1) ([], c2)) scoreAcc usedns) = 
    if null results3 then results4 else results3
    where
    -- try to match a typeclass argument from the left with a typeclass argument from the right
    results3 =
         catMaybes [ unifyQueue (State [] 
         (Sided ([], deleteFromArgList n1 c1)
                ([], map (second subst2for1) (deleteFromArgList n2 c2)))
         scoreAcc usedns) [(ty1, ty2)] 
     | (n1, ty1) <- c1, (n2, ty2) <- c2, let subst2for1 = psubst n2 (P Bound n1 ty1)]

    -- try to hunt match a typeclass constraint by replacing it with an instance
    results4 = [ State [] (both (\(cs, _, _) -> ([], cs)) sds) 
               (scoreAcc `mappend` Score 0 0 (both (\(_, amods, _) -> amods) sds))
               (usedns ++ sided (++) (both (\(_, _, hs) -> hs) sds))
               | sds <- allMods ]
      where
      allMods = parallel defMod mods
      mods :: Sided [( Classes, AsymMods, [Name] )]
      mods = both (instanceMods . snd) unresolved
      defMod :: Sided (Classes, AsymMods, [Name])
      defMod = both (\(_, cs) -> (cs, mempty , [])) unresolved
      parallel :: Sided a -> Sided [a] -> [Sided a]
      parallel (Sided l r) (Sided ls rs) = map (flip Sided r) ls ++ map (Sided l) rs
      instanceMods :: Classes -> [( Classes , AsymMods, [Name] )]
      instanceMods classes = [ ( newClassArgs, mempty { typeClassApp = 1 }, newHoles )
                      | (_, ty) <- classes
                      , inst <- possClassInstances usedns ty
                      , newClassArgs <- maybeToList $ typeclassUnify classInfo ctxt ty inst
                      , let newHoles = map fst newClassArgs ]


  -- Stage 1 - match arguments
  nextSteps (State hs (Sided (dagL, c1) (dagR, c2)) scoreAcc usedns) = results where

    results = concatMap takeSomeClasses results1

    -- we only try to match arguments whose names don't appear in the types
    -- of any other arguments
    canBeFirst :: ArgsDAG -> [(Name, Type)]
    canBeFirst = map fst . filter (S.null . snd . snd)

    -- try to match an argument from the left with an argument from the right
    results1 = catMaybes [ unifyQueue (State (filter (not . (`elem` [n1,n2]) . fst) hs) 
         (Sided (deleteFromDag n1 dagL, c1) 
                (inArgTys subst2for1 $ deleteFromDag n2 dagR, map (second subst2for1) c2))
          scoreAcc usedns) [(ty1, ty2)] 
     | (n1, ty1) <- canBeFirst dagL, (n2, ty2) <- canBeFirst dagR
     , let subst2for1 = psubst n2 (P Bound n1 ty1)]



  -- Stage 2 - simply introduce a subset of the typeclasses
  -- we've finished, so take some classes
  takeSomeClasses (State [] unresolved@(Sided ([], _) ([], _)) scoreAcc usedns) = 
    map statesFromMods . prod $ both (classMods . snd) unresolved
    where
    swap (Sided l r) = Sided r l
    statesFromMods :: Sided (Classes, AsymMods) -> State
    statesFromMods sides = let classes = both (\(c, _) -> ([], c)) sides
                               mods    = swap (both snd sides) in
      State [] classes (scoreAcc `mappend` (mempty { asymMods = mods })) usedns
    classMods :: Classes -> [(Classes, AsymMods)]
    classMods cs = let lcs = length cs in 
      [ (cs', mempty { typeClassIntro = lcs - length cs' }) | cs' <- subsets cs ]
    prod :: Sided [a] -> [Sided a]
    prod (Sided ls rs) = [Sided l r | l <- ls, r <- rs]
  -- still have arguments to match, so just be the identity
  takeSomeClasses s = [s]