{-# LANGUAGE BangPatterns, CPP, DataKinds, FlexibleContexts               #-}
{-# LANGUAGE FlexibleInstances, LambdaCase, MultiWayIf, OverloadedStrings #-}
{-# LANGUAGE PatternGuards, RankNTypes, TypeOperators, ViewPatterns       #-}
-- | Since 0.3.0.0
module GHC.TypeLits.Presburger.Types
  ( pluginWith
  , defaultTranslation
  , Translation(..), ParseEnv, Machine
  , module Data.Integer.SAT
  ) where
import           Class                          (classTyCon)
import           Control.Applicative            ((<|>))
import           Control.Arrow                  (second)
import           Control.Monad                  (forM_, guard, mzero, unless)
import           Control.Monad.State.Class
import           Control.Monad.Trans.Class
import           Control.Monad.Trans.Maybe      (MaybeT (..))
import           Control.Monad.Trans.RWS.Strict (runRWS, tell)
import           Control.Monad.Trans.State      (StateT, runStateT)
import           Data.Foldable                  (asum)
import           Data.Integer.SAT               (Expr (..), Prop (..), PropSet,
                                                 assert)
import           Data.Integer.SAT               (checkSat, noProps, toName)
import qualified Data.Integer.SAT               as SAT
import           Data.List                      (nub)
import qualified Data.Map.Strict                as M
import           Data.Maybe                     (catMaybes, fromMaybe,
                                                 isNothing)
import           Data.Reflection                (Given, give, given)
import qualified Data.Set                       as Set
import           GHC.TypeLits.Presburger.Compat
import           PrelNames
import           TcPluginM                      (lookupOrig, newFlexiTyVar,
                                                 newWanted, tcLookupClass)
import           Type                           (mkPrimEqPredRole, mkTyVarTy)
import           TysWiredIn                     (promotedEQDataCon,
                                                 promotedGTDataCon,
                                                 promotedLTDataCon)
import           Var
#if MIN_VERSION_ghc(8,6,0)
import Plugins (purePlugin)
#endif

assert' :: Prop -> PropSet -> PropSet
assert' p ps = foldr assert ps (p : varPos)
  where
    varPos = [K 0 :<= Var i | i <- varsProp p ]

data Proof = Proved | Disproved [(Int, Integer)]
           deriving (Read, Show, Eq, Ord)

isProved :: Proof -> Bool
isProved Proved = True
isProved _      = False

varsProp :: Prop -> [SAT.Name]
varsProp (p :|| q) = nub $ varsProp p ++ varsProp q
varsProp (p :&& q) = nub $ varsProp p ++ varsProp q
varsProp (Not p)   = varsProp p
varsProp (e :== v) = nub $ varsExpr e ++ varsExpr v
varsProp (e :/= v) = nub $ varsExpr e ++ varsExpr v
varsProp (e :< v)  = nub $ varsExpr e ++ varsExpr v
varsProp (e :> v)  = nub $ varsExpr e ++ varsExpr v
varsProp (e :<= v) = nub $ varsExpr e ++ varsExpr v
varsProp (e :>= v) = nub $ varsExpr e ++ varsExpr v
varsProp _         = []

varsExpr :: Expr -> [SAT.Name]
varsExpr (e :+ v)   = nub $ varsExpr e ++ varsExpr v
varsExpr (e :- v)   = nub $ varsExpr e ++ varsExpr v
varsExpr (_ :* v)   = varsExpr v
varsExpr (Negate e) = varsExpr e
varsExpr (Var i)    = [i]
varsExpr (K _)      = []
varsExpr (If p e v) = nub $ varsProp p ++ varsExpr e ++ varsExpr v
varsExpr (Div e _)  = varsExpr e
varsExpr (Mod e _)  = varsExpr e

data PluginMode = DisallowNegatives
                | AllowNegatives
                deriving (Read, Show, Eq, Ord)

pluginWith :: TcPluginM Translation -> Plugin
pluginWith trans = defaultPlugin
    { tcPlugin = Just . presburgerPlugin trans . procOpts
#if MIN_VERSION_ghc(8,6,0)
    , pluginRecompile = purePlugin
#endif
    }
  where
    procOpts opts
      | "allow-negated-numbers" `elem` opts = AllowNegatives
      | otherwise = DisallowNegatives

presburgerPlugin :: TcPluginM Translation -> PluginMode -> TcPlugin
presburgerPlugin trans mode =
  tracePlugin "typelits-presburger"
  TcPlugin { tcPluginInit  = return ()
           , tcPluginSolve = decidePresburger mode trans
           , tcPluginStop  = const $ return ()
           }

testIf :: PropSet -> Prop -> Proof
testIf ps q = maybe Proved Disproved $ checkSat (Not q `assert'` ps)

-- Replaces every subtraction with new constant,
-- adding order constraint.
handleSubtraction :: PluginMode -> Prop -> Prop
handleSubtraction AllowNegatives p = p
handleSubtraction DisallowNegatives p0 =
  let (p, _, w) = runRWS (loop p0) () Set.empty
  in foldr (:&&) p w
  where
    loop PTrue     = return PTrue
    loop PFalse    = return PFalse
    loop (q :|| r) = (:||) <$> loop q <*> loop r
    loop (q :&& r) = (:&&) <$> loop q <*> loop r
    loop (Not q)   = Not <$> loop q
    loop (l :<= r) = (:<=) <$> loopExp l <*> loopExp r
    loop (l :< r)  = (:<) <$> loopExp l <*> loopExp r
    loop (l :>= r) = (:<=) <$> loopExp l <*> loopExp r
    loop (l :> r)  = (:>) <$> loopExp l <*> loopExp r
    loop (l :== r) = (:==) <$> loopExp l <*> loopExp r
    loop (l :/= r) = (:/=) <$> loopExp l <*> loopExp r


    withPositive pos = do
      dic <- get
      unless (Set.member pos dic) $ do
        modify $ Set.insert pos
        tell $ Set.fromList [pos :>= K 0]
      return pos

    loopExp e@(Negate _) = withPositive . Negate =<< loopExp e
    loopExp (l :- r)   = do
      e <- (:-) <$> loopExp l <*> loopExp r
      withPositive e
    loopExp (l :+ r)     = (:+) <$> loopExp l <*> loopExp r
    loopExp v@Var {}     = return v
    loopExp (c :* e)
      | c > 0 = (c :*) <$> loopExp e
      | otherwise = (negate c :*) <$> loopExp (Negate e)
    loopExp e@(K _) = return e

data Translation =
  Translation
    { isEmpty     :: [TyCon]
    , isTrue      :: [TyCon]
    , trueData    :: [TyCon]
    , falseData   :: [TyCon]
    , voids       :: [TyCon]
    , tyEq        :: [TyCon]
    , tyEqBool    :: [TyCon]
    , tyEqWitness :: [TyCon]
    , tyNeqBool   :: [TyCon]
    , natPlus     :: [TyCon]
    , natMinus    :: [TyCon]
    , natExp      :: [TyCon]
    , natTimes    :: [TyCon]
    , natLeq      :: [TyCon]
    , natLeqBool  :: [TyCon]
    , natGeq      :: [TyCon]
    , natGeqBool  :: [TyCon]
    , natLt       :: [TyCon]
    , natLtBool   :: [TyCon]
    , natGt       :: [TyCon]
    , natGtBool   :: [TyCon]
    , orderingLT  :: [TyCon]
    , orderingGT  :: [TyCon]
    , orderingEQ  :: [TyCon]
    , natCompare  :: [TyCon]
    , parsePred   :: (Type -> Machine Expr) -> Type -> Machine Prop
    , parseExpr   :: Type -> Machine Expr
    }

instance Semigroup Translation where
  l <> r =
    Translation
      { isEmpty = isEmpty l <> isEmpty r
      , isTrue = isTrue l <> isTrue r
      , voids = voids l <> voids r
      , tyEq = tyEq l <> tyEq r
      , tyEqBool = tyEqBool l <> tyEqBool r
      , tyEqWitness = tyEqWitness l <> tyEqWitness r
      , tyNeqBool = tyNeqBool l <> tyNeqBool r
      , natPlus = natPlus l <> natPlus r
      , natMinus = natMinus l <> natMinus r
      , natTimes = natTimes l <> natTimes r
      , natExp = natExp l <> natExp r
      , natLeq = natLeq l <> natLeq r
      , natGeq = natGeq l <> natGeq r
      , natLt = natLt l <> natLt r
      , natGt = natGt l <> natGt r
      , natLeqBool = natLeqBool l <> natLeqBool r
      , natGeqBool = natGeqBool l <> natGeqBool r
      , natLtBool = natLtBool l <> natLtBool r
      , natGtBool = natGtBool l <> natGtBool r
      , orderingLT = orderingLT l <> orderingLT r
      , orderingGT = orderingGT l <> orderingGT r
      , orderingEQ = orderingEQ l <> orderingEQ r
      , natCompare = natCompare l <> natCompare r
      , trueData = trueData l <> trueData r
      , falseData = falseData l <> falseData r
      , parsePred = \f ty -> parsePred l f ty <|> parsePred r f ty
      , parseExpr = (<|>) <$> parseExpr l <*> parseExpr r
      }

instance Monoid Translation where
  mempty = Translation
    { isEmpty = mempty
    , isTrue = mempty
    , tyEq  = mempty
    , tyEqBool = mempty
    , tyEqWitness = mempty
    , tyNeqBool = mempty
    , voids = mempty
    , natPlus = mempty
    , natMinus = mempty
    , natTimes = mempty
    , natExp = mempty
    , natLeq = mempty
    , natGeq = mempty
    , natLt = mempty
    , natGt = mempty
    , natLeqBool = mempty
    , natGeqBool = mempty
    , natLtBool = mempty
    , natGtBool = mempty
    , orderingLT = mempty
    , orderingGT = mempty
    , orderingEQ = mempty
    , natCompare = mempty
    , trueData = []
    , falseData = []
    , parsePred = const $ const mzero
    , parseExpr = const mzero
    }

decidePresburger :: PluginMode -> TcPluginM Translation -> () -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
decidePresburger _ genTrans _ gs [] [] = do
  tcPluginTrace "Started givens with: " (ppr $ map (ctEvPred . ctEvidence) gs)
  trans <- genTrans
  give trans $ do
    ngs <- mapM (\a -> runMachine $ (,) a <$> toPresburgerPred (deconsPred a)) gs
    let givens = catMaybes ngs
        prems0 = map snd givens
        prems  = foldr assert' noProps prems0
        (solved, _) = foldr go ([], noProps) givens
    if isNothing (checkSat prems)
      then return $ TcPluginContradiction gs
      else return $ TcPluginOk (map withEv solved) []
    where
      go (ct, p) (ss, prem)
        | Proved <- testIf prem p = (ct : ss, prem)
        | otherwise = (ss, assert' p prem)
decidePresburger mode genTrans _ gs ds ws = do
  trans <- genTrans
  give trans $ do
    gs' <- normaliseGivens gs
    let subst = mkSubstitution (gs' ++ ds)
    tcPluginTrace "Current subst" (ppr subst)
    tcPluginTrace "wanteds" $ ppr $ map deconsPred ws
    tcPluginTrace "givens" $ ppr $ map (subsType subst . deconsPred) gs
    tcPluginTrace "deriveds" $ ppr $ map deconsPred ds
    (prems, wants, prems0) <- do
      wants <- catMaybes <$>
              mapM
              (\ct -> runMachine $ (,) ct <$> toPresburgerPred
                  ( subsType subst
                  $ deconsPred $ subsCt subst ct))
              (filter (isWanted . ctEvidence) ws)

      resls <- mapM (runMachine . toPresburgerPred . subsType subst . deconsPred)
                      (gs ++ ds)
      let prems = foldr assert' noProps $ catMaybes resls
      return (prems, map (second $ handleSubtraction mode) wants, catMaybes resls)
    let solved = map fst $ filter (isProved . testIf prems . snd) wants
        coerced = [(evByFiat "ghc-typelits-presburger" t1 t2, ct)
                  | ct <- solved
                  , EqPred NomEq t1 t2 <- return (classifyPredType $ deconsPred ct)
                  ]
    tcPluginTrace "final premises" (text $ show prems0)
    tcPluginTrace "final goals" (text $ show $ map snd wants)
    case testIf prems (foldr ((:&&) . snd) PTrue wants) of
      Proved -> do
        tcPluginTrace "Proved" (text $ show $ map snd wants)
        tcPluginTrace "... with coercions" (ppr coerced)
        return $ TcPluginOk coerced []
      Disproved wit -> do
        tcPluginTrace "Failed! " (text $ show wit)
        return $ TcPluginContradiction $ map fst wants

defaultTranslation :: TcPluginM Translation
defaultTranslation = do
  emd <- lookupModule (mkModuleName "Proof.Propositional.Empty") (fsLit "equational-reasoning")
  emptyClsTyCon <- classTyCon <$> (tcLookupClass =<< lookupOrig emd (mkTcOcc "Empty"))
  eqTyCon_ <- getEqTyCon
  eqWitCon_ <- getEqWitnessTyCon
  pmd <- lookupModule (mkModuleName "Proof.Propositional") (fsLit "equational-reasoning")
  isTrueCon_ <- tcLookupTyCon =<< lookupOrig pmd (mkTcOcc "IsTrue")
  vmd <- lookupModule (mkModuleName "Data.Void") (fsLit "base")
  voidTyCon <- tcLookupTyCon =<< lookupOrig vmd (mkTcOcc "Void")
  nLeq <- tcLookupTyCon =<< lookupOrig gHC_TYPENATS (mkTcOcc "<=")
  return
    mempty
    { isEmpty = [emptyClsTyCon]
    , tyEq = [eqTyCon_]
    , tyEqWitness = [eqWitCon_]
    , isTrue = [isTrueCon_]
    , voids = [voidTyCon]
    , natMinus = [typeNatSubTyCon]
    , natPlus = [typeNatAddTyCon]
    , natTimes = [typeNatMulTyCon]
    , natExp = [typeNatExpTyCon]
    , falseData = [promotedFalseDataCon]
    , trueData = [promotedTrueDataCon]
    , natLeqBool = [typeNatLeqTyCon]
    , natLeq = [nLeq]
    , natCompare = [typeNatCmpTyCon]
    , orderingEQ = [promotedEQDataCon]
    , orderingLT = [promotedLTDataCon]
    , orderingGT = [promotedGTDataCon]
    }

(<=>) :: Prop -> Prop -> Prop
p <=> q =  (p :&& q) :|| (Not p :&& Not q)

withEv :: Ct -> (EvTerm, Ct)
withEv ct
  | EqPred _ t1 t2 <- classifyPredType (deconsPred ct) =
      (evByFiat "ghc-typelits-presburger" t1 t2, ct)
  | otherwise = undefined

orderingDic :: Given Translation => [(TyCon, Expr -> Expr -> Prop)]
orderingDic =
  [(lt, (:<))  | lt <- orderingLT given ] ++
  [(eq, (:==)) | eq <- orderingEQ given ] ++
  [(gt, (:>))  | gt <- orderingGT given ]

deconsPred :: Ct -> Type
deconsPred = ctEvPred . ctEvidence

toPresburgerPred :: Given Translation => Type -> Machine Prop
toPresburgerPred (TyConApp con [t1, t2])
  | con `elem` (natLeq given ++ natLeqBool given)
  = (:<=) <$> toPresburgerExp t1 <*> toPresburgerExp t2
toPresburgerPred ty
  | Just (con, []) <- splitTyConApp_maybe ty
  , con `elem` trueData given = return PTrue
  | Just (con, []) <- splitTyConApp_maybe ty
  , con `elem` falseData given = return PFalse
  | isEqPred ty = toPresburgerPredTree $ classifyPredType ty
  | Just (con, [l, r]) <- splitTyConApp_maybe ty -- l ~ r
  , con `elem` (tyEq given ++ tyEqBool given)
  = toPresburgerPredTree $ EqPred NomEq l r
  | Just (con, [_k, l, r]) <- splitTyConApp_maybe ty -- l (:~: {k}) r
  , con `elem` tyEqWitness given = toPresburgerPredTree $ EqPred NomEq l r
  | Just (con, [l]) <- splitTyConApp_maybe ty -- Empty l => ...
  , con `elem` isEmpty given = Not <$> toPresburgerPred l
  | Just (con, [l]) <- splitTyConApp_maybe ty -- IsTrue l =>
  , con `elem` isTrue given = toPresburgerPred l
  | otherwise = parsePred given toPresburgerExp ty

splitTyConAppLastBin :: Type -> Maybe (TyCon, [Type])
splitTyConAppLastBin t = do
  (con, ts) <- splitTyConApp_maybe t
  let !n = length ts
  guard $ n >= 2
  return (con, drop (n - 2) ts)

toPresburgerPredTree :: Given Translation => PredTree -> Machine Prop
toPresburgerPredTree (EqPred NomEq p false) -- P ~ 'False <=> Not P ~ 'True
  | maybe False (`elem` falseData given) $ tyConAppTyCon_maybe false =
    Not <$> toPresburgerPredTree (EqPred NomEq p (mkTyConTy promotedTrueDataCon))
toPresburgerPredTree (EqPred NomEq p b)  -- (n :<=? m) ~ 'True
  | maybe False (`elem` trueData given) $ tyConAppTyCon_maybe b
  , Just (con, [t1, t2]) <- splitTyConAppLastBin p
  , con `elem` natLeqBool given = (:<=) <$> toPresburgerExp t1  <*> toPresburgerExp t2
toPresburgerPredTree (EqPred NomEq p q)  -- (p :: Bool) ~ (q :: Bool)
    | typeKind p `eqType` mkTyConTy promotedBoolTyCon = do
      lift $ lift $ tcPluginTrace "EQBOOL:" $ ppr (p, q)
      (<=>) <$> toPresburgerPred p
            <*> toPresburgerPred q
toPresburgerPredTree (EqPred NomEq n m)  -- (n :: Nat) ~ (m :: Nat)
  | typeKind n `eqType` typeNatKind =
    (:==) <$> toPresburgerExp n
          <*> toPresburgerExp m
toPresburgerPredTree (EqPred _ t1 t2) -- CmpNat a b ~ CmpNat c d
  | Just (con,  lastTwo -> [a, b]) <- splitTyConAppLastBin t1
  , Just (con', lastTwo -> [c, d]) <- splitTyConAppLastBin t2
  , con `elem` natCompare given, con' `elem` natCompare given
  = (<=>) <$> ((:<) <$> toPresburgerExp a <*> toPresburgerExp b)
          <*> ((:<) <$> toPresburgerExp c <*> toPresburgerExp d)
toPresburgerPredTree (EqPred NomEq t1 t2) -- CmpNat a b ~ x
  | Just (con, lastTwo -> [a, b]) <- splitTyConAppLastBin t1
  , con `elem` natCompare given
  , Just cmp <- tyConAppTyCon_maybe t2 =
    MaybeT (return $ lookup cmp orderingDic)
       <*> toPresburgerExp a
       <*> toPresburgerExp b
toPresburgerPredTree (EqPred NomEq t1 t2) -- x ~ CmpNat a b
  | Just (con, lastTwo -> [a, b]) <- splitTyConAppLastBin t2
  , con `elem` natCompare given
  , Just cmp <- tyConAppTyCon_maybe t1 =
    MaybeT (return $ lookup cmp orderingDic)
       <*> toPresburgerExp a
       <*> toPresburgerExp b
toPresburgerPredTree (ClassPred con ts)
  -- (n :: Nat) (<=| < | > | >= | == | /=) (m :: Nat)
  | let n = length ts, n >= 2
  , [t1, t2] <- drop (n - 2) ts
  , typeKind t1 `eqType` typeNatKind
  , typeKind t2 `eqType` typeNatKind =
    let p = lookup (classTyCon con) binPropDic
    in MaybeT (return p) <*> toPresburgerExp t1 <*> toPresburgerExp t2
toPresburgerPredTree _ = mzero

binPropDic :: Given Translation => [(TyCon, Expr -> Expr -> Prop)]
binPropDic =
  [ (n, (:<)) | n <- natLt given ++ natLtBool given ] ++
  [ (n, (:>)) | n <- natGt given ++ natGtBool given ] ++
  [ (n, (:<=)) | n <- natLeq given ++ natLeqBool given ] ++
  [ (n, (:>=)) | n <- natGeq given ++ natGeqBool given ] ++
  [ (n, (:==)) | n <- tyEq given ++ tyEqBool given ] ++
  [ (n, (:/=)) | n <- tyNeqBool given ]

toPresburgerExp :: Given Translation => Type -> Machine Expr
toPresburgerExp ty = case ty of
  TyVarTy t          -> return $ Var $ toName $ getKey $ getUnique t
  t@(TyConApp tc ts) -> body tc ts <|> Var . toName . getKey . getUnique <$> toVar t
  LitTy (NumTyLit n) -> return (K n)
  LitTy _            -> mzero
  t                  ->
        parseExpr given ty
    <|> Var . toName . getKey .getUnique <$> toVar t
  where
    body tc ts =
      let step con op
            | tc == con, [tl, tr] <- lastTwo ts =
              op <$> toPresburgerExp tl <*> toPresburgerExp tr
            | otherwise = mzero
      in case ts of
        [tl, tr] | tc `elem` natTimes given ->
          case (simpleExp tl, simpleExp tr) of
            (LitTy (NumTyLit n), LitTy (NumTyLit m)) -> return $ K $ n * m
            (LitTy (NumTyLit n), x) -> (:*) <$> pure n <*> toPresburgerExp x
            (x, LitTy (NumTyLit n)) -> (:*) <$> pure n <*> toPresburgerExp x
            _ -> mzero
        _ ->  asum
           $  [ step con (:+)
              | con <- natPlus given
              ] ++
              [ step con (:-)
              | con <- natMinus given
              ]


-- simplTypeCmp :: Type -> Type

lastTwo :: [a] -> [a]
lastTwo = drop <$> subtract 2 . length <*> id

simpleExp :: Given Translation => Type -> Type
simpleExp (AppTy t1 t2) = AppTy (simpleExp t1) (simpleExp t2)
simpleExp (FunTy t1 t2) = FunTy (simpleExp t1) (simpleExp t2)
simpleExp (ForAllTy t1 t2) = ForAllTy t1 (simpleExp t2)
simpleExp (TyConApp tc (lastTwo -> ts)) = fromMaybe (TyConApp tc (map simpleExp ts)) $
  asum (map simpler
        $ [(c, (+)) | c <- natPlus given] ++
          [(c, (-)) | c <- natMinus given] ++
          [(c, (*)) | c <- natTimes given] ++
          [(c, (^)) | c <- natExp given]
      )
  where
    simpler (con, op)
      | con == tc, [tl, tr] <- map simpleExp ts =
        Just $
        case (tl, tr) of
          (LitTy (NumTyLit n), LitTy (NumTyLit m)) -> LitTy (NumTyLit (op n m))
          _ -> TyConApp con [tl, tr]
      | otherwise = Nothing
simpleExp t = t

type ParseEnv = M.Map TypeEq TyVar

type Machine = MaybeT (StateT ParseEnv TcPluginM)

runMachine :: Machine a -> TcPluginM (Maybe a)
runMachine act = do
  (ma, dic) <- runStateT (runMaybeT act) M.empty
  forM_ (M.toList dic) $ \(TypeEq ty, var) ->
    newWanted undefined $ mkPrimEqPredRole Nominal (mkTyVarTy var) ty
  return ma

toVar :: Type -> Machine TyVar
toVar ty = gets (M.lookup (TypeEq ty)) >>= \case
  Just v -> return v
  Nothing -> do
    v <- lift $ lift $ newFlexiTyVar $ typeKind ty
    modify $ M.insert (TypeEq ty) v
    return v