{-# LANGUAGE FlexibleContexts   #-}
{-# LANGUAGE TupleSections      #-}
{-# LANGUAGE OverloadedStrings  #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveTraversable  #-}
{-# LANGUAGE DeriveGeneric      #-}
{-# LANGUAGE DerivingVia        #-}
{-# LANGUAGE NamedFieldPuns     #-}

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

module Language.Haskell.Liquid.Types.Bounds (

    Bound(..),

    RBound, RRBound, RRBoundV,

    RBEnv, RRBEnv, RRBEnvV,

    makeBound,
    emapBoundM,
    mapBoundTy

    ) where

import Prelude hiding (error)
import Text.PrettyPrint.HughesPJ
import GHC.Generics
import Data.List (partition)
import Data.Maybe
import Data.Hashable
import Data.Bifunctor as Bifunctor
import Data.Data
import qualified Data.Binary         as B
import Data.Traversable
import qualified Data.HashMap.Strict as M

import qualified Language.Fixpoint.Types as F
import Language.Haskell.Liquid.Types.Errors
import Language.Haskell.Liquid.Types.RefType
import Language.Haskell.Liquid.Types.RType
import Language.Haskell.Liquid.Types.RTypeOp
import Language.Haskell.Liquid.Types.Types


data Bound t e = Bound
  { bname   :: LocSymbol         -- ^ The name of the bound
  , tyvars  :: [t]               -- ^ Type variables that appear in the bounds
  , bparams :: [(LocSymbol, t)]  -- ^ These are abstract refinements, for now
  , bargs   :: [(LocSymbol, t)]  -- ^ These are value variables
  , bbody   :: e                 -- ^ The body of the bound
  } deriving (Data, Generic, Functor, Foldable, Traversable)
  deriving B.Binary via Generically (Bound t e)

type RBound        = RRBound RSort
type RRBound tv    = RRBoundV F.Symbol tv
type RRBoundV v tv = Bound tv (F.ExprV v)
type RBEnv         = M.HashMap LocSymbol RBound
type RRBEnv tv     = M.HashMap LocSymbol (RRBound tv)
type RRBEnvV v tv     = M.HashMap LocSymbol (RRBoundV v tv)

emapBoundM
  :: Monad m
  => ([F.Symbol] -> t0 -> m t1)
  -> ([F.Symbol] -> e0 -> m e1)
  -> Bound t0 e0
  -> m (Bound t1 e1)
emapBoundM f g b = do
    tyvars <- mapM (f []) $ tyvars b
    (e1, bparams) <- mapAccumM (\e -> fmap (e,) . traverse (f e)) [] (bparams b)
    (e2, bargs) <- mapAccumM (\e -> fmap (e,) . traverse (f e)) e1 (bargs b)
    bbody <- g e2 (bbody b)
    return b{tyvars, bparams, bargs, bbody}

mapBoundTy :: (t0 -> t1) -> Bound t0 e -> Bound t1 e
mapBoundTy f Bound{..} = do
    Bound
      { tyvars = map f tyvars
      , bparams = map (fmap f) bparams
      , bargs = map (fmap f) bargs
      , ..
      }

instance Hashable (Bound t e) where
  hashWithSalt i = hashWithSalt i . bname

instance Eq (Bound t e) where
  b1 == b2 = bname b1 == bname b2

instance (PPrint e, PPrint t) => (Show (Bound t e)) where
  show = showpp


instance (PPrint e, PPrint t) => (PPrint (Bound t e)) where
  pprintTidy k (Bound s vs ps ys e) = "bound" <+> pprintTidy k s <+>
                                      "forall" <+> pprintTidy k vs <+> "." <+>
                                      pprintTidy k (fst <$> ps) <+> "=" <+>
                                      ppBsyms k (fst <$> ys) <+> pprintTidy k e
    where
      ppBsyms _ [] = ""
      ppBsyms k' xs = "\\" <+> pprintTidy k' xs <+> "->"

instance Bifunctor Bound where
  first  f (Bound s vs ps xs e) = Bound s (f <$> vs) (fmap f <$> ps) (fmap f <$> xs) e
  second = fmap

makeBound :: (PPrint r, UReftable r, SubsTy RTyVar (RType RTyCon RTyVar ()) r)
          => RRBound RSort -> [RRType r] -> [F.Symbol] -> RRType r -> RRType r
makeBound (Bound _  vs ps xs expr) ts qs
         = RRTy cts mempty OCons
  where
    cts  = (\(x, t) -> (x, foldr subsTyVarMeet t su)) <$> cts'

    cts' = makeBoundType penv rs xs

    penv = zip (val . fst <$> ps) qs
    rs   = bkImp [] expr

    bkImp acc (F.PImp p q) = bkImp (p:acc) q
    bkImp acc p          = p:acc

    su  = [(α, toRSort t, t) | (RVar α _, t) <-  zip vs ts ]

makeBoundType :: (PPrint r, UReftable r)
              => [(F.Symbol, F.Symbol)]
              -> [F.Expr]
              -> [(LocSymbol, RSort)]
              -> [(F.Symbol, RRType r)]
makeBoundType penv (q:qs) xts = go xts
  where
    -- NV TODO: Turn this into a proper error
    go [] = panic Nothing "Bound with empty symbols"

    go [(x, t)]      = [(F.dummySymbol, tp t x), (F.dummySymbol, tq t x)]
    go ((x, t):xtss) = (val x, mkt t x) : go xtss

    mkt t x = ofRSort t `strengthen` ofUReft (MkUReft (F.Reft (val x, F.PTrue))
                                                (Pr $ M.lookupDefault [] (val x) ps))
    tp t x  = ofRSort t `strengthen` ofUReft (MkUReft (F.Reft (val x, F.pAnd rs))
                                                (Pr $ M.lookupDefault [] (val x) ps))
    tq t x  = ofRSort t `strengthen` makeRef penv x q

    (ps, rs) = partitionPs penv qs


-- NV TODO: Turn this into a proper error
makeBoundType _ _ _           = panic Nothing "Bound with empty predicates"


partitionPs :: [(F.Symbol, F.Symbol)] -> [F.Expr] -> (M.HashMap F.Symbol [UsedPVar], [F.Expr])
partitionPs penv qs = Bifunctor.first makeAR $ partition (isPApp penv) qs
  where
    makeAR ps       = M.fromListWith (++) $ map (toUsedPVars penv) ps

isPApp :: [(F.Symbol, a)] -> F.Expr -> Bool
isPApp penv (F.EApp (F.EVar p) _)  = isJust $ lookup p penv
isPApp penv (F.EApp e _)         = isPApp penv e
isPApp _    _                  = False

toUsedPVars :: [(F.Symbol, F.Symbol)] -> F.Expr -> (F.Symbol, [PVar ()])
toUsedPVars penv q@(F.EApp _ expr) = (sym, [toUsedPVar penv q])
  where
    -- NV : TODO make this a better error
    sym = case {- unProp -} expr of {F.EVar x -> x; e -> todo Nothing ("Bound fails in " ++ show e) }
toUsedPVars _ _ = impossible Nothing "This cannot happen"

toUsedPVar :: [(F.Symbol, F.Symbol)] -> F.Expr -> PVar ()
toUsedPVar penv ee@(F.EApp _ _)
  = PV q () e (((), F.dummySymbol,) <$> es')
   where
     F.EVar e = {- unProp $ -} last es
     es'    = init es
     Just q = lookup p penv
     (F.EVar p, es) = F.splitEApp ee

toUsedPVar _ _ = impossible Nothing "This cannot happen"

-- `makeRef` is used to make the refinement of the last implication,
-- thus it can contain both concrete and abstract refinements

makeRef :: (UReftable r) => [(F.Symbol, F.Symbol)] -> LocSymbol -> F.Expr -> r
makeRef penv v (F.PAnd rs) = ofUReft (MkUReft (F.Reft (val v, F.pAnd rrs)) r)
  where
    r                    = Pr  (toUsedPVar penv <$> pps)
    (pps, rrs)           = partition (isPApp penv) rs

makeRef penv v rr
  | isPApp penv rr       = ofUReft (MkUReft (F.Reft(val v, F.PTrue)) r)
  where
    r                    = Pr [toUsedPVar penv rr]

makeRef _    v p         = ofReft (F.Reft (val v, p))
