{-# LANGUAGE LambdaCase
           , ViewPatterns
           #-}
-- | Functions that convert the value and function definitions of the GHC AST to corresponding elements in the Haskell-tools AST representation
module Language.Haskell.Tools.AST.FromGHC.Binds where

import ApiAnnotation as GHC (AnnKeywordId(..))
import Bag as GHC (bagToList)
import BasicTypes as GHC (FixityDirection(..), Fixity(..))
import BasicTypes as GHC
import HsBinds as GHC
import HsExpr as GHC
import HsPat as GHC (LPat)
import HsTypes as GHC (HsWildCardBndrs(..), HsImplicitBndrs(..))
import Outputable as GHC (Outputable(..), showSDocUnsafe)
import SrcLoc as GHC

import Control.Monad.Reader (Monad(..), mapM, asks)
import Data.Data (Data(..))
import Data.List

import Language.Haskell.Tools.AST.FromGHC.Exprs (trfExpr)
import Language.Haskell.Tools.AST.FromGHC.Monad
import Language.Haskell.Tools.AST.FromGHC.Names
import Language.Haskell.Tools.AST.FromGHC.Patterns (trfPattern)
import Language.Haskell.Tools.AST.FromGHC.Types (trfType)
import Language.Haskell.Tools.AST.FromGHC.Utils

import Language.Haskell.Tools.AST (Ann, AnnMaybeG, AnnListG, Dom, RangeStage)
import qualified Language.Haskell.Tools.AST as AST

trfBind :: TransformName n r => Located (HsBind n) -> Trf (Ann AST.UValueBind (Dom r) RangeStage)
trfBind = trfLocNoSema trfBind'
  
trfBind' :: TransformName n r => HsBind n -> Trf (AST.UValueBind (Dom r) RangeStage)
-- a value binding (not a function)
trfBind' (FunBind { fun_id = id, fun_matches = MG { mg_alts = L _ [L _ (Match { m_pats = [], m_grhss = GRHSs [L _ (GRHS [] expr)] (L _ locals) })]} }) 
  = AST.USimpleBind <$> copyAnnot AST.UVarPat (define $ trfName id)
                    <*> addEmptyScope (addToScope locals (annLocNoSema (combineSrcSpans (getLoc expr) <$> tokenLoc AnnEqual) (AST.UUnguardedRhs <$> trfExpr expr)))
                    <*> addEmptyScope (trfWhereLocalBinds (getLoc expr) locals)
trfBind' (FunBind id (MG (unLoc -> matches) _ _ _) _ _ _) 
  = AST.UFunBind <$> makeNonemptyIndentedList (mapM (trfMatch (unLoc id)) matches)
trfBind' (PatBind pat (GRHSs rhs (unLoc -> locals)) _ _ _) 
  = AST.USimpleBind <$> trfPattern pat 
                    <*> addEmptyScope (addToScope locals (trfRhss rhs))
                    <*> addEmptyScope (trfWhereLocalBinds (collectLocs rhs) locals)
trfBind' (PatSynBind _) = error "Pattern synonym bindings should be recognized on the declaration level"
trfBind' _ = error "Bindings generated by the compiler cannot be converted"

trfMatch :: TransformName n r => n -> Located (Match n (LHsExpr n)) -> Trf (Ann AST.UMatch (Dom r) RangeStage)
trfMatch id = trfLocNoSema (trfMatch' id)

trfMatch' :: TransformName n r => n -> Match n (LHsExpr n) -> Trf (AST.UMatch (Dom r) RangeStage)
trfMatch' name (Match funid pats typ (GRHSs rhss (unLoc -> locBinds)))
  -- TODO: add the optional typ to pats
  = AST.UMatch <$> trfMatchLhs name funid pats
               <*> addToScope pats (addToScope locBinds (trfRhss rhss))
               <*> addToScope pats (trfWhereLocalBinds (collectLocs rhss) locBinds)

trfMatchLhs :: TransformName n r => n -> MatchFixity n -> [LPat n] -> Trf (Ann AST.UMatchLhs (Dom r) RangeStage)
trfMatchLhs name fb pats 
  = do implicitIdLoc <- mkSrcSpan <$> atTheStart <*> atTheStart
       closeLoc <- srcSpanStart <$> (combineSrcSpans <$> tokenLoc AnnEqual <*> tokenLoc AnnVbar)
       let (n, isInfix) = case fb of NonFunBindMatch -> (L implicitIdLoc name, False)
                                     FunBindMatch n inf -> (n, inf)
       args <- mapM trfPattern pats
       annLocNoSema (mkSrcSpan <$> atTheStart <*> (pure closeLoc)) $
         case (args, isInfix) of 
            (left:right:rest, True) -> AST.UInfixLhs left <$> define (trfOperator n) <*> pure right <*> makeList " " (pure closeLoc) (pure rest)
            _                       -> AST.UNormalLhs <$> define (trfName n) <*> makeList " " (pure closeLoc) (pure args)

trfRhss :: TransformName n r => [Located (GRHS n (LHsExpr n))] -> Trf (Ann AST.URhs (Dom r) RangeStage)
-- the original location on the GRHS misleadingly contains the local bindings
trfRhss [unLoc -> GRHS [] body] = annLocNoSema (combineSrcSpans (getLoc body) <$> tokenBefore (srcSpanStart $ getLoc body) AnnEqual) 
                                         (AST.UUnguardedRhs <$> trfExpr body)
trfRhss rhss = annLocNoSema (pure $ collectLocs rhss) 
                      (AST.UGuardedRhss . nonemptyAnnList <$> mapM trfGuardedRhs rhss)
                      
trfGuardedRhs :: TransformName n r => Located (GRHS n (LHsExpr n)) -> Trf (Ann AST.UGuardedRhs (Dom r) RangeStage)
trfGuardedRhs = trfLocNoSema $ \(GRHS guards body) 
  -> AST.UGuardedRhs . nonemptyAnnList <$> trfScopedSequence trfRhsGuard guards <*> addToScope guards (trfExpr body)
  
trfRhsGuard :: TransformName n r => Located (Stmt n (LHsExpr n)) -> Trf (Ann AST.URhsGuard (Dom r) RangeStage)
trfRhsGuard = trfLocNoSema trfRhsGuard'
  
trfRhsGuard' :: TransformName n r => Stmt n (LHsExpr n) -> Trf (AST.URhsGuard (Dom r) RangeStage)
trfRhsGuard' (BindStmt pat body _ _ _) = AST.UGuardBind <$> trfPattern pat <*> trfExpr body
trfRhsGuard' (BodyStmt body _ _ _) = AST.UGuardCheck <$> trfExpr body
trfRhsGuard' (LetStmt (unLoc -> binds)) = AST.UGuardLet <$> trfLocalBinds binds
trfRhsGuard' d = error $ "Illegal guard: " ++ showSDocUnsafe (ppr d) ++ " (ctor: " ++ show (toConstr d) ++ ")"
  
trfWhereLocalBinds :: TransformName n r => SrcSpan -> HsLocalBinds n -> Trf (AnnMaybeG AST.ULocalBinds (Dom r) RangeStage)
trfWhereLocalBinds _ EmptyLocalBinds = nothing "" "" atTheEnd
trfWhereLocalBinds bef binds
  = makeJust <$> annLocNoSema (combineSrcSpans (srcLocSpan (srcSpanEnd bef) `combineSrcSpans` getBindLocs binds) <$> tokenLocBack AnnWhere) 
                              (AST.ULocalBinds <$> addToScope binds (trfLocalBinds binds))

getBindLocs :: HsLocalBinds n -> SrcSpan
getBindLocs (HsValBinds (ValBindsIn binds sigs)) = foldLocs $ map getLoc (bagToList binds) ++ map getLoc sigs
getBindLocs (HsValBinds (ValBindsOut binds sigs)) = foldLocs $ map getLoc (concatMap (bagToList . snd) binds) ++ map getLoc sigs
getBindLocs (HsIPBinds (IPBinds binds _)) = foldLocs $ map getLoc binds
getBindLocs EmptyLocalBinds = noSrcSpan
  
trfLocalBinds :: TransformName n r => HsLocalBinds n -> Trf (AnnListG AST.ULocalBind (Dom r) RangeStage)
trfLocalBinds (HsValBinds (ValBindsIn binds sigs)) 
  = makeIndentedListBefore " " (after AnnWhere)
      (orderDefs <$> ((++) <$> mapM (copyAnnot AST.ULocalValBind . trfBind) (bagToList binds) 
                           <*> mapM trfLocalSig sigs))
trfLocalBinds (HsValBinds (ValBindsOut binds sigs)) 
  = makeIndentedListBefore " " (after AnnWhere)
      (orderDefs <$> ((++) <$> (concat <$> mapM (mapM (copyAnnot AST.ULocalValBind . trfBind) . bagToList . snd) binds)
                           <*> mapM trfLocalSig sigs))
trfLocalBinds (HsIPBinds (IPBinds binds _))
  = makeIndentedListBefore " " (after AnnWhere) (mapM trfIpBind binds)
trfLocalBinds EmptyLocalBinds
  -- TODO: implement
  = error "trfLocalBinds: EmptyLocalBinds not supported yet"

trfIpBind :: TransformName n r => Located (IPBind n) -> Trf (Ann AST.ULocalBind (Dom r) RangeStage)
trfIpBind = trfLocNoSema $ \case
  IPBind (Left (L l ipname)) expr 
    -> AST.ULocalValBind 
         <$> (annContNoSema $ AST.USimpleBind <$> focusOn l (annContNoSema (AST.UVarPat <$> define (trfImplicitName ipname)))
                                              <*> annFromNoSema AnnEqual (AST.UUnguardedRhs <$> trfExpr expr)
                                              <*> nothing " " "" atTheEnd)
  IPBind (Right _) _ -> error "trfIpBind: called on typechecked AST"
             
trfLocalSig :: TransformName n r => Located (Sig n) -> Trf (Ann AST.ULocalBind (Dom r) RangeStage)
trfLocalSig = trfLocNoSema $ \case
  ts@(TypeSig {}) -> AST.ULocalSignature <$> annContNoSema (trfTypeSig' ts)
  (FixSig fs) -> AST.ULocalFixity <$> annContNoSema (trfFixitySig fs)
  (InlineSig name prag) -> AST.ULocalInline <$> trfInlinePragma name prag
  d -> error $ "Illegal local signature: " ++ showSDocUnsafe (ppr d) ++ " (ctor: " ++ show (toConstr d) ++ ")"
  
trfTypeSig :: TransformName n r => Located (Sig n) -> Trf (Ann AST.UTypeSignature (Dom r) RangeStage)
trfTypeSig = trfLocNoSema trfTypeSig'

trfTypeSig' :: TransformName n r => Sig n -> Trf (AST.UTypeSignature (Dom r) RangeStage)
trfTypeSig' (TypeSig names typ) 
  = defineTypeVars $ AST.UTypeSignature <$> makeNonemptyList ", " (mapM trfName names) <*> trfType (hswc_body $ hsib_body typ)
trfTypeSig' ts = error $ "Illegal type signature: " ++ showSDocUnsafe (ppr ts) ++ " (ctor: " ++ show (toConstr ts) ++ ")"

trfFixitySig :: TransformName n r => FixitySig n -> Trf (AST.UFixitySignature (Dom r) RangeStage)
trfFixitySig (FixitySig names (Fixity _ prec dir)) 
  = do precLoc <- tokenLoc AnnVal
       AST.UFixitySignature <$> transformDir dir
                            <*> (if isGoodSrcSpan precLoc 
                                   then makeJust <$> (annLocNoSema (return precLoc) $ pure $ AST.Precedence prec)
                                                                                         -- names cannot be empty
                                   else nothing "" " " (return $ srcSpanStart $ getLoc $ head names))
                            <*> (nonemptyAnnList . nub <$> mapM trfOperator names)
  where transformDir InfixL = directionChar (pure AST.AssocLeft)
        transformDir InfixR = directionChar (pure AST.AssocRight)
        transformDir InfixN = annLocNoSema (srcLocSpan . srcSpanEnd <$> tokenLoc AnnInfix) (pure AST.AssocNone)
        
        directionChar = annLocNoSema ((\l -> mkSrcSpan (updateCol (subtract 1) l) l) . srcSpanEnd <$> tokenLoc AnnInfix)

trfInlinePragma :: TransformName n r => Located n -> InlinePragma -> Trf (Ann AST.UInlinePragma (Dom r) RangeStage)
trfInlinePragma name (InlinePragma _ Inlinable _ phase _) 
  = annContNoSema (AST.UInlinablePragma <$> trfPhase (pure $ srcSpanStart $ getLoc name) phase <*> trfName name)
trfInlinePragma name (InlinePragma src NoInline _ _ cl) = annContNoSema (AST.UNoInlinePragma <$> trfName name)
trfInlinePragma name (InlinePragma src Inline _ phase cl) 
  = annContNoSema $ do rng <- asks contRange
                       let parts = map getLoc $ splitLocated (L rng src)
                       AST.UInlinePragma <$> trfConlike parts cl 
                                         <*> trfPhase (pure $ srcSpanStart (getLoc name)) phase 
                                         <*> trfName name

trfPhase :: Trf SrcLoc -> Activation -> Trf (AnnMaybeG AST.UPhaseControl (Dom r) RangeStage)
trfPhase l AlwaysActive = nothing " " "" l
trfPhase _ (ActiveAfter _ pn) = makeJust <$> annLocNoSema (combineSrcSpans <$> tokenLoc AnnOpenS <*> tokenLoc AnnCloseS) 
                                                          (AST.UPhaseControl <$> nothing "" "" (before AnnCloseS) <*> trfPhaseNum pn)
trfPhase _ (ActiveBefore _ pn) = makeJust <$> annLocNoSema (combineSrcSpans <$> tokenLoc AnnOpenS <*> tokenLoc AnnCloseS)
                                                           (AST.UPhaseControl <$> (makeJust <$> annLocNoSema (tokenLoc AnnTilde) (pure AST.PhaseInvert)) <*> trfPhaseNum pn)
trfPhase _ NeverActive = do range <- asks contRange 
                            error $ "NeverActive pragmas should be checked earlier : " ++ show range

trfPhaseNum ::  PhaseNum -> Trf (Ann AST.PhaseNumber (Dom r) RangeStage)
trfPhaseNum i = annLocNoSema (tokenLoc AnnVal) $ pure (AST.PhaseNumber $ fromIntegral i)

trfConlike :: [SrcSpan] -> RuleMatchInfo -> Trf (AnnMaybeG AST.UConlikeAnnot (Dom r) RangeStage)
trfConlike parts ConLike = makeJust <$> annLocNoSema (pure $ parts !! 2) (pure AST.UConlikeAnnot)
trfConlike parts FunLike = nothing " " "" (pure $ srcSpanEnd $ parts !! 1)