module Idris.ProofSearch(trivial, trivialHoles, proofSearch, resolveTC) where
import Idris.Core.Elaborate hiding (Tactic(..))
import Idris.Core.TT
import Idris.Core.Unify
import Idris.Core.Evaluate
import Idris.Core.CaseTree
import Idris.Core.Typecheck
import Idris.AbsSyntax
import Idris.Delaborate
import Idris.Error
import Control.Applicative ((<$>))
import Control.Monad
import Control.Monad.State.Strict
import qualified Data.Set as S
import Data.List
import Debug.Trace
trivial :: (PTerm -> ElabD ()) -> IState -> ElabD ()
trivial = trivialHoles [] []
trivialHoles :: [Name] -> 
                          
                [(Name, Int)] -> (PTerm -> ElabD ()) -> IState -> ElabD ()
trivialHoles psnames ok elab ist
     = try' (do elab (PApp (fileFC "prf") (PRef (fileFC "prf") [] eqCon) [pimp (sUN "A") Placeholder False, pimp (sUN "x") Placeholder False])
                return ())
            (do env <- get_env
                g <- goal
                tryAll env
                return ()) True
      where
        tryAll []     = fail "No trivial solution"
        tryAll ((x, b):xs)
           = do 
                hs <- get_holes
                let badhs = hs 
                g <- goal
                
                if 
                   (holesOK hs (binderTy b) && (null psnames || x `elem` psnames))
                   then try' (elab (PRef (fileFC "prf") [] x))
                             (tryAll xs) True
                   else tryAll xs
        holesOK hs ap@(App _ _ _)
           | (P _ n _, args) <- unApply ap
                = holeArgsOK hs n 0 args
        holesOK hs (App _ f a) = holesOK hs f && holesOK hs a
        holesOK hs (P _ n _) = not (n `elem` hs)
        holesOK hs (Bind n b sc) = holesOK hs (binderTy b) &&
                                   holesOK hs sc
        holesOK hs _ = True
        holeArgsOK hs n p [] = True
        holeArgsOK hs n p (a : as)
           | (n, p) `elem` ok = holeArgsOK hs n (p + 1) as
           | otherwise = holesOK hs a && holeArgsOK hs n (p + 1) as
trivialTCs :: [(Name, Int)] -> (PTerm -> ElabD ()) -> IState -> ElabD ()
trivialTCs ok elab ist
     = try' (do elab (PApp (fileFC "prf") (PRef (fileFC "prf") [] eqCon) [pimp (sUN "A") Placeholder False, pimp (sUN "x") Placeholder False])
                return ())
            (do env <- get_env
                g <- goal
                tryAll env
                return ()) True
      where
        tryAll []     = fail "No trivial solution"
        tryAll ((x, b):xs)
           = do 
                hs <- get_holes
                let badhs = hs 
                g <- goal
                env <- get_env
                
                if 
                   (holesOK hs (binderTy b) && tcArg env (binderTy b))
                   then try' (elab (PRef (fileFC "prf") [] x))
                             (tryAll xs) True
                   else tryAll xs
        tcArg env ty
           | (P _ n _, args) <- unApply (getRetTy (normalise (tt_ctxt ist) env ty))
                 = case lookupCtxtExact n (idris_classes ist) of
                        Just _ -> True
                        _ -> False
           | otherwise = False
        holesOK hs ap@(App _ _ _)
           | (P _ n _, args) <- unApply ap
                = holeArgsOK hs n 0 args
        holesOK hs (App _ f a) = holesOK hs f && holesOK hs a
        holesOK hs (P _ n _) = not (n `elem` hs)
        holesOK hs (Bind n b sc) = holesOK hs (binderTy b) &&
                                   holesOK hs sc
        holesOK hs _ = True
        holeArgsOK hs n p [] = True
        holeArgsOK hs n p (a : as)
           | (n, p) `elem` ok = holeArgsOK hs n (p + 1) as
           | otherwise = holesOK hs a && holeArgsOK hs n (p + 1) as
cantSolveGoal :: ElabD a
cantSolveGoal = do g <- goal
                   env <- get_env
                   lift $ tfail $
                      CantSolveGoal g (map (\(n,b) -> (n, binderTy b)) env)
proofSearch :: Bool -> 
               Bool -> 
                       
                       
               Bool -> 
               Bool -> 
               Int -> 
               (PTerm -> ElabD ()) -> Maybe Name -> Name ->
               [Name] ->
               [Name] ->
               IState -> ElabD ()
proofSearch False fromProver ambigok deferonfail depth elab _ nroot psnames [fn] ist
       = do 
            
            let all_imps = lookupCtxtName fn (idris_implicits ist)
            tryAllFns all_imps
  where
    
    tryAllFns [] | fromProver = cantSolveGoal
    tryAllFns [] = do attack; defer [] nroot; solve
    tryAllFns (f : fs) = try' (tryFn f) (tryAllFns fs) True
    tryFn (f, args) = do let imps = map isImp args
                         ps <- get_probs
                         hs <- get_holes
                         args <- map snd <$> try' (apply (Var f) imps)
                                                  (match_apply (Var f) imps) True
                         ps' <- get_probs
                         
                         hs' <- get_holes
                         ptm <- get_term
                         if fromProver then cantSolveGoal
                           else do
                             mapM_ (\ h -> do focus h
                                              attack; defer [] nroot; solve)
                                 (hs' \\ hs)
                             solve
    isImp (PImp p _ _ _ _) = (True, p)
    isImp arg = (True, priority arg) 
proofSearch rec fromProver ambigok deferonfail maxDepth elab fn nroot psnames hints ist
       = do compute
            ty <- goal
            hs <- get_holes
            env <- get_env
            tm <- get_term
            argsok <- conArgsOK ty
            if ambigok || argsok then
               case lookupCtxt nroot (idris_tyinfodata ist) of
                    [TISolution ts] -> findInferredTy ts
                    _ -> if ambigok then psRec rec maxDepth [] S.empty
                            
                            else handleError cantsolve
                                      (psRec rec maxDepth [] S.empty)
                                      (autoArg (sUN "auto"))
               else autoArg (sUN "auto") 
  where
    findInferredTy (t : _) = elab (delab ist (toUN t))
    cantsolve (InternalMsg _) = True
    cantsolve (CantSolveGoal _ _) = True
    cantsolve (IncompleteTerm _) = True
    cantsolve (At _ e) = cantsolve e
    cantsolve (Elaborating _ _ _ e) = cantsolve e
    cantsolve (ElaboratingArg _ _ _ e) = cantsolve e
    cantsolve err = False
    conArgsOK ty
       = let (f, as) = unApply ty in
           case f of
              P _ n _ ->
                let autohints = case lookupCtxtExact n (idris_autohints ist) of
                                     Nothing -> []
                                     Just hs -> hs in
                    case lookupCtxtExact n (idris_datatypes ist) of
                              Just t -> do rs <- mapM (conReady as)
                                                      (autohints ++ con_names t)
                                           return (and rs)
                              Nothing -> 
                                    return True
              TType _ -> return True
              _ -> typeNotSearchable ty
    conReady :: [Term] -> Name -> ElabD Bool
    conReady as n
       = case lookupTyExact n (tt_ctxt ist) of
              Just ty -> do let (_, cs) = unApply (getRetTy ty)
                            
                            
                            
                            hs <- get_holes
                            return $ and (map (notHole hs) (zip as cs))
              Nothing -> fail "Can't happen"
    
    
    notHole hs (P _ n _, c)
       | (P _ cn _, _) <- unApply c,
         n `elem` hs && isConName cn (tt_ctxt ist) = False
       | Constant _ <- c = not (n `elem` hs)
    
    notHole hs (fa, c)
       | (P _ fn _, args@(_:_)) <- unApply fa = fn `notElem` hs
    notHole _ _ = True
    inHS hs (P _ n _) = n `elem` hs
    isHS _ _ = False
    toUN t@(P nt (MN i n) ty)
       | ('_':xs) <- str n = t
       | otherwise = P nt (UN n) ty
    toUN (App s f a) = App s (toUN f) (toUN a)
    toUN t = t
    
    
    
    
    
    
    
    psRec :: Bool -> Int -> [Name] -> S.Set Type -> ElabD ()
    psRec _ 0 locs tys | fromProver = cantSolveGoal
    psRec rec 0 locs tys = do attack; defer [] nroot; solve 
    psRec False d locs tys = tryCons d locs tys hints
    psRec True d locs tys
                 = do compute
                      ty <- goal
                      when (S.member ty tys) $ fail "Been here before"
                      let tys' = S.insert ty tys
                      try' (try' (trivialHoles psnames [] elab ist)
                                 (resolveTC False False 20 ty nroot elab ist)
                                 True)
                           (try' (try' (resolveByCon (d  1) locs tys')
                                       (resolveByLocals (d  1) locs tys')
                                 True)
             
                         (if fromProver
                             then fail "cantSolveGoal"
                             else do attack; defer [] nroot; solve) True) True
    
    getFn d (Just f) | d < maxDepth1 && usersname f = [f]
                     | otherwise = []
    getFn d _ = []
    usersname (UN _) = True
    usersname (NS n _) = usersname n
    usersname _ = False
    resolveByCon d locs tys
        = do t <- goal
             let (f, _) = unApply t
             case f of
                P _ n _ ->
                   do let autohints = case lookupCtxtExact n (idris_autohints ist) of
                                           Nothing -> []
                                           Just hs -> hs
                      case lookupCtxtExact n (idris_datatypes ist) of
                          Just t -> do
                             let others = hints ++ con_names t ++ autohints
                             tryCons d locs tys (others ++ getFn d fn)
                          Nothing -> typeNotSearchable t
                _ -> typeNotSearchable t
    
    
    resolveByLocals d locs tys
        = do env <- get_env
             tryLocals d locs tys env
    tryLocals d locs tys [] = fail "Locals failed"
    tryLocals d locs tys ((x, t) : xs)
       | x `elem` locs || x `notElem` psnames = tryLocals d locs tys xs
       | otherwise = try' (tryLocal d (x : locs) tys x t)
                          (tryLocals d locs tys xs) True
    tryCons d locs tys [] = fail "Constructors failed"
    tryCons d locs tys (c : cs)
        = try' (tryCon d locs tys c) (tryCons d locs tys cs) True
    tryLocal d locs tys n t
          = do let a = getPArity (delab ist (binderTy t))
               tryLocalArg d locs tys n a
    tryLocalArg d locs tys n 0 = elab (PRef (fileFC "prf") [] n)
    tryLocalArg d locs tys n i
        = simple_app False (tryLocalArg d locs tys n (i  1))
                (psRec True d locs tys) "proof search local apply"
    
    tryCon d locs tys n =
         do ty <- goal
            let imps = case lookupCtxtExact n (idris_implicits ist) of
                            Nothing -> []
                            Just args -> map isImp args
            ps <- get_probs
            hs <- get_holes
            args <- map snd <$> try' (apply (Var n) imps)
                                     (match_apply (Var n) imps) True
            ps' <- get_probs
            hs' <- get_holes
            when (length ps < length ps') $ fail "Can't apply constructor"
            let newhs = filter (\ (x, y) -> not x) (zip (map fst imps) args)
            mapM_ (\ (_, h) -> do focus h
                                  aty <- goal
                                  psRec True d locs tys) newhs
            solve
    isImp (PImp p _ _ _ _) = (True, p)
    isImp arg = (False, priority arg)
    typeNotSearchable ty =
       lift $ tfail $ FancyMsg $
       [TextPart "Attempted to find an element of type",
        TermPart ty,
        TextPart "using proof search, but proof search only works on datatypes with constructors."] ++
       case ty of
         (Bind _ (Pi _ _ _) _) -> [TextPart "In particular, function types are not supported."]
         _ -> []
resolveTC :: Bool 
          -> Bool 
          -> Int 
          -> Term 
          -> Name 
          -> (PTerm -> ElabD ()) 
          -> IState -> ElabD ()
resolveTC def mvok depth top fn elab ist
   = do hs <- get_holes
        resTC' [] def hs depth top fn elab ist
resTC' tcs def topholes 0 topg fn elab ist = fail $ "Can't resolve type class"
resTC' tcs def topholes 1 topg fn elab ist = try' (trivial elab ist) (resolveTC def False 0 topg fn elab ist) True
resTC' tcs defaultOn topholes depth topg fn elab ist
  = do compute
       g <- goal
       
       
       
       
       let (argsok, okholePos) = case tcArgsOK g topholes of
                                    Nothing -> (False, [])
                                    Just hs -> (True, hs)
       env <- get_env
       probs <- get_probs
       if not argsok 
         then lift $ tfail $ CantResolve True topg (probErr probs)
         else do
           ptm <- get_term
           ulog <- getUnifyLog
           hs <- get_holes
           env <- get_env
           t <- goal
           let (tc, ttypes) = unApply (getRetTy t)
           let okholes = case tc of
                              P _ n _ -> zip (repeat n) okholePos
                              _ -> []
           traceWhen ulog ("Resolving class " ++ show g ++ "\nin" ++ show env ++ "\n" ++ show okholes) $
            try' (trivialTCs okholes elab ist)
                (do addDefault t tc ttypes
                    let stk = map fst (filter snd $ elab_stack ist)
                    let insts = findInstances ist t
                    blunderbuss t depth stk (stk ++ insts)) True
  where
    
    
    tcArgsOK ty hs | (P _ nc _, as) <- unApply (getRetTy ty), nc == numclass && defaultOn
       = Just []
    tcArgsOK ty hs 
       = let (f, as) = unApply (getRetTy ty) in
             case f of
                  P _ cn _ -> case lookupCtxtExact cn (idris_classes ist) of
                                   Just ci -> tcDetArgsOK 0 (class_determiners ci) hs as
                                   Nothing -> if any (isMeta hs) as
                                                 then Nothing
                                                 else Just []
                  _ -> if any (isMeta hs) as
                          then Nothing
                          else Just []
    
    
    tcDetArgsOK i ds hs (x : xs)
        | i `elem` ds = if isMeta hs x
                           then Nothing
                           else tcDetArgsOK (i + 1) ds hs xs
        | otherwise = do rs <- tcDetArgsOK (i + 1) ds hs xs
                         case x of
                              P _ n _ -> Just (i : rs)
                              _ -> Just rs
    tcDetArgsOK _ _ _ [] = Just []
    probErr [] = Msg ""
    probErr ((_,_,_,_,err,_,_) : _) = err
    isMeta :: [Name] -> Term -> Bool
    isMeta ns (P _ n _) = n `elem` ns
    isMeta _ _ = False
    notHole hs (P _ n _, c)
       | (P _ cn _, _) <- unApply (getRetTy c),
         n `elem` hs && isConName cn (tt_ctxt ist) = False
       | Constant _ <- c = not (n `elem` hs)
    notHole _ _ = True
    
    
    chaser (UN nm)
        | ('@':'@':_) <- str nm = True 
    chaser (SN (ParentN _ _)) = True
    chaser (NS n _) = chaser n
    chaser _ = False
    numclass = sNS (sUN "Num") ["Interfaces","Prelude"]
    addDefault t num@(P _ nc _) [P Bound a _] | nc == numclass && defaultOn
        = do focus a
             fill (RConstant (AType (ATInt ITBig))) 
             solve
    addDefault t f as
          | all boundVar as = return () 
    addDefault t f a = return () 
    boundVar (P Bound _ _) = True
    boundVar _ = False
    blunderbuss t d stk [] = do ps <- get_probs
                                lift $ tfail $ CantResolve False topg (probErr ps)
    blunderbuss t d stk (n:ns)
        | n /= fn 
              = tryCatch (resolve n d)
                    (\e -> case e of
                             CantResolve True _ _ -> lift $ tfail e
                             _ -> blunderbuss t d stk ns)
        | otherwise = blunderbuss t d stk ns
    introImps = do g <- goal
                   case g of
                        (Bind _ (Pi _ _ _) sc) -> do attack; intro Nothing
                                                     num <- introImps
                                                     return (num + 1)
                        _ -> return 0
    solven n = replicateM_ n solve
    resolve n depth
       | depth == 0 = fail $ "Can't resolve type class"
       | otherwise
           = do lams <- introImps
                t <- goal
                let (tc, ttypes) = trace (show t) $ unApply (getRetTy t)
                   
                let imps = case lookupCtxtName n (idris_implicits ist) of
                                [] -> []
                                [args] -> map isImp (snd args) 
                                xs -> error "The impossible happened - overloading is not expected here!"
                ps <- get_probs
                tm <- get_term
                args <- map snd <$> apply (Var n) imps
                solven lams 
                ps' <- get_probs
                when (length ps < length ps' || unrecoverable ps') $
                     fail "Can't apply type class"
                mapM_ (\ (_,n) -> do focus n
                                     t' <- goal
                                     let (tc', ttype) = unApply (getRetTy t')
                                     let got = fst (unApply (getRetTy t))
                                     let depth' = if tc' `elem` tcs
                                                     then depth  1 else depth
                                     resTC' (got : tcs) defaultOn topholes depth' topg fn elab ist)
                      (filter (\ (x, y) -> not x) (zip (map fst imps) args))
                
                hs <- get_holes
                ulog <- getUnifyLog
                solve
                traceWhen ulog ("Got " ++ show n) $ return ()
       where isImp (PImp p _ _ _ _) = (True, p)
             isImp arg = (False, priority arg)
findInstances :: IState -> Term -> [Name]
findInstances ist t
    | (P _ n _, _) <- unApply (getRetTy t)
        = case lookupCtxt n (idris_classes ist) of
            [CI _ _ _ _ _ ins _] ->
              [n | (n, True) <- ins, accessible n]
            _ -> []
    | otherwise = []
  where accessible n = case lookupDefAccExact n False (tt_ctxt ist) of
                            Just (_, Hidden) -> False
                            Just (_, Private) -> False
                            _ -> True