{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

module Language.Haskell.Liquid.WiredIn
       ( wiredTyCons
       , wiredDataCons
       , wiredSortedSyms

       , charDataCon

       -- * Constants for automatic proofs
       , dictionaryVar
       , dictionaryTyVar
       , dictionaryBind
       , proofTyConName
       , combineProofsName

       -- * Built in symbols
       , isWiredIn
       , isWiredInName
       , dcPrefix

       -- * Deriving classes
       , isDerivedInstance
       , derivingClasses
       ) where

import Prelude                                hiding (error)

-- import Language.Fixpoint.Misc           (mapSnd)
import Language.Haskell.Liquid.GHC.Misc
import qualified Liquid.GHC.API as Ghc
import Liquid.GHC.API (Var, Arity, TyVar, Bind(..), Boxity(..), Expr(..), ForAllTyFlag(Required))
import Language.Haskell.Liquid.Types.Errors
import Language.Haskell.Liquid.Types.Names
import Language.Haskell.Liquid.Types.RType
import Language.Haskell.Liquid.Types.Types
import Language.Haskell.Liquid.Types.RefType
import Language.Haskell.Liquid.Types.Variance
import Language.Haskell.Liquid.Types.PredType

-- import Language.Fixpoint.Types hiding (panic)
import qualified Language.Fixpoint.Types.Config as F
import qualified Language.Fixpoint.Smt.Theories as F
import qualified Language.Fixpoint.Types as F
import           Data.Bifunctor (first)
import qualified Data.HashSet as S
import           Data.Maybe

import Language.Haskell.Liquid.GHC.TypeRep ()

-- | Horrible hack to support hardwired symbols like
--      `head`, `tail`, `fst`, `snd`
--   and other LH generated symbols that
--   *do not* correspond to GHC Vars and
--   *should not* be resolved to GHC Vars.

isWiredIn :: F.LocSymbol -> Bool
isWiredIn x = isWiredInLoc x  || isWiredInName (val x) || isWiredInShape x

isWiredInLoc :: F.LocSymbol -> Bool
isWiredInLoc sym  = ln == ln' && ln == F.safePos 1 && c == c' && c' == F.safePos 1
  where
    (ln , c)  = spe (loc sym)
    (ln', c') = spe (locE sym)
    spe l    = (x, y) where (_, x, y) = F.sourcePosElts l

isWiredInName :: F.Symbol -> Bool
isWiredInName x = x `S.member` wiredInNames

wiredInNames :: S.HashSet F.Symbol
wiredInNames = S.fromList [ "head", "tail", "fst", "snd", "len"]

isWiredInShape :: F.LocSymbol -> Bool
isWiredInShape x = any (`F.isPrefixOfSym` val x) [F.anfPrefix, F.tempPrefix, dcPrefix]
  -- where s        = val x
        -- dcPrefix = "lqdc"

dcPrefix :: F.Symbol
dcPrefix = "lqdc"

wiredSortedSyms :: [(F.Symbol, F.Sort)]
wiredSortedSyms =
    (selfSymbol,selfSort) :
    [(pappSym n, pappSort n) | n <- [1..pappArity]] ++
    wiredTheorySortedSyms
  where
    selfSort = F.FAbs 1 (F.FVar 0)

wiredTheorySortedSyms :: [(F.Symbol, F.Sort)]
wiredTheorySortedSyms =
    [ (s, srt)
    | s <- wiredTheorySyms
    , let srt = F.tsSort $
                  fromMaybe (panic Nothing ("unknown symbol: " ++ show s)) $
                    F.lookupSEnv s (F.theorySymbols F.Z3 <> F.theorySymbols F.Cvc5)
    ]
  where
    wiredTheorySyms =
      [ "Map_default"
      , "Map_select"
      , "Map_store"

      , "Set_cup"
      , "Set_cap"
      , "Set_dif"
      , "Set_sng"
      , "Set_emp"
      , "Set_empty"
      , "Set_mem"
      , "Set_sub"
      , "Set_add"
      , "Set_com"
      , "Set_card"

      , "Bag_count"
      , "Bag_empty"
      , "Bag_inter_min"
      , "Bag_sng"
      , "Bag_sub"
      , "Bag_union"
      , "Bag_union_max"

      , "FF_val"
      , "FF_add"
      , "FF_mul"

      , "strLen"
      ]

--------------------------------------------------------------------------------
-- | LH Primitive TyCons -------------------------------------------------------
--------------------------------------------------------------------------------

dictionaryVar :: Var
dictionaryVar   = stringVar "tmp_dictionary_var" (Ghc.ForAllTy (Ghc.Bndr dictionaryTyVar Required) $ Ghc.TyVarTy dictionaryTyVar)

dictionaryTyVar :: TyVar
dictionaryTyVar = stringTyVar "da"

dictionaryBind :: Bind Var
dictionaryBind = Rec [(v, Lam a $ App (Var v) (Type $ Ghc.TyVarTy a))]
  where
   v = dictionaryVar
   a = dictionaryTyVar

-----------------------------------------------------------------------
-- | LH Primitive TyCons ----------------------------------------------
-----------------------------------------------------------------------


combineProofsName :: String
combineProofsName = "combineProofs"

proofTyConName :: F.Symbol
proofTyConName = "Proof"

--------------------------------------------------------------------------------
-- | Predicate Types for WiredIns ----------------------------------------------
--------------------------------------------------------------------------------

maxArity :: Arity
maxArity = 7

wiredTyCons :: [TyConP]
wiredTyCons  = fst wiredTyDataCons

wiredDataCons :: [Located DataConP]
wiredDataCons = snd wiredTyDataCons

wiredTyDataCons :: ([TyConP] , [Located DataConP])
wiredTyDataCons = (concat tcs, dummyLoc <$> concat dcs)
  where
    (tcs, dcs)  = unzip $ listTyDataCons : map tupleTyDataCons [2..maxArity]

charDataCon :: Located DataConP
charDataCon = dummyLoc (DataConP l0 Ghc.charDataCon  [] [] [] [(makeGeneratedLogicLHName "charX",lt)] lt False wiredInName l0)
  where
    l0 = F.dummyPos "LH.Bare.charTyDataCons"
    c  = Ghc.charTyCon
    lt = rApp c [] [] mempty

listTyDataCons :: ([TyConP] , [DataConP])
listTyDataCons   = ( [TyConP l0 c [RTV tyv] [p] [Covariant] [Covariant] (Just fsize)]
                   , [DataConP l0 Ghc.nilDataCon  [RTV tyv] [p] [] []    lt False wiredInName l0
                   ,  DataConP l0 Ghc.consDataCon [RTV tyv] [p] [] cargs lt False wiredInName l0])
    where
      l0         = F.dummyPos "LH.Bare.listTyDataCons"
      c          = Ghc.listTyCon
      [tyv]      = tyConTyVarsDef c
      t          = rVar tyv :: RSort
      fld        = "fldList"
      xHead      = "head"
      xTail      = "tail"
      p          = PV "p" t (F.vv Nothing) [(t, fld, F.EVar fld)]
      px         = pdVarReft $ PV "p" t (F.vv Nothing) [(t, fld, F.EVar xHead)]
      lt         = rApp c [xt] [rPropP [] $ pdVarReft p] mempty
      xt         = rVar tyv
      xst        = rApp c [RVar (RTV tyv) px] [rPropP [] $ pdVarReft p] mempty
      cargs      = map (first makeGeneratedLogicLHName) [(xTail, xst), (xHead, xt)]
      fsize      = SymSizeFun (dummyLoc "GHC.Types_LHAssumptions.len")

wiredInName :: F.Symbol
wiredInName = "WiredIn"

tupleTyDataCons :: Int -> ([TyConP] , [DataConP])
tupleTyDataCons n = ( [TyConP   l0 c  (RTV <$> tyvs) ps tyvarinfo pdvarinfo Nothing]
                    , [DataConP l0 dc (RTV <$> tyvs) ps []  cargs  lt False wiredInName l0])
  where
    tyvarinfo     = replicate n     Covariant
    pdvarinfo     = replicate (n-1) Covariant
    l0            = F.dummyPos "LH.Bare.tupleTyDataCons"
    c             = Ghc.tupleTyCon   Boxed n
    dc            = Ghc.tupleDataCon Boxed n
    tyvs@(tv:tvs) = tyConTyVarsDef c
    (ta:ts)       = (rVar <$> tyvs) :: [RSort]
    flds          = mks "fld_Tuple"
    fld           = "fld_Tuple"
    x1:xs         = mks ("x_Tuple" ++ show n)
    ps            = mkps pnames (ta:ts) ((fld, F.EVar fld) : zip flds (F.EVar <$> flds))
    ups           = uPVar <$> ps
    pxs           = mkps pnames (ta:ts) ((fld, F.EVar x1) : zip flds (F.EVar <$> xs))
    lt            = rApp c (rVar <$> tyvs) (rPropP [] . pdVarReft <$> ups) mempty
    xts           = zipWith (\v p -> RVar (RTV v) (pdVarReft p)) tvs pxs
    cargs         = map (first makeGeneratedLogicLHName) $ reverse $ (x1, rVar tv) : zip xs xts
    pnames        = mks_ "p"
    mks  x        = (\i -> F.symbol (x++ show i)) <$> [1..n]
    mks_ x        = (\i -> F.symbol (x++ show i)) <$> [2..n]


mkps :: [F.Symbol]
     -> [t] -> [(F.Symbol, F.Expr)] -> [PVar t]
mkps ns (t:ts) ((f,x):fxs) = reverse $ mkps_ ns ts fxs [(t, f, x)] []
mkps _  _      _           = panic Nothing "Bare : mkps"

mkps_ :: [F.Symbol]
      -> [t]
      -> [(F.Symbol, F.Expr)]
      -> [(t, F.Symbol, F.Expr)]
      -> [PVar t]
      -> [PVar t]
mkps_ []     _       _          _    ps = ps
mkps_ (n:ns) (t:ts) ((f, x):xs) args ps = mkps_ ns ts xs (a:args) (p:ps)
  where
    p                                   = PV n t (F.vv Nothing) args
    a                                   = (t, f, x)
mkps_ _     _       _          _    _ = panic Nothing "Bare : mkps_"


--------------------------------------------------------------------------------
isDerivedInstance :: Ghc.ClsInst -> Bool
--------------------------------------------------------------------------------
isDerivedInstance i = F.notracepp ("IS-DERIVED: " ++ F.showpp classSym)
                    $ S.member classSym derivingClassesSet
  where
    classSym        = F.symbol . Ghc.is_cls $ i

derivingClassesSet :: S.HashSet F.Symbol
derivingClassesSet = S.fromList $ map F.symbol derivingClasses

derivingClasses :: [String]
derivingClasses =
  [ show ''Eq
  , show ''Ord
  , show ''Enum
  , show ''Show
  , show ''Read
  , show ''Monad
  , show ''Applicative
  , show ''Functor
  , show ''Foldable
  , show ''Traversable
  , show ''Fractional
  -- , "GHC.Enum.Bounded"
  -- , "GHC.Base.Monoid"
  ]
