{-# LANGUAGE NondecreasingIndentation #-}
module Agda.TypeChecking.SizedTypes.Solve where
import Prelude hiding (null)
import Control.Monad hiding (forM, forM_)
import Control.Monad.Trans.Maybe
import Data.Either
import Data.Foldable (foldMap, forM_)
import qualified Data.Foldable as Fold
import Data.Function
import qualified Data.List as List
import Data.List.NonEmpty (NonEmpty(..), nonEmpty)
import qualified Data.List.NonEmpty as NonEmpty
import Data.Monoid
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (forM)
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Internal.MetaVars
import Agda.TypeChecking.Monad as TCM hiding (Offset)
import Agda.TypeChecking.Monad.Builtin
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Free
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.MetaVars
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.TypeChecking.Constraints as C
import qualified Agda.TypeChecking.SizedTypes as S
import Agda.TypeChecking.SizedTypes.Syntax as Size
import Agda.TypeChecking.SizedTypes.Utils
import Agda.TypeChecking.SizedTypes.WarshallSolver as Size
import Agda.Utils.Cluster
import Agda.Utils.Except ( MonadError(catchError) )
import Agda.Utils.Function
import Agda.Utils.Functor
import Agda.Utils.Lens
import qualified Agda.Utils.List as List
import Agda.Utils.Maybe
import Agda.Utils.Monad
import Agda.Utils.Null
import Agda.Utils.Pretty (Pretty, prettyShow)
import qualified Agda.Utils.Pretty as P
import Agda.Utils.Singleton
import Agda.Utils.Size
import qualified Agda.Utils.VarSet as VarSet
import Agda.Utils.Impossible
type CC = Closure TCM.Constraint
data DefaultToInfty
  = DefaultToInfty      
  | DontDefaultToInfty  
  deriving (Eq, Ord, Show)
solveSizeConstraints :: DefaultToInfty -> TCM ()
solveSizeConstraints flag =  do
  
  cs0 <- mapM (mapClosure normalise) =<< S.takeSizeConstraints (== CmpLeq)
    
  unless (null cs0) $
    reportSDoc "tc.size.solve" 40 $ vcat $
      [ text $ "Solving constraints (" ++ show flag ++ ")"
      ] ++ map prettyTCM cs0
  let 
      cannotSolve :: TCM a
      cannotSolve = typeError . GenericDocError =<<
        vcat ("Cannot solve size constraints" : map prettyTCM cs0)
  
  
  sizeMetaSet <- Set.fromList . map (\ (x, _t, _tel) -> x) <$> S.getSizeMetas True
  
  cms <- forM cs0 $ \ cl -> enterClosure cl $ \ c -> do
    
    
    return (cl, map metaId . Set.toList $
      sizeMetaSet `Set.intersection` allMetas singleton c)
  
  let classify :: (a, [b]) -> Either a (a, NonEmpty b)
      classify (cl, [])     = Left  cl
      classify (cl, (x:xs)) = Right (cl, x :| xs)
  let (clcs, othercs) = partitionEithers $ map classify cms
  
  let ccs = cluster' othercs
  
  
  forM_ clcs $ \ c -> () <$ solveSizeConstraints_ flag [c]
  
  constrainedMetas <- Set.unions <$> do
    forM  (ccs) $ \ (cs :: NonEmpty CC) -> do
      reportSDoc "tc.size.solve" 60 $ vcat $ concat
        [ [ "size constraint cluster:" ]
        , map (text . show) $ NonEmpty.toList cs
        ]
      
      
      enterClosure (Fold.maximumBy (compare `on` (length . envContext . clEnv)) cs) $ \ _ -> do
        
        cs' :: [TCM.Constraint] <- catMaybes <$> do
          mapM (runMaybeT . castConstraintToCurrentContext) $ NonEmpty.toList cs
        reportSDoc "tc.size.solve" 20 $ vcat $
          [ "converted size constraints to context: " <+> do
              tel <- getContextTelescope
              inTopContext $ prettyTCM tel
          ] ++ map (nest 2 . prettyTCM) cs'
        
        solveSizeConstraints_ flag =<<  mapM buildClosure cs'
  
  
  when (flag == DefaultToInfty) $ do
    
    
    
    
    ms <- S.getSizeMetas False 
    unless (null ms) $ do
      inf <- primSizeInf
      forM_ ms $ \ (m, t, tel) -> do
        unless (m `Set.member` constrainedMetas) $ do
        unlessM (isFrozen m) $ do
        reportSDoc "tc.size.solve" 20 $
          "solution " <+> prettyTCM (MetaV m []) <+>
          " := "      <+> prettyTCM inf
        assignMeta 0 m t (List.downFrom $ size tel) inf
  
  
  
  
  
  
  
  forM_ cs0 $ \ cl -> enterClosure cl solveConstraint
castConstraintToCurrentContext' :: Closure TCM.Constraint -> MaybeT TCM TCM.Constraint
castConstraintToCurrentContext' cl = do
  let modN  = envCurrentModule $ clEnv cl
      delta = envContext $ clEnv cl
  
  
  
  
  delta1 <- liftTCM $ maybe empty (^. secTelescope) <$> getSection modN
  
  let delta2 = size delta - size delta1
  unless (delta2 >= 0) __IMPOSSIBLE__
  
  modM  <- currentModule
  gamma <- liftTCM $ getContextSize
  
  
  gamma1 <-liftTCM $ maybe empty (^. secTelescope) <$> getSection modM
  
  let gamma2 = gamma - size gamma1
  
  sigma <- liftTCM $ fromMaybe idS <$> getModuleParameterSub modN
  
  reportSDoc "tc.constr.cast" 40 $ "casting constraint" $$ do
    tel <- getContextTelescope
    inTopContext $ nest 2 $ vcat $
      [ "current module                = " <+> prettyTCM modM
      , "current module telescope      = " <+> prettyTCM gamma1
      , "current context               = " <+> prettyTCM tel
      , "constraint module             = " <+> prettyTCM modN
      , "constraint module telescope   = " <+> prettyTCM delta1
      , "constraint context            = " <+> (prettyTCM =<< enterClosure cl (const $ getContextTelescope))
      , "constraint                    = " <+> enterClosure cl prettyTCM
      , "module parameter substitution = " <+> prettyTCM sigma
      ]
  
  
  
  
  
  guard (gamma2 >= 0)
  
  
  
  if modN == modM then raiseMaybe (gamma - size delta) $ clValue cl else do
  
  c <- raiseMaybe (-delta2) $ clValue cl
  
  
  
  fv <- liftTCM $ getModuleFreeVars modN
  guard $ fv == size delta1
  
  return $ applySubst sigma c
  where
    raiseMaybe n c = do
      
      guard $ n >= 0 || List.all (>= -n) (VarSet.toList $ allFreeVars c)
      return $ raise n c
castConstraintToCurrentContext :: Closure TCM.Constraint -> MaybeT TCM TCM.Constraint
castConstraintToCurrentContext cl = do
  
  let cp = envCurrentCheckpoint $ clEnv cl
  sigma <- caseMaybeM (viewTC $ eCheckpoints . key cp)
          (do
            
            
            gamma <- asksTC envContext 
            let findInGamma (Dom {unDom = (x, t)}) =
                  
                  
                  List.findIndex ((x ==) . fst . unDom) gamma
            let delta = envContext $ clEnv cl
                cand  = map findInGamma delta
            
            let coveredVars = VarSet.fromList $ catMaybes $ zipWith ($>) cand [0..]
            
            
            
            guard $ getAll $ runFree (All . (`VarSet.member` coveredVars)) IgnoreAll (clValue cl)
            
            
            
            return $ parallelS $ map (maybe __DUMMY_TERM__ var) cand
          ) return 
  
  return $ applySubst sigma (clValue cl)
  
  
solveSizeConstraints_ :: DefaultToInfty -> [CC] -> TCM (Set MetaId)
solveSizeConstraints_ flag cs0 = do
  
  
  ccs :: [(CC,HypSizeConstraint)] <- catMaybes <$> do
    forM cs0 $ \ c0 -> fmap (c0,) <$> computeSizeConstraint c0
  
  ccs' <- concat <$> do
    forM ccs $ \ (c0, HypSizeConstraint cxt hids hs sc) -> do
      case simplify1 (\ sc -> return [sc]) sc of
        Left _ -> typeError . GenericDocError =<< do
          "Contradictory size constraint" <+> prettyTCM c0
        Right cs -> return $ (c0,) . HypSizeConstraint cxt hids hs <$> cs
  
  
  let (csNoM, csMs) = (`List.partitionMaybe` ccs') $ \ p@(c0, c) ->
        fmap (p,) $ nonEmpty $ map (metaId . sizeMetaId) $ Set.toList $ flexs c
  
      css :: [NonEmpty (CC,HypSizeConstraint)]
      css = cluster' csMs
  
  whenJust (nonEmpty csNoM) $ solveCluster flag
  
  forM_ css $ solveCluster flag
  return $ Set.mapMonotonic sizeMetaId $ flexs $ map (snd . fst) csMs
solveCluster :: DefaultToInfty -> NonEmpty (CC,HypSizeConstraint) -> TCM ()
solveCluster flag ccs = do
  let cs = fmap snd ccs
  let prettyCs   = map prettyTCM $ NonEmpty.toList cs
  let err reason = typeError . GenericDocError =<< do
        vcat $
          [ text $ "Cannot solve size constraints" ] ++ prettyCs ++
          [ text $ "Reason: " ++ reason ]
  reportSDoc "tc.size.solve" 20 $ vcat $
    [ "Solving constraint cluster" ] ++ prettyCs
  
  
  
  let HypSizeConstraint gamma hids hs _ = Fold.maximumBy (compare `on` (length . sizeContext)) cs
  
  let n = size gamma
  
      csL = for cs $ \ (HypSizeConstraint cxt _ _ c) -> raise (n - size cxt) c
  
  
      csC :: [SizeConstraint]
      csC = applyWhen (null hs) (mapMaybe canonicalizeSizeConstraint) $ NonEmpty.toList csL
  reportSDoc "tc.size.solve" 30 $ vcat $
    [ "Size hypotheses" ] ++
    map (prettyTCM . HypSizeConstraint gamma hids hs) hs ++
    [ "Canonicalized constraints" ] ++
    map (prettyTCM . HypSizeConstraint gamma hids hs) csC
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  let metas :: [SizeMeta]
      metas = concat $ map (foldMap (:[])) csC
      csF   :: [Size.Constraint' NamedRigid Int]
      csF   = map (fmap (metaId . sizeMetaId)) csC
  
  let hyps = map (fmap (metaId . sizeMetaId)) hs
  
  let hg = either __IMPOSSIBLE__ id $ hypGraph (rigids csF) hyps
  
  
  
  
  
  
  
  
  
  
  
  
  sol :: Solution NamedRigid Int <- either err return $
    iterateSolver Map.empty hg csF emptySolution
  
  solved <- fmap Set.unions $ forM (Map.assocs $ theSolution sol) $ \ (m, a) -> do
    unless (validOffset a) __IMPOSSIBLE__
    
    u <- unSizeExpr $ fmap __IMPOSSIBLE__ a
    let x = MetaId m
    let SizeMeta _ xs = fromMaybe __IMPOSSIBLE__ $
          List.find ((m==) . metaId . sizeMetaId) metas
    
    let ys = rigidIndex <$> Set.toList (rigids a)
        ok = all (`elem` xs) ys 
    
    u <- if ok then return u else primSizeInf
    t <- getMetaType x
    reportSDoc "tc.size.solve" 20 $ unsafeModifyContext (const gamma) $ do
      let args = map (Apply . defaultArg . var) xs
      "solution " <+> prettyTCM (MetaV x args) <+> " := " <+> prettyTCM u
    reportSDoc "tc.size.solve" 60 $ vcat
      [ text $ "  xs = " ++ show xs
      , text $ "  u  = " ++ show u
      ]
    ifM (isFrozen x `or2M` (not <$> asksTC envAssignMetas)) (return Set.empty) $ do
      assignMeta n x t xs u
      return $ Set.singleton x
    
    
    
    
  
  
  ims <- Set.fromList <$> getInteractionMetas
  
  let ms = Set.fromList (map sizeMetaId metas) Set.\\ solved
  
  let noIP = Set.null $ Set.intersection ims ms
  unless (null ms) $ reportSDoc "tc.size.solve" 30 $ fsep $
    [ "cluster did not solve these size metas: " ] ++ map prettyTCM (Set.toList ms)
  solvedAll <- do
    
    if Set.null ms                then return True  else do
    
    if flag == DontDefaultToInfty then return False else do
    
    if not noIP                   then return False else do
    
    inf <- primSizeInf
    and <$> do
      forM (Set.toList ms) $ \ m -> do
        
        let no = do
              reportSDoc "tc.size.solve" 30 $
                prettyTCM (MetaV m []) <+> "is frozen, cannot set it to ∞"
              return False
        ifM (isFrozen m `or2M` do not <$> asksTC envAssignMetas) no $  do
          reportSDoc "tc.size.solve" 20 $
            "solution " <+> prettyTCM (MetaV m []) <+>
            " := "      <+> prettyTCM inf
          t <- metaType m
          TelV tel core <- telView t
          unlessM (isJust <$> isSizeType core) __IMPOSSIBLE__
          assignMeta 0 m t (List.downFrom $ size tel) inf
          return True
  
  when solvedAll $ do
    let cs0 = map fst $ NonEmpty.toList ccs
        
        cannotSolve = typeError . GenericDocError =<<
          vcat ("Cannot solve size constraints" : map prettyTCM cs0)
    flip catchError (const cannotSolve) $
      noConstraints $
        forM_ cs0 $ \ cl -> enterClosure cl solveConstraint
getSizeHypotheses :: Context -> TCM [(Nat, SizeConstraint)]
getSizeHypotheses gamma = unsafeModifyContext (const gamma) $ do
  (_, msizelt) <- getBuiltinSize
  caseMaybe msizelt (return []) $ \ sizelt -> do
    
    catMaybes <$> do
      forM (zip [0..] gamma) $ \ (i, ce) -> do
        
        let (x, t) = unDom ce
            s      = prettyShow x
        t <- reduce . raise (1 + i) . unEl $ t
        case t of
          Def d [Apply u] | d == sizelt -> do
            caseMaybeM (sizeExpr $ unArg u) (return Nothing) $ \ a ->
              return $ Just $ (i, Constraint (Rigid (NamedRigid s i) 0) Lt a)
          _ -> return Nothing
canonicalizeSizeConstraint :: SizeConstraint -> Maybe (SizeConstraint)
canonicalizeSizeConstraint c@(Constraint a cmp b) = Just c
data NamedRigid = NamedRigid
  { rigidName  :: String   
  , rigidIndex :: Int      
  } deriving (Show)
instance Eq NamedRigid where (==) = (==) `on` rigidIndex
instance Ord NamedRigid where compare = compare `on` rigidIndex
instance Pretty NamedRigid where pretty = P.text . rigidName
instance Plus NamedRigid Int NamedRigid where
  plus (NamedRigid x i) j = NamedRigid x (i + j)
data SizeMeta = SizeMeta
  { sizeMetaId   :: MetaId
  
  
  
  , sizeMetaArgs :: [Int]       
  } deriving (Show)
instance Eq  SizeMeta where (==)    = (==)    `on` sizeMetaId
instance Ord SizeMeta where compare = compare `on` sizeMetaId
instance Pretty SizeMeta where pretty = P.pretty . sizeMetaId
instance PrettyTCM SizeMeta where
  prettyTCM (SizeMeta x es) = prettyTCM (MetaV x $ map (Apply . defaultArg . var) es)
instance Subst Term SizeMeta where
  applySubst sigma (SizeMeta x es) = SizeMeta x (map raise es)
    where
      raise i =
        case lookupS sigma i of
          Var j [] -> j
          _        -> __IMPOSSIBLE__
type DBSizeExpr = SizeExpr' NamedRigid SizeMeta
instance Subst Term (SizeExpr' NamedRigid SizeMeta) where
  applySubst sigma a =
    case a of
      Infty   -> a
      Const{} -> a
      Flex  x n -> Flex (applySubst sigma x) n
      Rigid r n ->
        case lookupS sigma $ rigidIndex r of
          Var j [] -> Rigid r{ rigidIndex = j } n
          _        -> __IMPOSSIBLE__
type SizeConstraint = Constraint' NamedRigid SizeMeta
instance Subst Term SizeConstraint where
  applySubst sigma (Constraint a cmp b) =
    Constraint (applySubst sigma a) cmp (applySubst sigma b)
instance PrettyTCM (SizeConstraint) where
  prettyTCM (Constraint a cmp b) = do
    u <- unSizeExpr a
    v <- unSizeExpr b
    prettyTCM u <+> pretty cmp <+> prettyTCM v
data HypSizeConstraint = HypSizeConstraint
  { sizeContext    :: Context
  , sizeHypIds     :: [Nat] 
  , sizeHypotheses :: [SizeConstraint]  
  , sizeConstraint :: SizeConstraint    
  }
instance Flexs SizeMeta HypSizeConstraint where
  flexs (HypSizeConstraint _ _ hs c) = flexs hs `mappend` flexs c
instance PrettyTCM HypSizeConstraint where
  prettyTCM (HypSizeConstraint cxt _ hs c) =
    unsafeModifyContext (const cxt) $ do
      let cxtNames = reverse $ map (fst . unDom) cxt
      
      prettyList (map prettyTCM cxtNames) <+> do
      applyUnless (null hs)
       (((hcat $ punctuate ", " $ map prettyTCM hs) <+> "|-") <+>)
       (prettyTCM c)
computeSizeConstraint :: Closure TCM.Constraint -> TCM (Maybe HypSizeConstraint)
computeSizeConstraint c = do
  let cxt = envContext $ clEnv c
  unsafeModifyContext (const cxt) $ do
    case clValue c of
      ValueCmp CmpLeq _ u v -> do
        reportSDoc "tc.size.solve" 50 $ sep $
          [ "converting size constraint"
          , prettyTCM c
          ]
        ma <- sizeExpr u
        mb <- sizeExpr v
        (hids, hs) <- unzip <$> getSizeHypotheses cxt
        let mk a b = HypSizeConstraint cxt hids hs $ Size.Constraint a Le b
        
        
        return $ mk <$> ma <*> mb
      _ -> __IMPOSSIBLE__
sizeExpr :: Term -> TCM (Maybe DBSizeExpr)
sizeExpr u = do
  u <- reduce u 
                
  reportSDoc "tc.conv.size" 60 $ "sizeExpr:" <+> prettyTCM u
  s <- sizeView u
  case s of
    SizeInf     -> return $ Just Infty
    SizeSuc u   -> fmap (`plus` (1 :: Offset)) <$> sizeExpr u
    OtherSize u -> case u of
      Var i []    -> (\ x -> Just $ Rigid (NamedRigid x i) 0) . prettyShow <$> nameOfBV i
      MetaV m es | Just xs <- mapM isVar es, List.fastDistinct xs
                  -> return $ Just $ Flex (SizeMeta m xs) 0
      _           -> return Nothing
  where
    isVar (Proj{})  = Nothing
    isVar (IApply _ _ v) = isVar (Apply (defaultArg v))
    isVar (Apply v) = case unArg v of
      Var i [] -> Just i
      _        -> Nothing
unSizeExpr :: HasBuiltins m => DBSizeExpr -> m Term
unSizeExpr a =
  case a of
    Infty         -> fromMaybe __IMPOSSIBLE__ <$> getBuiltin' builtinSizeInf
    Rigid r (O n) -> do
      unless (n >= 0) __IMPOSSIBLE__
      sizeSuc n $ var $ rigidIndex r
    Flex (SizeMeta x es) (O n) -> do
      unless (n >= 0) __IMPOSSIBLE__
      sizeSuc n $ MetaV x $ map (Apply . defaultArg . var) es
    Const{} -> __IMPOSSIBLE__